rscp

git clone https://orangeshoelaces.net/git/rscp.git

325bc9134425b1097c808f44820a209283db59c0

Author: Vasily Kolobkov on 06/21/2017

Committer: Vasily Kolobkov on 06/21/2017

Act as a sink

Stats

main.go | 341 +++++++-
1 file changed, 310 insertions(+), 31 deletions(-)

Patch

diff --git a/main.go b/main.go
index a479a7a..b5dfe70 100644
--- a/main.go
+++ b/main.go
@@ -5,18 +5,29 @@ import (
 	"flag"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"os"
 	"path"
+	"strings"
 	"syscall"
 )
 
+const (
+	S_IWUSR = 00200
+	S_IRWXU = 00700
+	S_ISUID = 04000
+	S_ISGID = 02000
+)
+
 var (
-	iamSource    = flag.Bool("f", false, "Run in source mode")
-	iamSink      = flag.Bool("t", false, "Run in sink mode")
-	bwLimit      = flag.Int("l", 0, "Limit the bandwidth, specified in Kbit/s")
-	iamRecursive = flag.Bool("r", false, "Copy directoires recursively following any symlinks")
-	targetDir    = flag.Bool("d", false, "Target should be a directory")
-	preserveAttr = flag.Bool("p", false, "Preserve modification and access times and mode from original file")
+	iamSource     = flag.Bool("f", false, "Run in source mode")
+	iamSink       = flag.Bool("t", false, "Run in sink mode")
+	bwLimit       = flag.Int("l", 0, "Limit the bandwidth, specified in Kbit/s")
+	iamRecursive  = flag.Bool("r", false, "Copy directoires recursively following any symlinks")
+	targetDir     = flag.Bool("d", false, "Target should be a directory")
+	preserveAttrs = flag.Bool("p", false, "Preserve modification and access times and mode from original file")
+
+	protocolErr = FatalError("protocol error")
 )
 
 func main() {
@@ -35,10 +46,11 @@ func main() {
 	if *iamSource {
 		err = source(args)
 	} else {
-		err = sink(args[0])
+		err = sink(args[0], false)
 	}
 
 	if err != nil {
+		fmt.Fprintln(os.Stderr, err)
 		os.Exit(1)
 	}
 }
@@ -50,10 +62,9 @@ func source(paths []string) error {
 
 	var sendErrs []error
 	for _, path := range paths {
-		if err := send(path); err != nil {
-			if _, ok := err.(FatalError); ok {
-				return err
-			}
+		if err := send(path); isFatal(err) {
+			return err
+		} else if err != nil {
 			sendErrs = append(sendErrs, err)
 		}
 	}
@@ -64,12 +75,236 @@ func source(paths []string) error {
 	return nil
 }
 
-func sink(arg string) error {
+func sink(path string, recur bool) error {
+	var errs []error
+	var times *FileTimes
+
+	if *targetDir {
+		if st, err := os.Stat(path); err != nil {
+			return teeError(FatalError(err.Error()))
+		} else if !st.IsDir() {
+			return teeError(FatalError(path + ": is not a directory"))
+		}
+	}
+
+	fmt.Fprint(os.Stdout, "\x00")
+
+	for first := true; ; first = false {
+		prefix := []byte{0}
+		if _, err := os.Stdin.Read(prefix); err != nil {
+			if err == io.EOF {
+				break
+			}
+			return FatalError(err.Error())
+		}
+		line, err := readLine()
+		if err != nil {
+			return FatalError(err.Error())
+		}
+
+		switch prefix[0] {
+		case '\x01':
+			errs = append(errs, errors.New(line))
+
+		case '\x02':
+			return FatalError(line)
+
+		case 'E':
+			if !recur {
+				return teeError(protocolErr)
+			}
+			fmt.Fprint(os.Stdout, "\x00")
+
+		case 'T':
+			if times == nil {
+				times = new(FileTimes)
+			}
+			if n, err := fmt.Sscanf(line, "%d %d %d %d",
+				&times.Mtime.Sec, &times.Mtime.Usec,
+				&times.Atime.Sec, &times.Atime.Usec); err != nil {
+
+				return teeError(FatalError(err.Error()))
+			} else if n != 4 {
+				return teeError(protocolErr)
+			}
+			fmt.Fprint(os.Stdout, "\x00")
+
+		case 'D':
+			if err := sinkDir(path, line, times); isFatal(err) {
+				return err
+			} else if err != nil {
+				errs = append(errs, err)
+			}
+			times = nil
+
+		case 'C':
+			if err := sinkFile(path, line, times); isFatal(err) {
+				return err
+			} else if err != nil {
+				errs = append(errs, err)
+			}
+			times = nil
+
+		default:
+			err := protocolErr
+			if first {
+				compLine := append([]byte{prefix[0]}, line...)
+				err = FatalError(string(compLine))
+			}
+			return teeError(err)
+		}
+	}
+
+	if len(errs) > 0 {
+		return AccError{errs}
+	}
 	return nil
 }
 
-func send(path string) error {
-	f, err := os.Open(path)
+func sinkDir(parent, line string, times *FileTimes) error {
+	if !*iamRecursive {
+		return teeError(FatalError("received directory without -r flag"))
+	}
+
+	perm, _, name, err := parseSubj(line)
+	if err != nil {
+		return teeError(FatalError(err.Error()))
+	}
+
+	name = path.Join(parent, name)
+
+	resetPerm, err := prepareDir(name, perm)
+	if err != nil {
+		return teeError(err)
+	}
+
+	var errs []error
+	if err := sink(name, true); isFatal(err) {
+		return err
+	} else if err != nil {
+		errs = append(errs, err)
+	}
+
+	var pendErrs []error
+	if times != nil {
+		t := []syscall.Timeval{times.Atime, times.Mtime}
+		if err := syscall.Utimes(name, t); err != nil {
+			pendErrs = append(pendErrs, err)
+		}
+	}
+	if resetPerm {
+		if err := os.Chmod(name, perm); err != nil {
+			pendErrs = append(pendErrs, err)
+		}
+	}
+	if len(pendErrs) > 0 {
+		errs = append(errs, pendErrs...)
+		sendError(AccError{pendErrs})
+	}
+
+	if len(errs) > 0 {
+		return AccError{errs}
+	}
+	return nil
+}
+
+func sinkFile(name, line string, times *FileTimes) error {
+	perm, size, subj, err := parseSubj(line)
+	if err != nil {
+		return teeError(FatalError(err.Error()))
+	}
+
+	exists := false
+	if st, err := os.Stat(name); err == nil {
+		exists = true
+		if st.IsDir() {
+			name = path.Join(name, subj)
+		}
+	}
+
+	f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE, perm|S_IWUSR)
+	if err != nil {
+		return teeError(err)
+	}
+	defer f.Close() /* will sync explicitly */
+
+	st, err := f.Stat()
+	if err != nil {
+		return teeError(err)
+	}
+
+	fmt.Fprint(os.Stdout, "\x00")
+
+	var pendErrs []error
+	if wr, err := io.Copy(f, io.LimitReader(os.Stdin, size)); err != nil {
+		if _, err := io.Copy(ioutil.Discard, io.LimitReader(os.Stdin, size-wr)); err != nil {
+			return teeError(FatalError(err.Error()))
+		}
+		pendErrs = append(pendErrs, err)
+	}
+
+	if !exists || st.Mode().IsRegular() {
+		if err := f.Truncate(size); err != nil {
+			pendErrs = append(pendErrs, err)
+		}
+	}
+	if err := f.Sync(); err != nil {
+		pendErrs = append(pendErrs, err)
+	}
+	if *preserveAttrs || !exists {
+		if err := f.Chmod(perm); err != nil {
+			pendErrs = append(pendErrs, err)
+		}
+	}
+	if times != nil {
+		if err := syscall.Utimes(name, []syscall.Timeval{times.Atime, times.Mtime}); err != nil {
+			pendErrs = append(pendErrs, err)
+		}
+	}
+
+	ackErr := ack()
+	if isFatal(ackErr) {
+		return ackErr
+	}
+
+	var sentErr error
+	if len(pendErrs) > 0 {
+		sentErr = AccError{pendErrs}
+		sendError(sentErr)
+	} else {
+		fmt.Fprint(os.Stdout, "\x00")
+	}
+
+	if ackErr != nil {
+		return AccError{append(pendErrs, ackErr)}
+	}
+	return sentErr
+}
+
+func prepareDir(name string, perm os.FileMode) (bool, error) {
+	resetPerm := false
+	if st, err := os.Stat(name); err == nil {
+		if !st.IsDir() {
+			return resetPerm, errors.New(name + ": is not a directory")
+		}
+		if *preserveAttrs {
+			if err := os.Chmod(name, perm); err != nil {
+				return resetPerm, err
+			}
+		}
+	} else if os.IsNotExist(err) {
+		if err := os.Mkdir(name, perm|S_IRWXU); err != nil {
+			return resetPerm, err
+		}
+		resetPerm = true
+	} else {
+		return resetPerm, err
+	}
+	return resetPerm, nil
+}
+
+func send(name string) error {
+	f, err := os.Open(name)
 	if err != nil {
 		return teeError(err)
 	}
@@ -79,7 +314,7 @@ func send(path string) error {
 	if err != nil {
 		return teeError(err)
 	}
-	name := st.Name()
+	name = st.Name()
 
 	switch st.Mode() & os.ModeType {
 	case 0: /* regular file */
@@ -88,18 +323,18 @@ func send(path string) error {
 		if *iamRecursive {
 			return sendDir(f, st)
 		}
-		return teeError(errors.New(fmt.Sprintf("%s: is a directory", name)))
+		return teeError(errors.New(name + ": is a directory"))
 	default:
-		return teeError(errors.New(fmt.Sprintf("%s: not a regular file", name)))
+		return teeError(errors.New(name + ": not a regular file"))
 	}
 
-	if *preserveAttr {
+	if *preserveAttrs {
 		if err := sendAttr(st); err != nil {
 			return err
 		}
 	}
 
-	fmt.Fprintf(os.Stdout, "C%04o %d %s\n", toPosixMode(st.Mode()), st.Size(), name)
+	fmt.Fprintf(os.Stdout, "C%04o %d %s\n", toPosixPerm(st.Mode()), st.Size(), name)
 	if err := ack(); err != nil {
 		return err
 	}
@@ -115,7 +350,7 @@ func send(path string) error {
 		return teeError(err)
 	}
 
-	fmt.Fprintf(os.Stdout, "\x00")
+	fmt.Fprint(os.Stdout, "\x00")
 	return ack()
 }
 
@@ -125,13 +360,13 @@ func sendDir(dir *os.File, st os.FileInfo) error {
 		return teeError(err)
 	}
 
-	if *preserveAttr {
+	if *preserveAttrs {
 		if err := sendAttr(st); err != nil {
 			return err
 		}
 	}
 
-	fmt.Fprintf(os.Stdout, "D%04o %d %s\n", toPosixMode(st.Mode()), 0, st.Name())
+	fmt.Fprintf(os.Stdout, "D%04o %d %s\n", toPosixPerm(st.Mode()), 0, st.Name())
 	if err := ack(); err != nil {
 		return err
 	}
@@ -148,6 +383,9 @@ func sendDir(dir *os.File, st os.FileInfo) error {
 
 	fmt.Fprintf(os.Stdout, "E\n")
 	ackErr := ack()
+	if isFatal(ackErr) {
+		return ackErr
+	}
 
 	if len(sendErrs) > 0 {
 		return AccError{sendErrs}
@@ -155,6 +393,22 @@ func sendDir(dir *os.File, st os.FileInfo) error {
 	return ackErr
 }
 
+func parseSubj(line string) (perm os.FileMode, size int64, name string, err error) {
+	n := 0
+	pperm := 0
+	if n, err = fmt.Sscanf(line, "%o %d %s", &pperm, &size, &name); err != nil {
+		return
+	} else if n != 3 {
+		err = protocolErr
+		return
+	}
+	perm = toStdPerm(pperm)
+	if name == ".." || strings.ContainsRune(name, '/') {
+		err = FatalError(name + ": invalid name")
+	}
+	return
+}
+
 func sendAttr(st os.FileInfo) error {
 	mtime := st.ModTime().Unix()
 	atime := int64(0)
@@ -187,15 +441,19 @@ func ack() error {
 	case 2:
 		return FatalError(l)
 	default:
-		return FatalError("Protocol error")
+		return protocolErr
 	}
 }
 
 func teeError(err error) error {
-	fmt.Fprintf(os.Stdout, "\x01%s\n", err)
+	sendError(err)
 	return err
 }
 
+func sendError(err error) {
+	fmt.Fprintf(os.Stdout, "\x01%s\n", strings.Replace(err.Error(), "\n", "; ", -1))
+}
+
 func readLine() (string, error) {
 	l := make([]byte, 0, 64)
 	ch := []byte{0}
@@ -214,15 +472,26 @@ func readLine() (string, error) {
 	return string(l), nil
 }
 
-func toPosixMode(m os.FileMode) int {
-	pm := m & os.ModePerm
-	if m&os.ModeSetuid != 0 {
-		pm |= 04000
+func toPosixPerm(perm os.FileMode) int {
+	pp := perm & os.ModePerm
+	if perm&os.ModeSetuid != 0 {
+		pp |= S_ISUID
+	}
+	if perm&os.ModeSetgid != 0 {
+		pp |= S_ISGID
+	}
+	return int(pp)
+}
+
+func toStdPerm(posixPerm int) os.FileMode {
+	perm := os.FileMode(posixPerm) & os.ModePerm
+	if posixPerm&S_ISUID != 0 {
+		perm |= os.ModeSetuid
 	}
-	if m&os.ModeSetgid != 0 {
-		pm |= 02000
+	if posixPerm&S_ISGID != 0 {
+		perm |= os.ModeSetgid
 	}
-	return int(pm)
+	return perm
 }
 
 func usage() {
@@ -232,12 +501,22 @@ func usage() {
 	os.Exit(1)
 }
 
+type FileTimes struct {
+	Atime syscall.Timeval
+	Mtime syscall.Timeval
+}
+
 type FatalError string
 
 func (e FatalError) Error() string {
 	return string(e)
 }
 
+func isFatal(err error) bool {
+	_, isFatal := err.(FatalError)
+	return isFatal
+}
+
 type AccError struct {
 	Errors []error
 }