rscp

git clone https://orangeshoelaces.net/git/rscp.git

2ca3faee2b5cbf2660a0ab12669a81e2ff89d3e4

Author: Vasily Kolobkov on 06/21/2017

Committer: Vasily Kolobkov on 06/21/2017

Limit bandwidth

Stats

bwcapio.go | 83 ++++++++
rscp.go    | 47 ++--
2 files changed, 111 insertions(+), 19 deletions(-)

Patch

diff --git a/bwcapio.go b/bwcapio.go
new file mode 100644
index 0000000..16b9922
--- /dev/null
+++ b/bwcapio.go
@@ -0,0 +1,83 @@
+package main
+
+import (
+	"io"
+	"time"
+)
+
+type BwStats struct {
+	Last   time.Time /* time of last observed event */
+	Wnd    uint      /* unmetered bytes */
+	Thresh uint      /* delay after at least this much bytes */
+	Rate   uint      /* bandwidth limit in bits/second */
+}
+
+func NewBwStats(rate uint) *BwStats {
+	return &BwStats{Wnd: 0, Thresh: rate, Rate: rate}
+}
+
+func CapReader(r io.Reader, st *BwStats) io.Reader {
+	if st == nil {
+		panic("nil stats")
+	}
+	return &BwCapReader{r, st}
+}
+
+func CapWriter(w io.Writer, st *BwStats) io.Writer {
+	if st == nil {
+		panic("nil stats")
+	}
+	return &BwCapWriter{w, st}
+}
+
+type BwCapReader struct {
+	Base  io.Reader
+	Stats *BwStats
+}
+
+func (r *BwCapReader) Read(p []byte) (int, error) {
+	n, err := r.Base.Read(p)
+	bwCap(r.Stats, n)
+	return n, err
+}
+
+type BwCapWriter struct {
+	Base  io.Writer
+	Stats *BwStats
+}
+
+func (w *BwCapWriter) Write(p []byte) (int, error) {
+	n, err := w.Base.Write(p)
+	bwCap(w.Stats, n)
+	return n, err
+}
+
+func bwCap(st *BwStats, transfered int) {
+	if transfered <= 0 {
+		return 
+	}
+	if st.Last.IsZero() {
+		st.Last = time.Now()
+		return
+	}
+	st.Wnd += uint(transfered)
+	if st.Wnd < st.Thresh {
+		return
+	}
+
+	bits := st.Wnd * 8
+	exp := time.Duration((1e9 * bits) / st.Rate)
+	ahead := exp - time.Since(st.Last)
+
+	if ahead > 0 {
+		if ahead.Seconds() > 1 {
+			st.Thresh /= 2
+		} else if ahead < 10*time.Millisecond {
+			st.Thresh *= 2
+		}
+		time.Sleep(ahead)
+	}
+
+	st.Wnd = 0
+	st.Last = time.Now()
+}
diff --git a/rscp.go b/rscp.go
index f4da194..3f3f998 100644
--- a/rscp.go
+++ b/rscp.go
@@ -24,12 +24,15 @@ const (
 var (
 	iamSource     = flag.Bool("f", false, "Run in source mode")
 	iamSink       = flag.Bool("t", false, "Run in sink mode")
-	bwLimit       = flag.Int("l", 0, "Limit the bandwidth, specified in Kbit/s")
+	bwLimit       = flag.Uint("l", 0, "Limit the bandwidth, specified in Kbit/s")
 	iamRecursive  = flag.Bool("r", false, "Copy directoires recursively following any symlinks")
 	targetDir     = flag.Bool("d", false, "Target should be a directory")
 	preserveAttrs = flag.Bool("p", false, "Preserve modification and access times and mode from original file")
 
 	protocolErr = FatalError("protocol error")
+
+	in io.Reader  = os.Stdin
+	out io.Writer = os.Stdout
 )
 
 func main() {
@@ -43,6 +46,12 @@ func main() {
 		usage()
 	}
 
+	if *bwLimit > 0 {
+		st := NewBwStats(*bwLimit * 1024)
+		in = CapReader(in, st)
+		out = CapWriter(out, st)
+	}
+
 	var err error
 
 	if *iamSource {
@@ -89,11 +98,11 @@ func sink(path string, recur bool) error {
 		}
 	}
 
-	fmt.Fprint(os.Stdout, "\x00")
+	fmt.Fprint(out, "\x00")
 
 	for first := true; ; first = false {
 		prefix := []byte{0}
-		if _, err := os.Stdin.Read(prefix); err != nil {
+		if _, err := in.Read(prefix); err != nil {
 			if err == io.EOF {
 				break
 			}
@@ -115,7 +124,7 @@ func sink(path string, recur bool) error {
 			if !recur {
 				return teeError(protocolErr)
 			}
-			fmt.Fprint(os.Stdout, "\x00")
+			fmt.Fprint(out, "\x00")
 
 		case 'T':
 			if times == nil {
@@ -129,7 +138,7 @@ func sink(path string, recur bool) error {
 			} else if n != 4 {
 				return teeError(protocolErr)
 			}
-			fmt.Fprint(os.Stdout, "\x00")
+			fmt.Fprint(out, "\x00")
 
 		case 'D':
 			if err := sinkDir(path, line, times); isFatal(err) {
@@ -235,11 +244,11 @@ func sinkFile(name, line string, times *FileTimes) error {
 		return teeError(err)
 	}
 
-	fmt.Fprint(os.Stdout, "\x00")
+	fmt.Fprint(out, "\x00")
 
 	var pendErrs []error
-	if wr, err := io.Copy(f, io.LimitReader(os.Stdin, size)); err != nil {
-		if _, err := io.Copy(ioutil.Discard, io.LimitReader(os.Stdin, size-wr)); err != nil {
+	if wr, err := io.Copy(f, io.LimitReader(in, size)); err != nil {
+		if _, err := io.Copy(ioutil.Discard, io.LimitReader(in, size-wr)); err != nil {
 			return teeError(FatalError(err.Error()))
 		}
 		pendErrs = append(pendErrs, err)
@@ -274,7 +283,7 @@ func sinkFile(name, line string, times *FileTimes) error {
 		sentErr = AccError{pendErrs}
 		sendError(sentErr)
 	} else {
-		fmt.Fprint(os.Stdout, "\x00")
+		fmt.Fprint(out, "\x00")
 	}
 
 	if ackErr != nil {
@@ -336,14 +345,14 @@ func send(name string) error {
 		}
 	}
 
-	fmt.Fprintf(os.Stdout, "C%04o %d %s\n", toPosixPerm(st.Mode()), st.Size(), name)
+	fmt.Fprintf(out, "C%04o %d %s\n", toPosixPerm(st.Mode()), st.Size(), name)
 	if err := ack(); err != nil {
 		return err
 	}
 
-	if sent, err := io.Copy(os.Stdout, f); err != nil {
+	if sent, err := io.Copy(out, f); err != nil {
 		patch := io.LimitReader(ConstReader(0), st.Size()-sent)
-		if _, err := io.Copy(os.Stdout, patch); err != nil {
+		if _, err := io.Copy(out, patch); err != nil {
 			return FatalError(err.Error())
 		}
 		if err := ack(); err != nil {
@@ -352,7 +361,7 @@ func send(name string) error {
 		return teeError(err)
 	}
 
-	fmt.Fprint(os.Stdout, "\x00")
+	fmt.Fprint(out, "\x00")
 	return ack()
 }
 
@@ -368,7 +377,7 @@ func sendDir(dir *os.File, st os.FileInfo) error {
 		}
 	}
 
-	fmt.Fprintf(os.Stdout, "D%04o %d %s\n", toPosixPerm(st.Mode()), 0, st.Name())
+	fmt.Fprintf(out, "D%04o %d %s\n", toPosixPerm(st.Mode()), 0, st.Name())
 	if err := ack(); err != nil {
 		return err
 	}
@@ -383,7 +392,7 @@ func sendDir(dir *os.File, st os.FileInfo) error {
 		}
 	}
 
-	fmt.Fprintf(os.Stdout, "E\n")
+	fmt.Fprintf(out, "E\n")
 	ackErr := ack()
 	if isFatal(ackErr) {
 		return ackErr
@@ -419,13 +428,13 @@ func sendAttr(st os.FileInfo) error {
 		atime, _ = sysStat.Atim.Unix()
 	}
 
-	fmt.Fprintf(os.Stdout, "T%d 0 %d 0\n", mtime, atime)
+	fmt.Fprintf(out, "T%d 0 %d 0\n", mtime, atime)
 	return ack()
 }
 
 func ack() error {
 	kind := []byte{0}
-	if _, err := os.Stdin.Read(kind); err != nil {
+	if _, err := in.Read(kind); err != nil {
 		return FatalError(err.Error())
 	}
 	if kind[0] == 0 {
@@ -458,7 +467,7 @@ func sendError(err error) {
 	if len(line) > MaxErrLen-3 {
 		line = line[:MaxErrLen-6] + "..."
 	}
-	fmt.Fprintf(os.Stdout, "\x01%s\n", line)
+	fmt.Fprintf(out, "\x01%s\n", line)
 }
 
 func readLine() (string, error) {
@@ -466,7 +475,7 @@ func readLine() (string, error) {
 	ch := []byte{0}
 
 	for {
-		if _, err := os.Stdin.Read(ch); err != nil {
+		if _, err := in.Read(ch); err != nil {
 			return "", err
 		} else {
 			if ch[0] == '\n' {