From 2c58592d8c6b8349dfed943ba8169db2ab8b90dc Mon Sep 17 00:00:00 2001 From: Loyalsoldier <10487845+Loyalsoldier@users.noreply.github.com> Date: Mon, 3 May 2021 06:19:57 +0800 Subject: [PATCH] Fix: data race (#317) --- common/io.go | 15 ++++++- test/util/util.go | 86 +++++++++++++++++++------------------- tunnel/adapter/server.go | 6 +++ tunnel/transport/server.go | 6 +++ 4 files changed, 69 insertions(+), 44 deletions(-) diff --git a/common/io.go b/common/io.go index b80b641a9..0a1a952c5 100644 --- a/common/io.go +++ b/common/io.go @@ -3,11 +3,13 @@ package common import ( "io" "net" + "sync" "github.com/p4gefau1t/trojan-go/log" ) type RewindReader struct { + mu sync.Mutex rawReader io.Reader buf []byte bufReadIdx int @@ -17,13 +19,16 @@ type RewindReader struct { } func (r *RewindReader) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.rewound { if len(r.buf) > r.bufReadIdx { n := copy(p, r.buf[r.bufReadIdx:]) r.bufReadIdx += n return n, nil } - r.rewound = false //all buffering content has been read + r.rewound = false // all buffering content has been read } n, err := r.rawReader.Read(p) if r.buffering { @@ -59,19 +64,24 @@ func (r *RewindReader) Discard(n int) (int, error) { } func (r *RewindReader) Rewind() { + r.mu.Lock() if r.bufferSize == 0 { panic("no buffer") } r.rewound = true r.bufReadIdx = 0 + r.mu.Unlock() } func (r *RewindReader) StopBuffering() { + r.mu.Lock() r.buffering = false + r.mu.Unlock() } func (r *RewindReader) SetBufferSize(size int) { - if size == 0 { //disable buffering + r.mu.Lock() + if size == 0 { // disable buffering if !r.buffering { panic("reader is disabled") } @@ -88,6 +98,7 @@ func (r *RewindReader) SetBufferSize(size int) { r.bufferSize = size r.buf = make([]byte, 0, size) } + r.mu.Unlock() } type RewindConn struct { diff --git a/test/util/util.go b/test/util/util.go index 57329b43e..63e89fd50 100644 --- a/test/util/util.go +++ b/test/util/util.go @@ -12,30 +12,33 @@ import ( // CheckConn checks if two netConn were connected and work properly func CheckConn(a net.Conn, b net.Conn) bool { - payload1 := [1024]byte{} - payload2 := [1024]byte{} - rand.Reader.Read(payload1[:]) - rand.Reader.Read(payload2[:]) + payload1 := make([]byte, 1024) + payload2 := make([]byte, 1024) + + result1 := make([]byte, 1024) + result2 := make([]byte, 1024) + + rand.Reader.Read(payload1) + rand.Reader.Read(payload2) - result1 := [1024]byte{} - result2 := [1024]byte{} wg := sync.WaitGroup{} wg.Add(2) + go func() { - a.Write(payload1[:]) - a.Read(result2[:]) + a.Write(payload1) + a.Read(result2) wg.Done() }() + go func() { - b.Read(result1[:]) - b.Write(payload2[:]) + b.Read(result1) + b.Write(payload2) wg.Done() }() + wg.Wait() - if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) { - return false - } - return true + + return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2) } // CheckPacketOverConn checks if two PacketConn streaming over a connection work properly @@ -45,55 +48,54 @@ func CheckPacketOverConn(a, b net.PacketConn) bool { IP: net.ParseIP("127.0.0.1"), Port: port, } - payload1 := [1024]byte{} - payload2 := [1024]byte{} - rand.Reader.Read(payload1[:]) - rand.Reader.Read(payload2[:]) - result1 := [1024]byte{} - result2 := [1024]byte{} + payload1 := make([]byte, 1024) + payload2 := make([]byte, 1024) + + result1 := make([]byte, 1024) + result2 := make([]byte, 1024) - common.Must2(a.WriteTo(payload1[:], addr)) - _, addr1, err := b.ReadFrom(result1[:]) + rand.Reader.Read(payload1) + rand.Reader.Read(payload2) + + common.Must2(a.WriteTo(payload1, addr)) + _, addr1, err := b.ReadFrom(result1) common.Must(err) if addr1.String() != addr.String() { return false } - common.Must2(a.WriteTo(payload2[:], addr)) - _, addr2, err := b.ReadFrom(result2[:]) + common.Must2(a.WriteTo(payload2, addr)) + _, addr2, err := b.ReadFrom(result2) common.Must(err) if addr2.String() != addr.String() { return false } - if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) { - return false - } - return true + + return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2) } func CheckPacket(a, b net.PacketConn) bool { - payload1 := [1024]byte{} - payload2 := [1024]byte{} - rand.Reader.Read(payload1[:]) - rand.Reader.Read(payload2[:]) + payload1 := make([]byte, 1024) + payload2 := make([]byte, 1024) - result1 := [1024]byte{} - result2 := [1024]byte{} + result1 := make([]byte, 1024) + result2 := make([]byte, 1024) - _, err := a.WriteTo(payload1[:], b.LocalAddr()) + rand.Reader.Read(payload1) + rand.Reader.Read(payload2) + + _, err := a.WriteTo(payload1, b.LocalAddr()) common.Must(err) - _, _, err = b.ReadFrom(result1[:]) + _, _, err = b.ReadFrom(result1) common.Must(err) - _, err = b.WriteTo(payload2[:], a.LocalAddr()) + _, err = b.WriteTo(payload2, a.LocalAddr()) common.Must(err) - _, _, err = a.ReadFrom(result2[:]) + _, _, err = a.ReadFrom(result2) common.Must(err) - if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) { - return false - } - return true + + return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2) } func GetTestAddr() string { diff --git a/tunnel/adapter/server.go b/tunnel/adapter/server.go index de4f4cd38..94fe42203 100644 --- a/tunnel/adapter/server.go +++ b/tunnel/adapter/server.go @@ -3,6 +3,7 @@ package adapter import ( "context" "net" + "sync" "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/config" @@ -18,6 +19,7 @@ type Server struct { udpListener net.PacketConn socksConn chan tunnel.Conn httpConn chan tunnel.Conn + socksLock sync.RWMutex nextSocks bool ctx context.Context cancel context.CancelFunc @@ -45,7 +47,9 @@ func (s *Server) acceptConnLoop() { log.Error(common.NewError("failed to detect proxy protocol type").Base(err)) continue } + s.socksLock.RLock() if buf[0] == 5 && s.nextSocks { + s.socksLock.RUnlock() log.Debug("socks5 connection") s.socksConn <- &freedom.Conn{ Conn: rewindConn, @@ -68,7 +72,9 @@ func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) { return nil, common.NewError("adapter closed") } } else if _, ok := overlay.(*socks.Tunnel); ok { + s.socksLock.Lock() s.nextSocks = true + s.socksLock.Unlock() select { case conn := <-s.socksConn: return conn, nil diff --git a/tunnel/transport/server.go b/tunnel/transport/server.go index 94e258fa6..421ad12bb 100644 --- a/tunnel/transport/server.go +++ b/tunnel/transport/server.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "strconv" + "sync" "time" "github.com/p4gefau1t/trojan-go/common" @@ -22,6 +23,7 @@ type Server struct { cmd *exec.Cmd connChan chan tunnel.Conn wsChan chan tunnel.Conn + httpLock sync.RWMutex nextHTTP bool ctx context.Context cancel context.CancelFunc @@ -50,7 +52,9 @@ func (s *Server) acceptLoop() { go func(tcpConn net.Conn) { log.Info("tcp connection from", tcpConn.RemoteAddr()) + s.httpLock.RLock() if s.nextHTTP { // plaintext mode enabled + s.httpLock.RUnlock() // we use real http header parser to mimic a real http server rewindConn := common.NewRewindConn(tcpConn) rewindConn.SetBufferSize(512) @@ -84,7 +88,9 @@ func (s *Server) acceptLoop() { func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) { // TODO fix import cycle if overlay != nil && (overlay.Name() == "WEBSOCKET" || overlay.Name() == "HTTP") { + s.httpLock.Lock() s.nextHTTP = true + s.httpLock.Unlock() select { case conn := <-s.wsChan: return conn, nil