diff --git a/test/scenario/proxy_test.go b/test/scenario/proxy_test.go index 45b9897e3..1630f57b3 100644 --- a/test/scenario/proxy_test.go +++ b/test/scenario/proxy_test.go @@ -275,6 +275,60 @@ websocket: } } +func TestPluginWebsocket(t *testing.T) { + serverPort := common.PickPort("tcp", "127.0.0.1") + socksPort := common.PickPort("tcp", "127.0.0.1") + + clientData := fmt.Sprintf(` +run-type: client +local-addr: 127.0.0.1 +local-port: %d +remote-addr: 127.0.0.1 +remote-port: %d +password: + - password +transport-plugin: + enabled: true + type: plaintext +shadowsocks: + enabled: true + method: AEAD_CHACHA20_POLY1305 + password: 12345678 +mux: + enabled: true +websocket: + enabled: true + path: /ws + hostname: 127.0.0.1 +`, socksPort, serverPort) + serverData := fmt.Sprintf(` +run-type: server +local-addr: 127.0.0.1 +local-port: %d +remote-addr: 127.0.0.1 +remote-port: %s +disable-http-check: true +password: + - password +transport-plugin: + enabled: true + type: plaintext +shadowsocks: + enabled: true + method: AEAD_CHACHA20_POLY1305 + password: 12345678 +websocket: + enabled: true + path: /ws + hostname: 127.0.0.1 +`, serverPort, util.HTTPPort) + + if !CheckClientServer(clientData, serverData, socksPort) { + t.Fail() + } + +} + func TestForward(t *testing.T) { serverPort := common.PickPort("tcp", "127.0.0.1") clientPort := common.PickPort("tcp", "127.0.0.1") diff --git a/tunnel/transport/server.go b/tunnel/transport/server.go index f72b483b3..c66de1b41 100644 --- a/tunnel/transport/server.go +++ b/tunnel/transport/server.go @@ -70,6 +70,7 @@ func (s *Server) acceptLoop() { } return } + log.Info("tcp connection from", tcpConn.RemoteAddr()) go func(tcpConn net.Conn) { var transportConn net.Conn if s.plugin { @@ -96,12 +97,12 @@ func (s *Server) acceptLoop() { // ------------------------ WAR ZONE ---------------------------- - rewindConn := common.NewRewindConn(tcpConn) - rewindConn.SetBufferSize(2048) + handshakeRewindConn := common.NewRewindConn(tcpConn) + handshakeRewindConn.SetBufferSize(2048) - tlsConn := tls.Server(rewindConn, tlsConfig) + tlsConn := tls.Server(handshakeRewindConn, tlsConfig) err = tlsConn.Handshake() - rewindConn.StopBuffering() + handshakeRewindConn.StopBuffering() if err != nil { if !sniVerified { @@ -110,18 +111,18 @@ func (s *Server) acceptLoop() { log.Error(common.NewError("tls client hello with wrong sni").Base(err)) } else if strings.Contains(err.Error(), "first record does not look like a TLS handshake") { // not a valid tls client hello - rewindConn.Rewind() + handshakeRewindConn.Rewind() log.Error(common.NewError("failed to perform tls handshake with " + tlsConn.RemoteAddr().String() + ", redirecting").Base(err)) if s.fallbackAddress != nil { s.redir.Redirect(&redirector.Redirection{ - InboundConn: rewindConn, + InboundConn: handshakeRewindConn, RedirectTo: s.fallbackAddress, }) } else if s.httpResp != nil { - rewindConn.Write(s.httpResp) - rewindConn.Close() + handshakeRewindConn.Write(s.httpResp) + handshakeRewindConn.Close() } else { - rewindConn.Close() + handshakeRewindConn.Close() } } else { // other cases, simply close it @@ -137,22 +138,22 @@ func (s *Server) acceptLoop() { } // we use real http header parser to mimic a real http server - tlsRewindConn := common.NewRewindConn(transportConn) - tlsRewindConn.SetBufferSize(512) - defer tlsRewindConn.StopBuffering() - r := bufio.NewReader(tlsRewindConn) + rewindConn := common.NewRewindConn(transportConn) + rewindConn.SetBufferSize(512) + defer rewindConn.StopBuffering() + r := bufio.NewReader(rewindConn) httpReq, err := http.ReadRequest(r) - tlsRewindConn.Rewind() + rewindConn.Rewind() if err != nil { // this is not a http request, pass it to trojan protocol layer for further inspection s.connChan <- &Conn{ - Conn: tlsRewindConn, + Conn: rewindConn, } } else { // this is a http request, pass it to websocket protocol layer log.Debug("http req: ", httpReq) s.wsChan <- &Conn{ - Conn: tlsRewindConn, + Conn: rewindConn, } } }(tcpConn) @@ -240,6 +241,7 @@ func NewServer(ctx context.Context, _ tunnel.Server) (*Server, error) { } server := &Server{ connChan: make(chan tunnel.Conn, 32), + wsChan: make(chan tunnel.Conn, 32), tcpListener: tcpListener, redir: redirector.NewRedirector(ctx), cmd: cmd,