From 46c69d20eea9846ff7294edbd525f80c49294f69 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 15 Nov 2024 16:16:21 +0000 Subject: [PATCH] allign StatusPacket with other response packet --- client.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 9e435001..133b485d 100644 --- a/client.go +++ b/client.go @@ -402,12 +402,12 @@ func getPacket[RESP respPacket[PKT], PKT any](ctx context.Context, cancel <-chan return resp, nil case sshfx.PacketTypeStatus: - var status sshfx.StatusPacket + status := new(sshfx.StatusPacket) if err := status.UnmarshalPacketBody(&raw.Data); err != nil { return nil, err } - return nil, statusToError(&status, false) + return nil, statusToError(status, false) default: return nil, fmt.Errorf("unexpected packet type: %s", raw.PacketType) @@ -471,11 +471,10 @@ func (cl *Client) sendPacket(ctx context.Context, cancel <-chan struct{}, req ss return err } - var resp sshfx.StatusPacket - return cl.recvStatus(ctx, reqid, ch, &resp) + return cl.recvStatus(ctx, reqid, ch, nil) } -func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, resp *sshfx.StatusPacket) error { +func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, hint *sshfx.StatusPacket) error { raw, err := cl.conn.recv(ctx, reqid, ch) if err != nil { return err @@ -484,11 +483,15 @@ func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, switch raw.PacketType { case sshfx.PacketTypeStatus: - if err := resp.UnmarshalPacketBody(&raw.Data); err != nil { + if hint == nil { + hint = new(sshfx.StatusPacket) + } + + if err := hint.UnmarshalPacketBody(&raw.Data); err != nil { return err } - return statusToError(resp, true) + return statusToError(hint, true) default: return fmt.Errorf("unexpected packet type: %s", raw.PacketType) @@ -1549,10 +1552,10 @@ func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, e go func() { defer close(errCh) - var status sshfx.StatusPacket + statusHint := new(sshfx.StatusPacket) for work := range workCh { - err := f.cl.recvStatus(ctx, work.reqid, work.res, &status) + err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint) if err != nil { errCh <- rwErr{work.off, err} @@ -1784,10 +1787,10 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { go func() { defer close(errCh) - var status sshfx.StatusPacket + statusHint := new(sshfx.StatusPacket) for work := range workCh { - err := f.cl.recvStatus(ctx, work.reqid, work.res, &status) + err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint) if err != nil { errCh <- rwErr{work.off, err}