rscp

git clone https://orangeshoelaces.net/git/rscp.git

/rscp.go

   1 package main
   2 
   3 import (
   4         "errors"
   5         "flag"
   6         "fmt"
   7         "io"
   8         "io/ioutil"
   9         "os"
  10         "path"
  11         "strings"
  12         "syscall"
  13 )
  14 
  15 const (
  16         S_IWUSR = 00200
  17         S_IRWXU = 00700
  18         S_ISUID = 04000
  19         S_ISGID = 02000
  20 
  21         MaxErrLen = 1024
  22         DirScanBatchSize = 256
  23 )
  24 
  25 var (
  26         iamSource     = flag.Bool("f", false, "Run in source mode")
  27         iamSink       = flag.Bool("t", false, "Run in sink mode")
  28         bwLimit       = flag.Uint("l", 0, "Limit the bandwidth, specified in Kbit/s")
  29         iamRecursive  = flag.Bool("r", false, "Copy directoires recursively following any symlinks")
  30         targetDir     = flag.Bool("d", false, "Target should be a directory")
  31         preserveAttrs = flag.Bool("p", false, "Preserve modification and access times and mode from original file")
  32 
  33         protocolErr = FatalError("protocol error")
  34 
  35         in io.Reader  = os.Stdin
  36         out io.Writer = os.Stdout
  37 )
  38 
  39 func main() {
  40         flag.Parse()
  41         var args = flag.Args()
  42 
  43         var validMode = (*iamSource || *iamSink) && !(*iamSource && *iamSink)
  44         var validArgc = (*iamSource && len(args) > 0) || (*iamSink && len(args) == 1)
  45 
  46         if !validMode || !validArgc {
  47                 usage()
  48         }
  49 
  50         if *bwLimit > 0 {
  51                 st := NewBwStats(*bwLimit * 1024)
  52                 in = CapReader(in, st)
  53                 out = CapWriter(out, st)
  54         }
  55 
  56         var err error
  57 
  58         if *iamSource {
  59                 err = source(args)
  60         } else {
  61                 err = sink(args[0], false)
  62         }
  63 
  64         if err != nil {
  65                 fmt.Fprintln(os.Stderr, err)
  66                 os.Exit(1)
  67         }
  68 }
  69 
  70 func source(paths []string) error {
  71         if err := ack(); err != nil {
  72                 return err
  73         }
  74 
  75         var sendErrs []error
  76         for _, path := range paths {
  77                 if err := send(path); isFatal(err) {
  78                         return err
  79                 } else if err != nil {
  80                         sendErrs = append(sendErrs, err)
  81                 }
  82         }
  83 
  84         if len(sendErrs) > 0 {
  85                 return AccError{sendErrs}
  86         }
  87         return nil
  88 }
  89 
  90 func sink(path string, recur bool) error {
  91         var errs []error
  92         var times *FileTimes
  93 
  94         if *targetDir {
  95                 if st, err := os.Stat(path); err != nil {
  96                         return teeError(FatalError(err.Error()))
  97                 } else if !st.IsDir() {
  98                         return teeError(FatalError(path + ": is not a directory"))
  99                 }
 100         }
 101 
 102         if _, err := fmt.Fprint(out, "\x00"); err != nil {
 103                 return FatalError(err.Error())
 104         }
 105 
 106         for first := true; ; first = false {
 107                 prefix := []byte{0}
 108                 if _, err := in.Read(prefix); err != nil {
 109                         if err == io.EOF {
 110                                 break
 111                         }
 112                         return FatalError(err.Error())
 113                 }
 114                 line, err := readLine()
 115                 if err != nil {
 116                         return FatalError(err.Error())
 117                 }
 118 
 119                 switch prefix[0] {
 120                 case '\x01':
 121                         errs = append(errs, errors.New(line))
 122 
 123                 case '\x02':
 124                         return FatalError(line)
 125 
 126                 case 'E':
 127                         if !recur {
 128                                 return teeError(protocolErr)
 129                         }
 130                         if _, err := fmt.Fprint(out, "\x00"); err != nil {
 131                                 return FatalError(err.Error())
 132                         }
 133                         goto Out
 134 
 135                 case 'T':
 136                         if times == nil {
 137                                 times = new(FileTimes)
 138                         }
 139                         if n, err := fmt.Sscanf(line, "%d %d %d %d",
 140                                 &times.Mtime.Sec, &times.Mtime.Usec,
 141                                 &times.Atime.Sec, &times.Atime.Usec); err != nil {
 142 
 143                                 return teeError(FatalError(err.Error()))
 144                         } else if n != 4 {
 145                                 return teeError(protocolErr)
 146                         }
 147                         if _, err := fmt.Fprint(out, "\x00"); err != nil {
 148                                 return FatalError(err.Error())
 149                         }
 150 
 151                 case 'D':
 152                         if err := sinkDir(path, line, times); isFatal(err) {
 153                                 return err
 154                         } else if err != nil {
 155                                 errs = append(errs, err)
 156                         }
 157                         times = nil
 158 
 159                 case 'C':
 160                         if err := sinkFile(path, line, times); isFatal(err) {
 161                                 return err
 162                         } else if err != nil {
 163                                 errs = append(errs, err)
 164                         }
 165                         times = nil
 166 
 167                 default:
 168                         err := protocolErr
 169                         if first {
 170                                 compLine := append([]byte{prefix[0]}, line...)
 171                                 err = FatalError(string(compLine))
 172                         }
 173                         return teeError(err)
 174                 }
 175         }
 176 Out:
 177         if len(errs) > 0 {
 178                 return AccError{errs}
 179         }
 180         return nil
 181 }
 182 
 183 func sinkDir(parent, line string, times *FileTimes) error {
 184         if !*iamRecursive {
 185                 return teeError(FatalError("received directory without -r flag"))
 186         }
 187 
 188         perm, _, name, err := parseSubj(line)
 189         if err != nil {
 190                 return teeError(FatalError(err.Error()))
 191         }
 192 
 193         name = path.Join(parent, name)
 194 
 195         resetPerm, err := prepareDir(name, perm)
 196         if err != nil {
 197                 return teeError(err)
 198         }
 199 
 200         var errs []error
 201         if err := sink(name, true); isFatal(err) {
 202                 return err
 203         } else if err != nil {
 204                 errs = append(errs, err)
 205         }
 206 
 207         var pendErrs []error
 208         if times != nil {
 209                 t := []syscall.Timeval{times.Atime, times.Mtime}
 210                 if err := syscall.Utimes(name, t); err != nil {
 211                         pendErrs = append(pendErrs, err)
 212                 }
 213         }
 214         if resetPerm {
 215                 if err := os.Chmod(name, perm); err != nil {
 216                         pendErrs = append(pendErrs, err)
 217                 }
 218         }
 219         if len(pendErrs) > 0 {
 220                 errs = append(errs, pendErrs...)
 221                 if err := sendError(AccError{pendErrs}); err != nil {
 222                         return err
 223                 }
 224         }
 225 
 226         if len(errs) > 0 {
 227                 return AccError{errs}
 228         }
 229         return nil
 230 }
 231 
 232 func sinkFile(name, line string, times *FileTimes) error {
 233         perm, size, subj, err := parseSubj(line)
 234         if err != nil {
 235                 return teeError(FatalError(err.Error()))
 236         }
 237 
 238         exists := false
 239         if st, err := os.Stat(name); err == nil {
 240                 exists = true
 241                 if st.IsDir() {
 242                         name = path.Join(name, subj)
 243                 }
 244         }
 245 
 246         f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE, perm|S_IWUSR)
 247         if err != nil {
 248                 return teeError(err)
 249         }
 250         defer f.Close() /* will sync explicitly */
 251 
 252         st, err := f.Stat()
 253         if err != nil {
 254                 return teeError(err)
 255         }
 256 
 257         if _, err := fmt.Fprint(out, "\x00"); err != nil {
 258                 return FatalError(err.Error())
 259         }
 260 
 261         var pendErrs []error
 262         if wr, err := io.Copy(f, io.LimitReader(in, size)); err != nil {
 263                 if _, err := io.Copy(ioutil.Discard, io.LimitReader(in, size-wr)); err != nil {
 264                         return teeError(FatalError(err.Error()))
 265                 }
 266                 pendErrs = append(pendErrs, err)
 267         }
 268 
 269         if !exists || st.Mode().IsRegular() {
 270                 if err := f.Truncate(size); err != nil {
 271                         pendErrs = append(pendErrs, err)
 272                 }
 273         }
 274         if err := f.Sync(); err != nil {
 275                 pendErrs = append(pendErrs, err)
 276         }
 277         if *preserveAttrs || !exists {
 278                 if err := f.Chmod(perm); err != nil {
 279                         pendErrs = append(pendErrs, err)
 280                 }
 281         }
 282         if times != nil {
 283                 if err := syscall.Utimes(name,
 284                         []syscall.Timeval{times.Atime, times.Mtime}); err != nil {
 285 
 286                         pendErrs = append(pendErrs, err)
 287                 }
 288         }
 289 
 290         ackErr := ack()
 291         if isFatal(ackErr) {
 292                 return ackErr
 293         }
 294 
 295         var sentErr error
 296         if len(pendErrs) > 0 {
 297                 sentErr = AccError{pendErrs}
 298                 if err := sendError(sentErr); err != nil {
 299                         return err
 300                 }
 301         } else {
 302                 if _, err := fmt.Fprint(out, "\x00"); err != nil {
 303                         return FatalError(err.Error())
 304                 }
 305         }
 306 
 307         if ackErr != nil {
 308                 return AccError{append(pendErrs, ackErr)}
 309         }
 310         return sentErr
 311 }
 312 
 313 func prepareDir(name string, perm os.FileMode) (bool, error) {
 314         resetPerm := false
 315         if st, err := os.Stat(name); err == nil {
 316                 if !st.IsDir() {
 317                         return resetPerm, errors.New(name + ": is not a directory")
 318                 }
 319                 if *preserveAttrs {
 320                         if err := os.Chmod(name, perm); err != nil {
 321                                 return resetPerm, err
 322                         }
 323                 }
 324         } else if os.IsNotExist(err) {
 325                 if err := os.Mkdir(name, perm|S_IRWXU); err != nil {
 326                         return resetPerm, err
 327                 }
 328                 resetPerm = true
 329         } else {
 330                 return resetPerm, err
 331         }
 332         return resetPerm, nil
 333 }
 334 
 335 func send(name string) error {
 336         f, err := os.Open(name)
 337         if err != nil {
 338                 return teeError(err)
 339         }
 340         defer f.Close()
 341 
 342         st, err := f.Stat()
 343         if err != nil {
 344                 return teeError(err)
 345         }
 346 
 347         if mode := st.Mode(); mode.IsDir() {
 348                 if *iamRecursive {
 349                         return sendDir(f, st)
 350                 }
 351                 return teeError(errors.New(name + ": is a directory"))
 352         } else if !mode.IsRegular() {
 353                 return teeError(errors.New(name + ": not a regular file"))
 354         }
 355 
 356         if *preserveAttrs {
 357                 if err := sendAttr(st); err != nil {
 358                         return err
 359                 }
 360         }
 361 
 362         base := st.Name()
 363         if _, err := fmt.Fprintf(out, "C%04o %d %s\n",
 364                 toPosixPerm(st.Mode()), st.Size(), base); err != nil {
 365 
 366                 return FatalError(err.Error())
 367         }
 368         if err := ack(); err != nil {
 369                 return err
 370         }
 371 
 372         if sent, err := io.Copy(out, f); err != nil {
 373                 patch := io.LimitReader(ConstReader(0), st.Size()-sent)
 374                 if _, err := io.Copy(out, patch); err != nil {
 375                         return FatalError(err.Error())
 376                 }
 377                 if err := ack(); err != nil {
 378                         return err
 379                 }
 380                 return teeError(err)
 381         }
 382 
 383         if _, err := fmt.Fprint(out, "\x00"); err != nil {
 384                 return FatalError(err.Error())
 385         }
 386         return ack()
 387 }
 388 
 389 func sendDir(dir *os.File, st os.FileInfo) error {
 390         if *preserveAttrs {
 391                 if err := sendAttr(st); err != nil {
 392                         return err
 393                 }
 394         }
 395 
 396         if _, err := fmt.Fprintf(out, "D%04o %d %s\n",
 397                 toPosixPerm(st.Mode()), 0, st.Name()); err != nil {
 398 
 399                 return FatalError(err.Error())
 400         }
 401         if err := ack(); err != nil {
 402                 return err
 403         }
 404 
 405         var sendErrs []error
 406         for {
 407                 children, err := dir.Readdir(DirScanBatchSize)
 408                 for _, child := range children {
 409                         if err := send(path.Join(dir.Name(), child.Name())); isFatal(err) {
 410                                 return err
 411                         } else if err != nil {
 412                                 sendErrs = append(sendErrs, err)
 413                         }
 414                 }
 415                 if err == io.EOF {
 416                         break
 417                 } else if err != nil {
 418                         return teeError(err)
 419                 }
 420         }
 421 
 422         if _, err := fmt.Fprintf(out, "E\n"); err != nil {
 423                 return FatalError(err.Error())
 424         }
 425         ackErr := ack()
 426         if isFatal(ackErr) {
 427                 return ackErr
 428         }
 429 
 430         if len(sendErrs) > 0 {
 431                 return AccError{sendErrs}
 432         }
 433         return ackErr
 434 }
 435 
 436 func parseSubj(line string) (perm os.FileMode, size int64, name string, err error) {
 437         n := 0
 438         pperm := 0
 439         if n, err = fmt.Sscanf(line, "%o %d %s", &pperm, &size, &name); err != nil {
 440                 return
 441         } else if n != 3 {
 442                 err = protocolErr
 443                 return
 444         }
 445         perm = toStdPerm(pperm)
 446         if name == ".." || strings.ContainsRune(name, '/') {
 447                 err = FatalError(name + ": invalid name")
 448         }
 449         return
 450 }
 451 
 452 func sendAttr(st os.FileInfo) error {
 453         mtime := st.ModTime().Unix()
 454         atime := int64(0)
 455 
 456         if sysStat, ok := st.Sys().(*syscall.Stat_t); ok {
 457                 atime, _ = sysStat.Atim.Unix()
 458         }
 459 
 460         if _, err := fmt.Fprintf(out, "T%d 0 %d 0\n", mtime, atime); err != nil {
 461                 return FatalError(err.Error())
 462         }
 463         return ack()
 464 }
 465 
 466 func ack() error {
 467         kind := []byte{0}
 468         if _, err := in.Read(kind); err != nil {
 469                 return FatalError(err.Error())
 470         }
 471         if kind[0] == 0 {
 472                 return nil
 473         }
 474 
 475         l, err := readLine()
 476         if err != nil {
 477                 return FatalError(err.Error())
 478         }
 479 
 480         switch kind[0] {
 481         case 1:
 482                 return errors.New(l)
 483         case 2:
 484                 return FatalError(l)
 485         default:
 486                 return protocolErr
 487         }
 488 }
 489 
 490 func teeError(err error) error {
 491         if err := sendError(err); err != nil {
 492                 return err
 493         }
 494         return err
 495 }
 496 
 497 func sendError(err error) error {
 498         line := strings.Replace(err.Error(), "\n", "; ", -1)
 499         /* make complete protocol line with zero terminator (i.e \x01%s\n\x00) fit into MaxErrLen buffer */
 500         if len(line) > MaxErrLen-3 {
 501                 line = line[:MaxErrLen-6] + "..."
 502         }
 503         if _, err := fmt.Fprintf(out, "\x01%s\n", line); err != nil {
 504                 return FatalError(err.Error())
 505         }
 506         return nil
 507 }
 508 
 509 func readLine() (string, error) {
 510         l := make([]byte, 0, 64)
 511         ch := []byte{0}
 512 
 513         for {
 514                 if _, err := in.Read(ch); err != nil {
 515                         return "", err
 516                 } else {
 517                         if ch[0] == '\n' {
 518                                 break
 519                         }
 520                         l = append(l, ch[0])
 521                 }
 522         }
 523 
 524         return string(l), nil
 525 }
 526 
 527 func toPosixPerm(perm os.FileMode) int {
 528         pp := perm & os.ModePerm
 529         if perm&os.ModeSetuid != 0 {
 530                 pp |= S_ISUID
 531         }
 532         if perm&os.ModeSetgid != 0 {
 533                 pp |= S_ISGID
 534         }
 535         return int(pp)
 536 }
 537 
 538 func toStdPerm(posixPerm int) os.FileMode {
 539         perm := os.FileMode(posixPerm) & os.ModePerm
 540         if posixPerm&S_ISUID != 0 {
 541                 perm |= os.ModeSetuid
 542         }
 543         if posixPerm&S_ISGID != 0 {
 544                 perm |= os.ModeSetgid
 545         }
 546         return perm
 547 }
 548 
 549 func usage() {
 550         fmt.Fprintf(os.Stderr, "Usage: rscp -f [-pr] [-l limit] file1 ...\n"+
 551                 "       rscp -t [-prd] [-l limit] directory\n")
 552         flag.PrintDefaults()
 553         os.Exit(1)
 554 }
 555 
 556 type FileTimes struct {
 557         Atime syscall.Timeval
 558         Mtime syscall.Timeval
 559 }
 560 
 561 type FatalError string
 562 
 563 func (e FatalError) Error() string {
 564         return string(e)
 565 }
 566 
 567 func isFatal(err error) bool {
 568         _, isFatal := err.(FatalError)
 569         return isFatal
 570 }
 571 
 572 type AccError struct {
 573         Errors []error
 574 }
 575 
 576 func (e AccError) Error() string {
 577         ve := []interface{}{}
 578         for _, err := range e.Errors {
 579                 ve = append(ve, err)
 580         }
 581         return fmt.Sprintln(ve...)
 582 }
 583 
 584 type ConstReader byte
 585 
 586 func (c ConstReader) Read(b []byte) (int, error) {
 587         for i, _ := range b {
 588                 b[i] = byte(c)
 589         }
 590         return len(b), nil
 591 }