git clone https://orangeshoelaces.net/git/rscp.git
Author: Vasily Kolobkov on 06/21/2017
Committer: Vasily Kolobkov on 06/21/2017
Limit bandwidth
bwcapio.go | 83 ++++++++
rscp.go | 47 ++--
2 files changed, 111 insertions(+), 19 deletions(-)
diff --git a/bwcapio.go b/bwcapio.go
new file mode 100644
index 0000000..16b9922
--- /dev/null
+++ b/bwcapio.go
@@ -0,0 +1,83 @@
+package main
+
+import (
+ "io"
+ "time"
+)
+
+type BwStats struct {
+ Last time.Time /* time of last observed event */
+ Wnd uint /* unmetered bytes */
+ Thresh uint /* delay after at least this much bytes */
+ Rate uint /* bandwidth limit in bits/second */
+}
+
+func NewBwStats(rate uint) *BwStats {
+ return &BwStats{Wnd: 0, Thresh: rate, Rate: rate}
+}
+
+func CapReader(r io.Reader, st *BwStats) io.Reader {
+ if st == nil {
+ panic("nil stats")
+ }
+ return &BwCapReader{r, st}
+}
+
+func CapWriter(w io.Writer, st *BwStats) io.Writer {
+ if st == nil {
+ panic("nil stats")
+ }
+ return &BwCapWriter{w, st}
+}
+
+type BwCapReader struct {
+ Base io.Reader
+ Stats *BwStats
+}
+
+func (r *BwCapReader) Read(p []byte) (int, error) {
+ n, err := r.Base.Read(p)
+ bwCap(r.Stats, n)
+ return n, err
+}
+
+type BwCapWriter struct {
+ Base io.Writer
+ Stats *BwStats
+}
+
+func (w *BwCapWriter) Write(p []byte) (int, error) {
+ n, err := w.Base.Write(p)
+ bwCap(w.Stats, n)
+ return n, err
+}
+
+func bwCap(st *BwStats, transfered int) {
+ if transfered <= 0 {
+ return
+ }
+ if st.Last.IsZero() {
+ st.Last = time.Now()
+ return
+ }
+ st.Wnd += uint(transfered)
+ if st.Wnd < st.Thresh {
+ return
+ }
+
+ bits := st.Wnd * 8
+ exp := time.Duration((1e9 * bits) / st.Rate)
+ ahead := exp - time.Since(st.Last)
+
+ if ahead > 0 {
+ if ahead.Seconds() > 1 {
+ st.Thresh /= 2
+ } else if ahead < 10*time.Millisecond {
+ st.Thresh *= 2
+ }
+ time.Sleep(ahead)
+ }
+
+ st.Wnd = 0
+ st.Last = time.Now()
+}
diff --git a/rscp.go b/rscp.go
index f4da194..3f3f998 100644
--- a/rscp.go
+++ b/rscp.go
@@ -24,12 +24,15 @@ const (
var (
iamSource = flag.Bool("f", false, "Run in source mode")
iamSink = flag.Bool("t", false, "Run in sink mode")
- bwLimit = flag.Int("l", 0, "Limit the bandwidth, specified in Kbit/s")
+ bwLimit = flag.Uint("l", 0, "Limit the bandwidth, specified in Kbit/s")
iamRecursive = flag.Bool("r", false, "Copy directoires recursively following any symlinks")
targetDir = flag.Bool("d", false, "Target should be a directory")
preserveAttrs = flag.Bool("p", false, "Preserve modification and access times and mode from original file")
protocolErr = FatalError("protocol error")
+
+ in io.Reader = os.Stdin
+ out io.Writer = os.Stdout
)
func main() {
@@ -43,6 +46,12 @@ func main() {
usage()
}
+ if *bwLimit > 0 {
+ st := NewBwStats(*bwLimit * 1024)
+ in = CapReader(in, st)
+ out = CapWriter(out, st)
+ }
+
var err error
if *iamSource {
@@ -89,11 +98,11 @@ func sink(path string, recur bool) error {
}
}
- fmt.Fprint(os.Stdout, "\x00")
+ fmt.Fprint(out, "\x00")
for first := true; ; first = false {
prefix := []byte{0}
- if _, err := os.Stdin.Read(prefix); err != nil {
+ if _, err := in.Read(prefix); err != nil {
if err == io.EOF {
break
}
@@ -115,7 +124,7 @@ func sink(path string, recur bool) error {
if !recur {
return teeError(protocolErr)
}
- fmt.Fprint(os.Stdout, "\x00")
+ fmt.Fprint(out, "\x00")
case 'T':
if times == nil {
@@ -129,7 +138,7 @@ func sink(path string, recur bool) error {
} else if n != 4 {
return teeError(protocolErr)
}
- fmt.Fprint(os.Stdout, "\x00")
+ fmt.Fprint(out, "\x00")
case 'D':
if err := sinkDir(path, line, times); isFatal(err) {
@@ -235,11 +244,11 @@ func sinkFile(name, line string, times *FileTimes) error {
return teeError(err)
}
- fmt.Fprint(os.Stdout, "\x00")
+ fmt.Fprint(out, "\x00")
var pendErrs []error
- if wr, err := io.Copy(f, io.LimitReader(os.Stdin, size)); err != nil {
- if _, err := io.Copy(ioutil.Discard, io.LimitReader(os.Stdin, size-wr)); err != nil {
+ if wr, err := io.Copy(f, io.LimitReader(in, size)); err != nil {
+ if _, err := io.Copy(ioutil.Discard, io.LimitReader(in, size-wr)); err != nil {
return teeError(FatalError(err.Error()))
}
pendErrs = append(pendErrs, err)
@@ -274,7 +283,7 @@ func sinkFile(name, line string, times *FileTimes) error {
sentErr = AccError{pendErrs}
sendError(sentErr)
} else {
- fmt.Fprint(os.Stdout, "\x00")
+ fmt.Fprint(out, "\x00")
}
if ackErr != nil {
@@ -336,14 +345,14 @@ func send(name string) error {
}
}
- fmt.Fprintf(os.Stdout, "C%04o %d %s\n", toPosixPerm(st.Mode()), st.Size(), name)
+ fmt.Fprintf(out, "C%04o %d %s\n", toPosixPerm(st.Mode()), st.Size(), name)
if err := ack(); err != nil {
return err
}
- if sent, err := io.Copy(os.Stdout, f); err != nil {
+ if sent, err := io.Copy(out, f); err != nil {
patch := io.LimitReader(ConstReader(0), st.Size()-sent)
- if _, err := io.Copy(os.Stdout, patch); err != nil {
+ if _, err := io.Copy(out, patch); err != nil {
return FatalError(err.Error())
}
if err := ack(); err != nil {
@@ -352,7 +361,7 @@ func send(name string) error {
return teeError(err)
}
- fmt.Fprint(os.Stdout, "\x00")
+ fmt.Fprint(out, "\x00")
return ack()
}
@@ -368,7 +377,7 @@ func sendDir(dir *os.File, st os.FileInfo) error {
}
}
- fmt.Fprintf(os.Stdout, "D%04o %d %s\n", toPosixPerm(st.Mode()), 0, st.Name())
+ fmt.Fprintf(out, "D%04o %d %s\n", toPosixPerm(st.Mode()), 0, st.Name())
if err := ack(); err != nil {
return err
}
@@ -383,7 +392,7 @@ func sendDir(dir *os.File, st os.FileInfo) error {
}
}
- fmt.Fprintf(os.Stdout, "E\n")
+ fmt.Fprintf(out, "E\n")
ackErr := ack()
if isFatal(ackErr) {
return ackErr
@@ -419,13 +428,13 @@ func sendAttr(st os.FileInfo) error {
atime, _ = sysStat.Atim.Unix()
}
- fmt.Fprintf(os.Stdout, "T%d 0 %d 0\n", mtime, atime)
+ fmt.Fprintf(out, "T%d 0 %d 0\n", mtime, atime)
return ack()
}
func ack() error {
kind := []byte{0}
- if _, err := os.Stdin.Read(kind); err != nil {
+ if _, err := in.Read(kind); err != nil {
return FatalError(err.Error())
}
if kind[0] == 0 {
@@ -458,7 +467,7 @@ func sendError(err error) {
if len(line) > MaxErrLen-3 {
line = line[:MaxErrLen-6] + "..."
}
- fmt.Fprintf(os.Stdout, "\x01%s\n", line)
+ fmt.Fprintf(out, "\x01%s\n", line)
}
func readLine() (string, error) {
@@ -466,7 +475,7 @@ func readLine() (string, error) {
ch := []byte{0}
for {
- if _, err := os.Stdin.Read(ch); err != nil {
+ if _, err := in.Read(ch); err != nil {
return "", err
} else {
if ch[0] == '\n' {