Skip to content

Commit

Permalink
use bufio
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed Mar 20, 2020
1 parent 9284a9c commit 9138755
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 26 deletions.
6 changes: 6 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
package common

import (
"bufio"
"crypto/sha256"
"fmt"
"io"
)

type Runnable interface {
Run() error
Close() error
}

func NewBufReadWriter(rw io.ReadWriter) *bufio.ReadWriter {
return bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw))
}

func SHA224String(password string) string {
hash := sha256.New224()
hash.Write([]byte(password))
Expand Down
12 changes: 8 additions & 4 deletions protocol/direct/outbound.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package direct

import (
"bufio"
"context"
"io"
"net"
Expand All @@ -12,16 +13,18 @@ import (

type DirectOutboundConnSession struct {
protocol.ConnSession
conn io.ReadWriteCloser
request *protocol.Request
conn io.ReadWriteCloser
bufReadWriter *bufio.ReadWriter
request *protocol.Request
}

func (o *DirectOutboundConnSession) Read(p []byte) (int, error) {
return o.conn.Read(p)
return o.bufReadWriter.Read(p)
}

func (o *DirectOutboundConnSession) Write(p []byte) (int, error) {
return o.conn.Write(p)
defer o.bufReadWriter.Flush()
return o.bufReadWriter.Write(p)
}

func (o *DirectOutboundConnSession) Close() error {
Expand All @@ -37,6 +40,7 @@ func NewOutboundConnSession(conn io.ReadWriteCloser, req *protocol.Request) (pro
return nil, err
}
o.conn = newConn
o.bufReadWriter = common.NewBufReadWriter(newConn)
} else {
o.conn = conn
}
Expand Down
35 changes: 18 additions & 17 deletions protocol/trojan/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@ import (

type TrojanInboundConnSession struct {
protocol.ConnSession
config *conf.GlobalConfig
request *protocol.Request
bufReader *bufio.Reader
conn net.Conn
uploaded int
downloaded int
userHash string
config *conf.GlobalConfig
request *protocol.Request
bufReadWriter *bufio.ReadWriter
conn net.Conn
uploaded int
downloaded int
userHash string
}

func (i *TrojanInboundConnSession) Write(p []byte) (int, error) {
n, err := i.conn.Write(p)
n, err := i.bufReadWriter.Write(p)
i.bufReadWriter.Flush()
i.uploaded += n
return n, err
}

func (i *TrojanInboundConnSession) Read(p []byte) (int, error) {
n, err := i.bufReader.Read(p)
n, err := i.bufReadWriter.Read(p)
i.downloaded += n
return n, err
}
Expand All @@ -42,7 +43,7 @@ func (i *TrojanInboundConnSession) GetRequest() *protocol.Request {
}

func (i *TrojanInboundConnSession) parseRequest() error {
userHash, err := i.bufReader.Peek(56)
userHash, err := i.bufReadWriter.Peek(56)
if err != nil {
return common.NewError("failed to read hash").Base(err)
}
Expand All @@ -57,9 +58,9 @@ func (i *TrojanInboundConnSession) parseRequest() error {
logger.Warn("invalid hash or other protocol:", string(userHash))
return nil
}
i.bufReader.Discard(56 + 2)
i.bufReadWriter.Discard(56 + 2)

cmd, err := i.bufReader.ReadByte()
cmd, err := i.bufReadWriter.ReadByte()
network := "tcp"
switch protocol.Command(cmd) {
case protocol.Connect, protocol.Mux:
Expand All @@ -73,23 +74,23 @@ func (i *TrojanInboundConnSession) parseRequest() error {
return common.NewError("failed to read cmd").Base(err)
}

req, err := protocol.ParseAddress(i.bufReader)
req, err := protocol.ParseAddress(i.bufReadWriter)
if err != nil {
return common.NewError("failed to parse address").Base(err)
}
req.Command = protocol.Command(cmd)
req.NetworkType = network
i.request = req

i.bufReader.Discard(2)
i.bufReadWriter.Discard(2)
return nil
}

func NewInboundConnSession(conn net.Conn, config *conf.GlobalConfig) (protocol.ConnSession, error) {
i := &TrojanInboundConnSession{
config: config,
conn: conn,
bufReader: bufio.NewReader(conn),
config: config,
conn: conn,
bufReadWriter: common.NewBufReadWriter(conn),
}
if err := i.parseRequest(); err != nil {
return nil, err
Expand Down
13 changes: 8 additions & 5 deletions protocol/trojan/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@ type TrojanOutboundConnSession struct {
protocol.ConnSession
config *conf.GlobalConfig
conn io.ReadWriteCloser
bufReader *bufio.ReadWriter
request *protocol.Request
uploaded int
downloaded int
}

func (o *TrojanOutboundConnSession) Write(p []byte) (int, error) {
n, err := o.conn.Write(p)
n, err := o.bufReader.Write(p)
o.bufReader.Flush()
o.uploaded += n
return n, err
}

func (o *TrojanOutboundConnSession) Read(p []byte) (int, error) {
n, err := o.conn.Read(p)
n, err := o.bufReader.Read(p)
o.downloaded += n
return n, err
}
Expand Down Expand Up @@ -70,9 +72,10 @@ func NewOutboundConnSession(req *protocol.Request, conn io.ReadWriteCloser, conf
conn = tlsConn
}
o := &TrojanOutboundConnSession{
request: req,
config: config,
conn: conn,
request: req,
config: config,
conn: conn,
bufReader: common.NewBufReadWriter(conn),
}
if err := o.writeRequest(); err != nil {
return nil, common.NewError("failed to write request").Base(err)
Expand Down

0 comments on commit 9138755

Please sign in to comment.