diff --git a/Readme.md b/Readme.md index 210693f..70c998a 100644 --- a/Readme.md +++ b/Readme.md @@ -14,148 +14,34 @@ The SOCKS protocol is defined in [rfc1928](https://tools.ietf.org/html/rfc1928) ```golang type ProxyHandler interface { - PreHandler(Request) (io.ReadWriteCloser, *Error) - CopyFromClientToRemote(context.Context, io.ReadCloser, io.WriteCloser) error - CopyFromRemoteToClient(context.Context, io.ReadCloser, io.WriteCloser) error - Cleanup() error + Init(Request) (io.ReadWriteCloser, *Error) + ReadFromClient(context.Context, io.ReadCloser, io.WriteCloser) error + ReadFromRemote(context.Context, io.ReadCloser, io.WriteCloser) error + Close() error Refresh(ctx context.Context) } ``` -### PreHandler +### Init -PreHandler is called before the copy operations and it should return a connection to the target that is ready to receive data. +Init is called before the copy operations and it should return a connection to the target that is ready to receive data. -### CopyFromClientToRemote +### ReadFromClient -CopyFromClientToRemote is the method that handles the data copy from the client (you) to the remote connection. You can see the `DefaultHandler` for a sample implementation. +ReadFromClient is the method that handles the data copy from the client (you) to the remote connection. You can see the `DefaultHandler` for a sample implementation. -### CopyFromRemoteToClient +### ReadFromRemote -CopyFromRemoteToClient is the method that handles the data copy from the remote connection to the client (you). You can see the `DefaultHandler` for a sample implementation. +ReadFromRemote is the method that handles the data copy from the remote connection to the client (you). You can see the `DefaultHandler` for a sample implementation. -### Cleanup +### Close -Cleanup is called after the request finishes or errors out. It is used to clean up any connections in your custom implementation. +Close is called after the request finishes or errors out. It is used to clean up any connections in your custom implementation. ### Refresh Refresh is called in a seperate goroutine and should loop forever to do refreshes of the connection if needed. The passed in context is cancelled after the request so be sure to check on the Done event. -## Usage +## Examples -### Default Usage - -```golang -package main - -import ( - "time", - - socks "github.com/firefart/gosocks" - "github.com/sirupsen/logrus" -) - -func main() { - handler := socks.DefaultHandler{ - Timeout: 1*time.Second, - } - listen := "127.0.0.1:1080" - p := socks.Proxy{ - ServerAddr: listen, - Proxyhandler: handler, - Timeout: 1*time.Second, - Log: logrus.New(), - } - p.Log.Infof("starting SOCKS server on %s", listen) - if err := p.Start(); err != nil { - panic(err) - } - <-p.Done -} -``` - -### Usage with custom handlers - -```golang -package main - -import ( - "time" - "io" - "fmt" - "net" - "context" - - socks "github.com/firefart/gosocks" - "github.com/sirupsen/logrus" -) - -func main() { - log := logrus.New() - handler := MyCustomHandler{ - Timeout: 1*time.Second, - PropA: "A", - PropB: "B", - Log: log, - } - p := socks.Proxy{ - ServerAddr: "127.0.0.1:1080", - Proxyhandler: handler, - Timeout: 1*time.Second, - Log: log, - } - log.Infof("starting SOCKS server on %s", listen) - if err := p.Start(); err != nil { - panic(err) - } - <-p.Done -} - -type MyCustomHandler struct { - Timeout time.Duration, - PropA string, - PropB string, - Log Logger, -} - -func (s *MyCustomHandler) PreHandler(request socks.Request) (io.ReadWriteCloser, *socks.Error) { - conn, err := net.DialTimeout("tcp", s.Server, s.Timeout) - if err != nil { - return nil, &socks.SocksError{Reason: socks.RequestReplyHostUnreachable, Err: fmt.Errorf("error on connecting to server: %w", err)} - } - return conn, nil -} - -func (s *MyCustomHandler) Refresh(ctx context.Context) { - tick := time.NewTicker(10 * time.Second) - select { - case <-ctx.Done(): - return - case <-tick.C: - s.Log.Debug("refreshing connection") - } -} - -func (s *MyCustomHandler) CopyFromRemoteToClient(remote io.ReadCloser, client io.WriteCloser) error { - i, err := io.Copy(client, remote) - if err != nil { - return err - } - s.Log.Debugf("wrote %d bytes to client", i) - return nil -} - -func (s *MyCustomHandler) CopyFromClientToRemote(client io.ReadCloser, remote io.WriteCloser) error { - i, err := io.Copy(remote, client) - if err != nil { - return err - } - s.Log.Debugf("wrote %d bytes to remote", i) - return nil -} - -func (s *MyCustomHandler) Cleanup() error { - return nil -} -``` +For examples please have a look at the examples folder diff --git a/defaulthandler.go b/defaulthandler.go index af8565c..dea8657 100644 --- a/defaulthandler.go +++ b/defaulthandler.go @@ -15,35 +15,35 @@ type DefaultHandler struct { log Logger } -// PreHandler is the default socks5 implementation -func (s DefaultHandler) PreHandler(request Request) (io.ReadWriteCloser, error) { +// Init is the default socks5 implementation +func (s DefaultHandler) Init(request Request) (io.ReadWriteCloser, *Error) { target := fmt.Sprintf("%s:%d", request.DestinationAddress, request.DestinationPort) s.log.Infof("Connecting to target %s", target) remote, err := net.DialTimeout("tcp", target, s.Timeout) if err != nil { - return nil, err + return nil, NewError(RequestReplyNetworkUnreachable, err) } return remote, nil } -// CopyFromClientToRemote is the default socks5 implementation -func (s DefaultHandler) CopyFromClientToRemote(ctx context.Context, client, remote io.ReadWriteCloser) error { - if _, err := io.Copy(client, remote); err != nil { +// ReadFromClient is the default socks5 implementation +func (s DefaultHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { + if _, err := io.Copy(remote, client); err != nil { return err } return nil } -// CopyFromRemoteToClient is the default socks5 implementation -func (s DefaultHandler) CopyFromRemoteToClient(ctx context.Context, remote, client io.ReadWriteCloser) error { - if _, err := io.Copy(remote, client); err != nil { +// ReadFromRemote is the default socks5 implementation +func (s DefaultHandler) ReadFromRemote(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { + if _, err := io.Copy(client, remote); err != nil { return err } return nil } -// Cleanup is the default socks5 implementation -func (s DefaultHandler) Cleanup() error { +// Close is the default socks5 implementation +func (s DefaultHandler) Close() error { return nil } diff --git a/examples/custom/main.go b/examples/custom/main.go new file mode 100644 index 0000000..edaaf11 --- /dev/null +++ b/examples/custom/main.go @@ -0,0 +1,81 @@ +package main + +import ( + "context" + "fmt" + "io" + "net" + "time" + + socks "github.com/firefart/gosocks" +) + +func main() { + log := &socks.NilLogger{} + handler := MyCustomHandler{ + Timeout: 1 * time.Second, + PropA: "A", + PropB: "B", + Log: log, + } + p := socks.Proxy{ + ServerAddr: "127.0.0.1:1080", + Proxyhandler: &handler, + Timeout: 1 * time.Second, + Log: log, + } + log.Infof("starting SOCKS server on %s", p.ServerAddr) + if err := p.Start(); err != nil { + panic(err) + } + <-p.Done +} + +type MyCustomHandler struct { + Timeout time.Duration + PropA string + PropB string + Log socks.Logger +} + +func (s *MyCustomHandler) Init(request socks.Request) (io.ReadWriteCloser, *socks.Error) { + target := fmt.Sprintf("%s:%d", request.DestinationAddress, request.DestinationPort) + s.Log.Infof("Connecting to target %s", target) + remote, err := net.DialTimeout("tcp", target, s.Timeout) + if err != nil { + return nil, socks.NewError(socks.RequestReplyNetworkUnreachable, err) + } + return remote, nil +} + +func (s *MyCustomHandler) Refresh(ctx context.Context) { + tick := time.NewTicker(10 * time.Second) + select { + case <-ctx.Done(): + return + case <-tick.C: + s.Log.Debug("refreshing connection") + } +} + +func (s *MyCustomHandler) ReadFromRemote(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { + i, err := io.Copy(client, remote) + if err != nil { + return err + } + s.Log.Debugf("wrote %d bytes to client", i) + return nil +} + +func (s *MyCustomHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { + i, err := io.Copy(remote, client) + if err != nil { + return err + } + s.Log.Debugf("wrote %d bytes to remote", i) + return nil +} + +func (s MyCustomHandler) Close() error { + return nil +} diff --git a/examples/default/main.go b/examples/default/main.go new file mode 100644 index 0000000..f0ea45c --- /dev/null +++ b/examples/default/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "time" + + socks "github.com/firefart/gosocks" +) + +func main() { + handler := socks.DefaultHandler{ + Timeout: 1 * time.Second, + } + listen := "127.0.0.1:1080" + p := socks.Proxy{ + ServerAddr: listen, + Proxyhandler: handler, + Timeout: 1 * time.Second, + Log: &socks.NilLogger{}, + } + p.Log.Infof("starting SOCKS server on %s", listen) + if err := p.Start(); err != nil { + panic(err) + } + <-p.Done +} diff --git a/go.mod b/go.mod index 433a6ba..bbca117 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/firefart/gosocks -go 1.18 +go 1.21 diff --git a/parsers.go b/parsers.go index 436d6f2..aed0b9b 100644 --- a/parsers.go +++ b/parsers.go @@ -11,35 +11,47 @@ import ( +----+-----+-------+------+----------+----------+ | 1 | 1 | X'00' | 1 | Variable | 2 | +----+-----+-------+------+----------+----------+ + o VER protocol version: X'05' o CMD - o CONNECT X'01' - o BIND X'02' - o UDP ASSOCIATE X'03' + + o CONNECT X'01' + o BIND X'02' + o UDP ASSOCIATE X'03' + o RSV RESERVED o ATYP address type of following address - o IP V4 address: X'01' - o DOMAINNAME: X'03' - o IP V6 address: X'04' + + o IP V4 address: X'01' + o DOMAINNAME: X'03' + o IP V6 address: X'04' + o DST.ADDR desired destination address o DST.PORT desired destination port in network octet - order + + order In an address field (DST.ADDR, BND.ADDR), the ATYP field specifies the type of address contained within the field: + o X'01' + the address is a version-4 IP address, with a length of 4 octets + o X'03' + the address field contains a fully-qualified domain name. The first octet of the address field contains the number of octets of name that follow, there is no terminating NUL octet. + o X'04' + the address is a version-6 IP address, with a length of 16 octets. */ func parseRequest(buf []byte) (*Request, *Error) { r := &Request{} if len(buf) < 7 { - return nil, &Error{Reason: RequestReplyConnectionRefused, Err: fmt.Errorf("invalid request header length (%d)", len(buf))} + return nil, NewError(RequestReplyConnectionRefused, fmt.Errorf("invalid request header length (%d)", len(buf))) } version := buf[0] switch version { @@ -48,7 +60,7 @@ func parseRequest(buf []byte) (*Request, *Error) { case byte(Version5): r.Version = Version5 default: - return nil, &Error{Reason: RequestReplyConnectionRefused, Err: fmt.Errorf("Invalid Socks version %#x", version)} + return nil, NewError(RequestReplyConnectionRefused, fmt.Errorf("Invalid Socks version %#x", version)) } cmd := buf[1] switch cmd { @@ -59,7 +71,7 @@ func parseRequest(buf []byte) (*Request, *Error) { // case byte(RequestCmdAssociate): // r.Command = RequestCmdAssociate default: - return nil, &Error{Reason: RequestReplyCommandNotSupported, Err: fmt.Errorf("Command %#x not supported", cmd)} + return nil, NewError(RequestReplyCommandNotSupported, fmt.Errorf("Command %#x not supported", cmd)) } addresstype := buf[3] switch addresstype { @@ -70,7 +82,7 @@ func parseRequest(buf []byte) (*Request, *Error) { case byte(RequestAddressTypeDomainname): r.AddressType = RequestAddressTypeDomainname default: - return nil, &Error{Reason: RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("AddressType %#x not supported", addresstype)} + return nil, NewError(RequestReplyAddressTypeNotSupported, fmt.Errorf("AddressType %#x not supported", addresstype)) } switch r.AddressType { @@ -88,7 +100,7 @@ func parseRequest(buf []byte) (*Request, *Error) { p := buf[5+addrLen : 5+addrLen+2] r.DestinationPort = binary.BigEndian.Uint16(p) default: - return nil, &Error{Reason: RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("AddressType %#x not supported", addresstype)} + return nil, NewError(RequestReplyAddressTypeNotSupported, fmt.Errorf("AddressType %#x not supported", addresstype)) } return r, nil @@ -106,7 +118,7 @@ func parseHeader(buf []byte) (Header, error) { case byte(Version5): h.Version = Version5 default: - return h, fmt.Errorf("Could not get socks version from header") + return h, fmt.Errorf("could not get socks version from header") } numMethods := buf[1] if len(buf) < int(numMethods)+2 { diff --git a/proxy.go b/proxy.go index b97d732..363fa75 100644 --- a/proxy.go +++ b/proxy.go @@ -9,16 +9,15 @@ import ( // ProxyHandler is the interface for handling the proxy requests type ProxyHandler interface { - PreHandler(Request) (io.ReadWriteCloser, *Error) - CopyFromClientToRemote(context.Context, io.ReadCloser, io.WriteCloser) error - CopyFromRemoteToClient(context.Context, io.ReadCloser, io.WriteCloser) error - Cleanup() error + Init(Request) (io.ReadWriteCloser, *Error) + ReadFromClient(context.Context, io.ReadCloser, io.WriteCloser) error + ReadFromRemote(context.Context, io.ReadCloser, io.WriteCloser) error + Close() error Refresh(context.Context) } // Proxy is the main struct type Proxy struct { - ClientAddr string ServerAddr string Done chan struct{} Proxyhandler ProxyHandler diff --git a/reply.go b/reply.go index 359961d..c861b6f 100644 --- a/reply.go +++ b/reply.go @@ -1,94 +1,99 @@ package socks import ( + "bytes" "encoding/binary" - "fmt" - "net" - "net/netip" - "strconv" ) /* - +----+-----+-------+------+----------+----------+ - |VER | REP | RSV | ATYP | BND.ADDR | BND.PORT | - +----+-----+-------+------+----------+----------+ - | 1 | 1 | X'00' | 1 | Variable | 2 | - +----+-----+-------+------+----------+----------+ + +----+-----+-------+------+----------+----------+ + |VER | REP | RSV | ATYP | BND.ADDR | BND.PORT | + +----+-----+-------+------+----------+----------+ + | 1 | 1 | X'00' | 1 | Variable | 2 | + +----+-----+-------+------+----------+----------+ - Where: + Where: - o VER protocol version: X'05' - o REP Reply field: - o X'00' succeeded - o X'01' general SOCKS server failure - o X'02' connection not allowed by ruleset - o X'03' Network unreachable - o X'04' Host unreachable - o X'05' Connection refused - o X'06' TTL expired - o X'07' Command not supported - o X'08' Address type not supported - o X'09' to X'FF' unassigned - o RSV RESERVED - o ATYP address type of following address - o IP V4 address: X'01' - o DOMAINNAME: X'03' - o IP V6 address: X'04' - o BND.ADDR server bound address - o BND.PORT server bound port in network octet order + o VER protocol version: X'05' + o REP Reply field: + o X'00' succeeded + o X'01' general SOCKS server failure + o X'02' connection not allowed by ruleset + o X'03' Network unreachable + o X'04' Host unreachable + o X'05' Connection refused + o X'06' TTL expired + o X'07' Command not supported + o X'08' Address type not supported + o X'09' to X'FF' unassigned + o RSV RESERVED + o ATYP address type of following address + o IP V4 address: X'01' + o DOMAINNAME: X'03' + o IP V6 address: X'04' + o BND.ADDR server bound address + o BND.PORT server bound port in network octet order - Fields marked RESERVED (RSV) must be set to X'00'. - CONNECT - - In the reply to a CONNECT, BND.PORT contains the port number that the - server assigned to connect to the target host, while BND.ADDR - contains the associated IP address. The supplied BND.ADDR is often - different from the IP address that the client uses to reach the SOCKS - server, since such servers are often multi-homed. It is expected - that the SOCKS server will use DST.ADDR and DST.PORT, and the - client-side source address and port in evaluating the CONNECT - request. + Fields marked RESERVED (RSV) must be set to X'00'. + CONNECT + In the reply to a CONNECT, BND.PORT contains the port number that the + server assigned to connect to the target host, while BND.ADDR + contains the associated IP address. The supplied BND.ADDR is often + different from the IP address that the client uses to reach the SOCKS + server, since such servers are often multi-homed. It is expected + that the SOCKS server will use DST.ADDR and DST.PORT, and the + client-side source address and port in evaluating the CONNECT + request. */ -func requestReply(in net.Addr, reply RequestReplyReason) ([]byte, error) { - var buf []byte - buf = append(buf, Version5.Value()) - buf = append(buf, reply.Value()) - // reserved - buf = append(buf, 0x00) +func requestReply(request *Request, reply RequestReplyReason) ([]byte, error) { + buffer := bytes.NewBuffer(nil) + if err := binary.Write(buffer, binary.BigEndian, Version5.Value()); err != nil { + return nil, err + } + if err := binary.Write(buffer, binary.BigEndian, reply.Value()); err != nil { + return nil, err + } + if err := binary.Write(buffer, binary.BigEndian, byte(0x00)); err != nil { + return nil, err + } - if in != nil { - host, port, err := net.SplitHostPort(in.String()) - if err != nil { - return nil, err - } - ip, err := netip.ParseAddr(host) - if err != nil { + if request != nil { + if err := binary.Write(buffer, binary.BigEndian, request.AddressType.Value()); err != nil { return nil, err } // type - if ip.Is4() { - buf = append(buf, RequestAddressTypeIPv4.Value()) - } else if ip.Is6() { - buf = append(buf, RequestAddressTypeIPv6.Value()) + if request.AddressType == RequestAddressTypeIPv4 { + if err := binary.Write(buffer, binary.BigEndian, request.DestinationAddress); err != nil { + return nil, err + } + } else if request.AddressType == RequestAddressTypeIPv6 { + if err := binary.Write(buffer, binary.BigEndian, request.DestinationAddress); err != nil { + return nil, err + } } else { - return nil, fmt.Errorf("ip %s invalid", ip.String()) + if err := binary.Write(buffer, binary.BigEndian, byte(len(request.DestinationAddress))); err != nil { + return nil, err + } + if err := binary.Write(buffer, binary.BigEndian, request.DestinationAddress); err != nil { + return nil, err + } } - buf = append(buf, ip.AsSlice()...) - portInt, err := strconv.ParseUint(port, 10, 16) - if err != nil { + if err := binary.Write(buffer, binary.BigEndian, request.DestinationPort); err != nil { return nil, err } - var portByte = make([]byte, 2) - binary.BigEndian.PutUint16(portByte, uint16(portInt)) - buf = append(buf, portByte...) } else { - // type - buf = append(buf, RequestAddressTypeIPv4.Value()) - // error reply - buf = append(buf, []byte{0, 0, 0, 0}...) + if err := binary.Write(buffer, binary.BigEndian, RequestAddressTypeIPv4.Value()); err != nil { + return nil, err + } + if err := binary.Write(buffer, binary.BigEndian, []byte{0, 0, 0, 0}); err != nil { + return nil, err + } + if err := binary.Write(buffer, binary.BigEndian, uint16(0)); err != nil { + return nil, err + } } - return buf, nil + return buffer.Bytes(), nil } diff --git a/socks.go b/socks.go index f3c2d79..ae930f3 100644 --- a/socks.go +++ b/socks.go @@ -10,7 +10,7 @@ import ( "sync" ) -func (p *Proxy) handle(conn io.ReadWriteCloser) { +func (p *Proxy) handle(conn net.Conn) { defer conn.Close() defer func() { p.Log.Debug("client connection closed") @@ -19,11 +19,8 @@ func (p *Proxy) handle(conn io.ReadWriteCloser) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if c, ok := conn.(net.Conn); ok { - p.Log.Debugf("got connection from %s", c.RemoteAddr().String()) - } else { - p.Log.Debug("got connection") - } + p.Log.Debugf("got connection from %s", conn.RemoteAddr().String()) + if err := p.socks(ctx, conn); err != nil { // send error reply p.Log.Errorf("socks error: %v", err.Err) @@ -34,10 +31,10 @@ func (p *Proxy) handle(conn io.ReadWriteCloser) { } } -func (p *Proxy) socks(ctx context.Context, conn io.ReadWriteCloser) *Error { +func (p *Proxy) socks(ctx context.Context, conn net.Conn) *Error { defer func() { - if err := p.Proxyhandler.Cleanup(); err != nil { - p.Log.Errorf("error on cleanup: %v", err) + if err := p.Proxyhandler.Close(); err != nil { + p.Log.Errorf("error on close: %v", err) } }() @@ -50,22 +47,17 @@ func (p *Proxy) socks(ctx context.Context, conn io.ReadWriteCloser) *Error { return err } - p.Log.Infof("Connecting to %s", request.getDestinationString()) + p.Log.Infof("Connecting to %s", request.GetDestinationString()) // Should we assume connection succeed here? - remote, err := p.Proxyhandler.PreHandler(*request) + remote, err := p.Proxyhandler.Init(*request) if err != nil { return err } defer remote.Close() + p.Log.Infof("Connection established %s - %s", conn.RemoteAddr().String(), request.GetDestinationString()) - var ip net.Addr - if r, ok := remote.(net.Conn); ok { - ip = r.LocalAddr() - } else { - ip = nil - } - err = p.handleRequestReply(ctx, conn, ip) + err = p.handleRequestReply(ctx, conn, request) if err != nil { return err } @@ -88,10 +80,10 @@ func (p *Proxy) socks(ctx context.Context, conn io.ReadWriteCloser) *Error { // stop refreshing the connection cancel() if err := <-errChannel1; err != nil { - return &Error{Reason: RequestReplyHostUnreachable, Err: err} + return NewError(RequestReplyHostUnreachable, err) } if err := <-errChannel2; err != nil { - return &Error{Reason: RequestReplyHostUnreachable, Err: err} + return NewError(RequestReplyHostUnreachable, err) } p.Log.Debug("end of connection handling") @@ -107,7 +99,7 @@ func (p *Proxy) copyClientToRemote(ctx context.Context, client io.ReadCloser, re errChannel <- nil return default: - if err := p.Proxyhandler.CopyFromClientToRemote(ctx, client, remote); err != nil { + if err := p.Proxyhandler.ReadFromClient(ctx, client, remote); err != nil { errChannel <- fmt.Errorf("error on copy from Client to Remote: %v", err) return } @@ -125,7 +117,7 @@ func (p *Proxy) copyRemoteToClient(ctx context.Context, remote io.ReadCloser, cl errChannel <- nil return default: - if err := p.Proxyhandler.CopyFromRemoteToClient(ctx, remote, client); err != nil { + if err := p.Proxyhandler.ReadFromRemote(ctx, remote, client); err != nil { errChannel <- fmt.Errorf("error on copy from Remote to Client: %v", err) return } @@ -151,18 +143,18 @@ func (p *Proxy) socksErrorReply(ctx context.Context, conn io.ReadWriteCloser, re func (p *Proxy) handleConnect(ctx context.Context, conn io.ReadWriteCloser) *Error { buf, err := connectionRead(ctx, conn, p.Timeout) if err != nil { - return &Error{Reason: RequestReplyConnectionRefused, Err: err} + return NewError(RequestReplyConnectionRefused, err) } header, err := parseHeader(buf) if err != nil { - return &Error{Reason: RequestReplyConnectionRefused, Err: err} + return NewError(RequestReplyConnectionRefused, err) } switch header.Version { case Version4: - return &Error{Reason: RequestReplyCommandNotSupported, Err: fmt.Errorf("socks4 not yet implemented")} + return NewError(RequestReplyCommandNotSupported, fmt.Errorf("socks4 not yet implemented")) case Version5: default: - return &Error{Reason: RequestReplyCommandNotSupported, Err: fmt.Errorf("version %#x not yet implemented", byte(header.Version))} + return NewError(RequestReplyCommandNotSupported, fmt.Errorf("version %#x not yet implemented", byte(header.Version))) } methodSupported := false @@ -173,14 +165,14 @@ func (p *Proxy) handleConnect(ctx context.Context, conn io.ReadWriteCloser) *Err } } if !methodSupported { - return &Error{Reason: RequestReplyMethodNotSupported, Err: fmt.Errorf("we currently only support no authentication")} + return NewError(RequestReplyMethodNotSupported, fmt.Errorf("we currently only support no authentication")) } reply := make([]byte, 2) reply[0] = byte(Version5) reply[1] = byte(MethodNoAuthRequired) err = connectionWrite(ctx, conn, reply, p.Timeout) if err != nil { - return &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("could not send connect reply: %w", err)} + return NewError(RequestReplyGeneralFailure, fmt.Errorf("could not send connect reply: %w", err)) } return nil } @@ -188,7 +180,7 @@ func (p *Proxy) handleConnect(ctx context.Context, conn io.ReadWriteCloser) *Err func (p *Proxy) handleRequest(ctx context.Context, conn io.ReadWriteCloser) (*Request, *Error) { buf, err := connectionRead(ctx, conn, p.Timeout) if err != nil { - return nil, &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("error on ConnectionRead: %w", err)} + return nil, NewError(RequestReplyGeneralFailure, fmt.Errorf("error on ConnectionRead: %w", err)) } request, err2 := parseRequest(buf) if err2 != nil { @@ -197,14 +189,14 @@ func (p *Proxy) handleRequest(ctx context.Context, conn io.ReadWriteCloser) (*Re return request, nil } -func (p *Proxy) handleRequestReply(ctx context.Context, conn io.ReadWriteCloser, addr net.Addr) *Error { - repl, err := requestReply(addr, RequestReplySucceeded) +func (p *Proxy) handleRequestReply(ctx context.Context, conn io.ReadWriteCloser, request *Request) *Error { + repl, err := requestReply(request, RequestReplySucceeded) if err != nil { - return &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("error on requestReply: %w", err)} + return NewError(RequestReplyGeneralFailure, fmt.Errorf("error on requestReply: %w", err)) } err = connectionWrite(ctx, conn, repl, p.Timeout) if err != nil { - return &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("error on RequestResponse: %w", err)} + return NewError(RequestReplyGeneralFailure, fmt.Errorf("error on RequestResponse: %w", err)) } return nil diff --git a/types.go b/types.go index 50faac5..2d722f5 100644 --- a/types.go +++ b/types.go @@ -21,7 +21,7 @@ type Request struct { DestinationPort uint16 } -func (r Request) getDestinationString() string { +func (r Request) GetDestinationString() string { switch r.AddressType { case RequestAddressTypeDomainname: return fmt.Sprintf("%s:%d", r.DestinationAddress, r.DestinationPort) @@ -142,3 +142,7 @@ type Error struct { // Error returns the underying error string func (e *Error) Error() string { return e.Err.Error() } + +func NewError(reason RequestReplyReason, err error) *Error { + return &Error{Reason: reason, Err: err} +}