diff --git a/client.go b/client.go index f9bc5ce2..577e988a 100644 --- a/client.go +++ b/client.go @@ -387,27 +387,11 @@ func (c *Client) opendir(path string) (string, error) { // Stat returns a FileInfo structure describing the file specified by path 'p'. // If 'p' is a symbolic link, the returned FileInfo structure describes the referent file. func (c *Client) Stat(p string) (os.FileInfo, error) { - id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{ - ID: id, - Path: p, - }) + fs, err := c.stat(p) if err != nil { return nil, err } - switch typ { - case sshFxpAttrs: - sid, data := unmarshalUint32(data) - if sid != id { - return nil, &unexpectedIDErr{id, sid} - } - attr, _ := unmarshalAttrs(data) - return fileInfoFromStat(attr, path.Base(p)), nil - case sshFxpStatus: - return nil, normaliseError(unmarshalStatus(id, data)) - default: - return nil, unimplementedPacketErr(typ) - } + return fileInfoFromStat(fs, path.Base(p)), nil } // Lstat returns a FileInfo structure describing the file specified by path 'p'. @@ -638,6 +622,30 @@ func (c *Client) close(handle string) error { } } +func (c *Client) stat(path string) (*FileStat, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{ + ID: id, + Path: path, + }) + if err != nil { + return nil, err + } + switch typ { + case sshFxpAttrs: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return attr, nil + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + func (c *Client) fstat(handle string) (*FileStat, error) { id := c.nextID() typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{ @@ -1160,23 +1168,19 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } // For concurrency, we want to guess how many concurrent workers we should use. - var fileSize uint64 + var fileStat *FileStat if f.c.useFstat { - fileStat, err := f.c.fstat(f.handle) - if err != nil { - return 0, err - } - fileSize = fileStat.Size + fileStat, err = f.c.fstat(f.handle) } else { - fi, err := f.c.Stat(f.path) - if err != nil { - return 0, err - } - fileSize = uint64(fi.Size()) + fileStat, err = f.c.stat(f.path) + } + if err != nil { + return 0, err } - if fileSize <= uint64(f.c.maxPacket) { - // We should be able to handle this in one Read. + fileSize := fileStat.Size + if fileSize <= uint64(f.c.maxPacket) || !isRegular(fileStat.Mode) { + // only regular files are guaranteed to return (full read) xor (partial read, next error) return f.writeToSequential(w) } @@ -1187,6 +1191,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { // Now that concurrency64 is saturated to an int value, we know this assignment cannot possibly overflow. concurrency := int(concurrency64) + chunkSize := f.c.maxPacket + pool := newBufPool(concurrency, chunkSize) + resPool := newResChanPool(concurrency) + cancel := make(chan struct{}) var wg sync.WaitGroup defer func() { @@ -1200,7 +1208,6 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { type writeWork struct { b []byte - n int off int64 err error @@ -1209,7 +1216,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { writeCh := make(chan writeWork) type readWork struct { - off int64 + id uint32 + res chan result + off int64 + cur, next chan writeWork } readCh := make(chan readWork) @@ -1219,49 +1229,78 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { defer close(readCh) off := f.offset - chunkSize := int64(f.c.maxPacket) cur := writeCh for { + id := f.c.nextID() + res := resPool.Get() + next := make(chan writeWork) readWork := readWork{ - off: off, + id: id, + res: res, + off: off, + cur: cur, next: next, } + f.c.dispatchRequest(res, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off), + Len: uint32(chunkSize), + }) + select { case readCh <- readWork: case <-cancel: return } - off += chunkSize + off += int64(chunkSize) cur = next } }() - pool := newBufPool(concurrency, f.c.maxPacket) - wg.Add(concurrency) for i := 0; i < concurrency; i++ { // Map_i: each worker gets readWork, and does the Read into a buffer at the given offset. go func() { defer wg.Done() - ch := make(chan result, 1) // reusable channel - for readWork := range readCh { - b := pool.Get() - - n, err := f.readChunkAt(ch, b, readWork.off) - if n < 0 { - panic("sftp.File: returned negative count from readChunkAt") + var b []byte + var n int + + s := <-readWork.res + resPool.Put(readWork.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(readWork.id, s.data)) + + case sshFxpData: + sid, data := unmarshalUint32(s.data) + if readWork.id != sid { + err = &unexpectedIDErr{readWork.id, sid} + + } else { + l, data := unmarshalUint32(data) + b = pool.Get()[:l] + n = copy(b, data[:l]) + b = b[:n] + } + + default: + err = unimplementedPacketErr(s.typ) + } } writeWork := writeWork{ b: b, - n: n, off: readWork.off, err: err, @@ -1290,10 +1329,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } // Because writes are serialized, this will always be the last successfully read byte. - f.offset = packet.off + int64(packet.n) + f.offset = packet.off + int64(len(packet.b)) - if packet.n > 0 { - n, err := w.Write(packet.b[:packet.n]) + if len(packet.b) > 0 { + n, err := w.Write(packet.b) written += int64(n) if err != nil { return written, err diff --git a/pool.go b/pool.go index c074637a..563f64bd 100644 --- a/pool.go +++ b/pool.go @@ -1,5 +1,7 @@ package sftp +// bufPool provides a pool of byte-slices to be reused in various parts of the package. +// It is safe to use concurrently through a pointer. type bufPool struct { ch chan []byte blen int @@ -13,17 +15,66 @@ func newBufPool(depth, bufLen int) *bufPool { } func (p *bufPool) Get() []byte { - select { - case b := <-p.ch: - return b - default: + if p == nil { + // functional default: no reuse. return make([]byte, p.blen) } + + for { + select { + case b := <-p.ch: + if cap(b) < p.blen { + // just in case: throw away any buffer with insufficient capacity. + continue + } + + return b[:p.blen] + + default: + return make([]byte, p.blen) + } + } } func (p *bufPool) Put(b []byte) { + if p == nil { + // functional default: no reuse. + return + } + + if cap(b) < p.blen || cap(b) > p.blen*2 { + // DO NOT reuse buffers with insufficient capacity. + // This could cause panics when resizing to p.blen. + + // DO NOT reuse buffers with excessive capacity. + // This could cause memory leaks. + return + } + select { case p.ch <- b: default: } } + +type resChanPool chan chan result + +func newResChanPool(depth int) resChanPool { + return make(chan chan result, depth) +} + +func (p resChanPool) Get() chan result { + select { + case ch := <-p: + return ch + default: + return make(chan result, 1) + } +} + +func (p resChanPool) Put(ch chan result) { + select { + case p <- ch: + default: + } +} diff --git a/stat_plan9.go b/stat_plan9.go index 25074fe5..418f121c 100644 --- a/stat_plan9.go +++ b/stat_plan9.go @@ -41,6 +41,11 @@ func translateSyscallError(err error) (uint32, bool) { return 0, false } +// isRegular returns true if the mode describes a regular file. +func isRegular(mode uint32) bool { + return mode&S_IFMT == syscall.S_IFREG +} + // toFileMode converts sftp filemode bits to the os.FileMode specification func toFileMode(mode uint32) os.FileMode { var fm = os.FileMode(mode & 0777) diff --git a/stat_posix.go b/stat_posix.go index 71080acb..98b60e77 100644 --- a/stat_posix.go +++ b/stat_posix.go @@ -43,6 +43,11 @@ func translateSyscallError(err error) (uint32, bool) { return 0, false } +// isRegular returns true if the mode describes a regular file. +func isRegular(mode uint32) bool { + return mode&S_IFMT == syscall.S_IFREG +} + // toFileMode converts sftp filemode bits to the os.FileMode specification func toFileMode(mode uint32) os.FileMode { var fm = os.FileMode(mode & 0777)