Skip to content

Commit

Permalink
Merge pull request #436 from pkg/patch/sequential-concurrent-read-req…
Browse files Browse the repository at this point in the history
…uests

[bugfix] Sequentially issue read requests, process results concurrently
  • Loading branch information
puellanivis authored May 22, 2021
2 parents 9744aee + d5fa851 commit de44fbb
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 53 deletions.
137 changes: 88 additions & 49 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,27 +387,11 @@ func (c *Client) opendir(path string) (string, error) {
// Stat returns a FileInfo structure describing the file specified by path 'p'.
// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file.
func (c *Client) Stat(p string) (os.FileInfo, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
ID: id,
Path: p,
})
fs, err := c.stat(p)
if err != nil {
return nil, err
}
switch typ {
case sshFxpAttrs:
sid, data := unmarshalUint32(data)
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return fileInfoFromStat(attr, path.Base(p)), nil
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
return fileInfoFromStat(fs, path.Base(p)), nil
}

// Lstat returns a FileInfo structure describing the file specified by path 'p'.
Expand Down Expand Up @@ -638,6 +622,30 @@ func (c *Client) close(handle string) error {
}
}

func (c *Client) stat(path string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
ID: id,
Path: path,
})
if err != nil {
return nil, err
}
switch typ {
case sshFxpAttrs:
sid, data := unmarshalUint32(data)
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return attr, nil
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
}

func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{
Expand Down Expand Up @@ -1160,23 +1168,19 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
}

// For concurrency, we want to guess how many concurrent workers we should use.
var fileSize uint64
var fileStat *FileStat
if f.c.useFstat {
fileStat, err := f.c.fstat(f.handle)
if err != nil {
return 0, err
}
fileSize = fileStat.Size
fileStat, err = f.c.fstat(f.handle)
} else {
fi, err := f.c.Stat(f.path)
if err != nil {
return 0, err
}
fileSize = uint64(fi.Size())
fileStat, err = f.c.stat(f.path)
}
if err != nil {
return 0, err
}

if fileSize <= uint64(f.c.maxPacket) {
// We should be able to handle this in one Read.
fileSize := fileStat.Size
if fileSize <= uint64(f.c.maxPacket) || !isRegular(fileStat.Mode) {
// only regular files are guaranteed to return (full read) xor (partial read, next error)
return f.writeToSequential(w)
}

Expand All @@ -1187,6 +1191,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
// Now that concurrency64 is saturated to an int value, we know this assignment cannot possibly overflow.
concurrency := int(concurrency64)

chunkSize := f.c.maxPacket
pool := newBufPool(concurrency, chunkSize)
resPool := newResChanPool(concurrency)

cancel := make(chan struct{})
var wg sync.WaitGroup
defer func() {
Expand All @@ -1200,7 +1208,6 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {

type writeWork struct {
b []byte
n int
off int64
err error

Expand All @@ -1209,7 +1216,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
writeCh := make(chan writeWork)

type readWork struct {
off int64
id uint32
res chan result
off int64

cur, next chan writeWork
}
readCh := make(chan readWork)
Expand All @@ -1219,49 +1229,78 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
defer close(readCh)

off := f.offset
chunkSize := int64(f.c.maxPacket)

cur := writeCh
for {
id := f.c.nextID()
res := resPool.Get()

next := make(chan writeWork)
readWork := readWork{
off: off,
id: id,
res: res,
off: off,

cur: cur,
next: next,
}

f.c.dispatchRequest(res, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(off),
Len: uint32(chunkSize),
})

select {
case readCh <- readWork:
case <-cancel:
return
}

off += chunkSize
off += int64(chunkSize)
cur = next
}
}()

pool := newBufPool(concurrency, f.c.maxPacket)

wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
// Map_i: each worker gets readWork, and does the Read into a buffer at the given offset.
go func() {
defer wg.Done()

ch := make(chan result, 1) // reusable channel

for readWork := range readCh {
b := pool.Get()

n, err := f.readChunkAt(ch, b, readWork.off)
if n < 0 {
panic("sftp.File: returned negative count from readChunkAt")
var b []byte
var n int

s := <-readWork.res
resPool.Put(readWork.res)

err := s.err
if err == nil {
switch s.typ {
case sshFxpStatus:
err = normaliseError(unmarshalStatus(readWork.id, s.data))

case sshFxpData:
sid, data := unmarshalUint32(s.data)
if readWork.id != sid {
err = &unexpectedIDErr{readWork.id, sid}

} else {
l, data := unmarshalUint32(data)
b = pool.Get()[:l]
n = copy(b, data[:l])
b = b[:n]
}

default:
err = unimplementedPacketErr(s.typ)
}
}

writeWork := writeWork{
b: b,
n: n,
off: readWork.off,
err: err,

Expand Down Expand Up @@ -1290,10 +1329,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
}

// Because writes are serialized, this will always be the last successfully read byte.
f.offset = packet.off + int64(packet.n)
f.offset = packet.off + int64(len(packet.b))

if packet.n > 0 {
n, err := w.Write(packet.b[:packet.n])
if len(packet.b) > 0 {
n, err := w.Write(packet.b)
written += int64(n)
if err != nil {
return written, err
Expand Down
59 changes: 55 additions & 4 deletions pool.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package sftp

// bufPool provides a pool of byte-slices to be reused in various parts of the package.
// It is safe to use concurrently through a pointer.
type bufPool struct {
ch chan []byte
blen int
Expand All @@ -13,17 +15,66 @@ func newBufPool(depth, bufLen int) *bufPool {
}

func (p *bufPool) Get() []byte {
select {
case b := <-p.ch:
return b
default:
if p == nil {
// functional default: no reuse.
return make([]byte, p.blen)
}

for {
select {
case b := <-p.ch:
if cap(b) < p.blen {
// just in case: throw away any buffer with insufficient capacity.
continue
}

return b[:p.blen]

default:
return make([]byte, p.blen)
}
}
}

func (p *bufPool) Put(b []byte) {
if p == nil {
// functional default: no reuse.
return
}

if cap(b) < p.blen || cap(b) > p.blen*2 {
// DO NOT reuse buffers with insufficient capacity.
// This could cause panics when resizing to p.blen.

// DO NOT reuse buffers with excessive capacity.
// This could cause memory leaks.
return
}

select {
case p.ch <- b:
default:
}
}

type resChanPool chan chan result

func newResChanPool(depth int) resChanPool {
return make(chan chan result, depth)
}

func (p resChanPool) Get() chan result {
select {
case ch := <-p:
return ch
default:
return make(chan result, 1)
}
}

func (p resChanPool) Put(ch chan result) {
select {
case p <- ch:
default:
}
}
5 changes: 5 additions & 0 deletions stat_plan9.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func translateSyscallError(err error) (uint32, bool) {
return 0, false
}

// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}

// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
Expand Down
5 changes: 5 additions & 0 deletions stat_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func translateSyscallError(err error) (uint32, bool) {
return 0, false
}

// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}

// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
Expand Down

0 comments on commit de44fbb

Please sign in to comment.