diff --git a/common/common.go b/common/common.go index a862c9281..4fbdfa22b 100644 --- a/common/common.go +++ b/common/common.go @@ -1,8 +1,10 @@ package common import ( + "bufio" "crypto/sha256" "fmt" + "io" ) type Runnable interface { @@ -10,6 +12,10 @@ type Runnable interface { Close() error } +func NewBufReadWriter(rw io.ReadWriter) *bufio.ReadWriter { + return bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw)) +} + func SHA224String(password string) string { hash := sha256.New224() hash.Write([]byte(password)) diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index b51f35354..2f4e0cbbe 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -1,6 +1,7 @@ package direct import ( + "bufio" "context" "io" "net" @@ -12,16 +13,18 @@ import ( type DirectOutboundConnSession struct { protocol.ConnSession - conn io.ReadWriteCloser - request *protocol.Request + conn io.ReadWriteCloser + bufReadWriter *bufio.ReadWriter + request *protocol.Request } func (o *DirectOutboundConnSession) Read(p []byte) (int, error) { - return o.conn.Read(p) + return o.bufReadWriter.Read(p) } func (o *DirectOutboundConnSession) Write(p []byte) (int, error) { - return o.conn.Write(p) + defer o.bufReadWriter.Flush() + return o.bufReadWriter.Write(p) } func (o *DirectOutboundConnSession) Close() error { @@ -37,6 +40,7 @@ func NewOutboundConnSession(conn io.ReadWriteCloser, req *protocol.Request) (pro return nil, err } o.conn = newConn + o.bufReadWriter = common.NewBufReadWriter(newConn) } else { o.conn = conn } diff --git a/protocol/trojan/inbound.go b/protocol/trojan/inbound.go index 254ee86f0..3f15f1196 100644 --- a/protocol/trojan/inbound.go +++ b/protocol/trojan/inbound.go @@ -11,23 +11,24 @@ import ( type TrojanInboundConnSession struct { protocol.ConnSession - config *conf.GlobalConfig - request *protocol.Request - bufReader *bufio.Reader - conn net.Conn - uploaded int - downloaded int - userHash string + config *conf.GlobalConfig + request *protocol.Request + bufReadWriter *bufio.ReadWriter + conn net.Conn + uploaded int + downloaded int + userHash string } func (i *TrojanInboundConnSession) Write(p []byte) (int, error) { - n, err := i.conn.Write(p) + n, err := i.bufReadWriter.Write(p) + i.bufReadWriter.Flush() i.uploaded += n return n, err } func (i *TrojanInboundConnSession) Read(p []byte) (int, error) { - n, err := i.bufReader.Read(p) + n, err := i.bufReadWriter.Read(p) i.downloaded += n return n, err } @@ -42,7 +43,7 @@ func (i *TrojanInboundConnSession) GetRequest() *protocol.Request { } func (i *TrojanInboundConnSession) parseRequest() error { - userHash, err := i.bufReader.Peek(56) + userHash, err := i.bufReadWriter.Peek(56) if err != nil { return common.NewError("failed to read hash").Base(err) } @@ -57,9 +58,9 @@ func (i *TrojanInboundConnSession) parseRequest() error { logger.Warn("invalid hash or other protocol:", string(userHash)) return nil } - i.bufReader.Discard(56 + 2) + i.bufReadWriter.Discard(56 + 2) - cmd, err := i.bufReader.ReadByte() + cmd, err := i.bufReadWriter.ReadByte() network := "tcp" switch protocol.Command(cmd) { case protocol.Connect, protocol.Mux: @@ -73,7 +74,7 @@ func (i *TrojanInboundConnSession) parseRequest() error { return common.NewError("failed to read cmd").Base(err) } - req, err := protocol.ParseAddress(i.bufReader) + req, err := protocol.ParseAddress(i.bufReadWriter) if err != nil { return common.NewError("failed to parse address").Base(err) } @@ -81,15 +82,15 @@ func (i *TrojanInboundConnSession) parseRequest() error { req.NetworkType = network i.request = req - i.bufReader.Discard(2) + i.bufReadWriter.Discard(2) return nil } func NewInboundConnSession(conn net.Conn, config *conf.GlobalConfig) (protocol.ConnSession, error) { i := &TrojanInboundConnSession{ - config: config, - conn: conn, - bufReader: bufio.NewReader(conn), + config: config, + conn: conn, + bufReadWriter: common.NewBufReadWriter(conn), } if err := i.parseRequest(); err != nil { return nil, err diff --git a/protocol/trojan/outbound.go b/protocol/trojan/outbound.go index bd93cc023..d459f1d58 100644 --- a/protocol/trojan/outbound.go +++ b/protocol/trojan/outbound.go @@ -14,19 +14,21 @@ type TrojanOutboundConnSession struct { protocol.ConnSession config *conf.GlobalConfig conn io.ReadWriteCloser + bufReader *bufio.ReadWriter request *protocol.Request uploaded int downloaded int } func (o *TrojanOutboundConnSession) Write(p []byte) (int, error) { - n, err := o.conn.Write(p) + n, err := o.bufReader.Write(p) + o.bufReader.Flush() o.uploaded += n return n, err } func (o *TrojanOutboundConnSession) Read(p []byte) (int, error) { - n, err := o.conn.Read(p) + n, err := o.bufReader.Read(p) o.downloaded += n return n, err } @@ -70,9 +72,10 @@ func NewOutboundConnSession(req *protocol.Request, conn io.ReadWriteCloser, conf conn = tlsConn } o := &TrojanOutboundConnSession{ - request: req, - config: config, - conn: conn, + request: req, + config: config, + conn: conn, + bufReader: common.NewBufReadWriter(conn), } if err := o.writeRequest(); err != nil { return nil, common.NewError("failed to write request").Base(err)