Skip to content

Commit

Permalink
allign StatusPacket with other response packet
Browse files Browse the repository at this point in the history
  • Loading branch information
puellanivis committed Nov 15, 2024
1 parent 98fad75 commit 46c69d2
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,12 @@ func getPacket[RESP respPacket[PKT], PKT any](ctx context.Context, cancel <-chan
return resp, nil

case sshfx.PacketTypeStatus:
var status sshfx.StatusPacket
status := new(sshfx.StatusPacket)
if err := status.UnmarshalPacketBody(&raw.Data); err != nil {
return nil, err
}

return nil, statusToError(&status, false)
return nil, statusToError(status, false)

default:
return nil, fmt.Errorf("unexpected packet type: %s", raw.PacketType)
Expand Down Expand Up @@ -471,11 +471,10 @@ func (cl *Client) sendPacket(ctx context.Context, cancel <-chan struct{}, req ss
return err
}

var resp sshfx.StatusPacket
return cl.recvStatus(ctx, reqid, ch, &resp)
return cl.recvStatus(ctx, reqid, ch, nil)
}

func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, resp *sshfx.StatusPacket) error {
func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, hint *sshfx.StatusPacket) error {
raw, err := cl.conn.recv(ctx, reqid, ch)
if err != nil {
return err
Expand All @@ -484,11 +483,15 @@ func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result,

switch raw.PacketType {
case sshfx.PacketTypeStatus:
if err := resp.UnmarshalPacketBody(&raw.Data); err != nil {
if hint == nil {
hint = new(sshfx.StatusPacket)
}

if err := hint.UnmarshalPacketBody(&raw.Data); err != nil {
return err
}

return statusToError(resp, true)
return statusToError(hint, true)

default:
return fmt.Errorf("unexpected packet type: %s", raw.PacketType)
Expand Down Expand Up @@ -1549,10 +1552,10 @@ func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, e
go func() {
defer close(errCh)

var status sshfx.StatusPacket
statusHint := new(sshfx.StatusPacket)

for work := range workCh {
err := f.cl.recvStatus(ctx, work.reqid, work.res, &status)
err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint)
if err != nil {
errCh <- rwErr{work.off, err}

Expand Down Expand Up @@ -1784,10 +1787,10 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) {
go func() {
defer close(errCh)

var status sshfx.StatusPacket
statusHint := new(sshfx.StatusPacket)

for work := range workCh {
err := f.cl.recvStatus(ctx, work.reqid, work.res, &status)
err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint)
if err != nil {
errCh <- rwErr{work.off, err}

Expand Down

0 comments on commit 46c69d2

Please sign in to comment.