From 2c63c6026c9a3c072e16daa0af954edf0645b274 Mon Sep 17 00:00:00 2001 From: p4gefau1t Date: Thu, 9 Apr 2020 10:34:07 -0400 Subject: [PATCH] add tls over ws --- conf/conf.go | 1 + protocol/trojan/inbound.go | 93 +---------------- protocol/trojan/outbound.go | 13 +-- protocol/trojan/websocket.go | 193 +++++++++++++++++++++++++++++++++++ test/proxy_test.go | 3 + 5 files changed, 205 insertions(+), 98 deletions(-) create mode 100644 protocol/trojan/websocket.go diff --git a/conf/conf.go b/conf/conf.go index 418fc9aaa..4d645345f 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -100,6 +100,7 @@ type WebsocketConfig struct { Enabled bool `json:"enabled"` HostName string `json:"hostname"` Path string `json:"path"` + Password string `json:"password"` } type GlobalConfig struct { diff --git a/protocol/trojan/inbound.go b/protocol/trojan/inbound.go index 5dcbdc1bc..7d34bb090 100644 --- a/protocol/trojan/inbound.go +++ b/protocol/trojan/inbound.go @@ -6,15 +6,12 @@ import ( "context" "io" "net" - "net/http" - "time" "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/conf" "github.com/p4gefau1t/trojan-go/log" "github.com/p4gefau1t/trojan-go/protocol" "github.com/p4gefau1t/trojan-go/stat" - "golang.org/x/net/websocket" ) type TrojanInboundConnSession struct { @@ -109,88 +106,6 @@ func (i *TrojanInboundConnSession) parseRequest() error { return nil } -//Fake response writer -//Websocket ServeHTTP method uses its Hijack method to get the Readwriter -type wsHttpResponseWriter struct { - http.Hijacker - http.ResponseWriter - - ReadWriter *bufio.ReadWriter - Conn net.Conn -} - -func (w *wsHttpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.Conn, w.ReadWriter, nil -} - -func (i *TrojanInboundConnSession) parseWebsocket() (bool, error) { - correct := "GET " + i.config.Websocket.Path - first, err := i.bufReadWriter.Peek(len(correct)) - if err != nil { - return false, err - } - if !bytes.Equal([]byte(correct), first) { - //it may be a normal trojan conn - log.Debug("not a ws conn", string(first)) - return true, common.NewError("invalid header") - } - - httpRequest, err := http.ReadRequest(i.bufReadWriter.Reader) - if err != nil { - //malformed http request - return false, err - } - - url := "wss://" + i.config.Websocket.HostName + i.config.Websocket.Path - origin := "https://" + i.config.Websocket.HostName - wsConfig, err := websocket.NewConfig(url, origin) - - if httpRequest.URL.String() != i.config.Websocket.Path { - log.Error("invalid websocket path, url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin")) - i.readBytes = bytes.NewBuffer([]byte{}) - httpRequest.Write(i.readBytes) - return false, common.NewError("invalid url") - } - - handshaked := make(chan struct{}) - - var wsConn *websocket.Conn - wsServer := websocket.Server{ - Config: *wsConfig, - Handler: func(conn *websocket.Conn) { - wsConn = conn //store the websocket after handshaking - log.Debug("websocket obtained") - handshaked <- struct{}{} - //this function will NOT return unless the connection is ended - //or the websocket will be closed by ServeHTTP method - <-i.ctx.Done() - }, - Handshake: func(wsConfig *websocket.Config, httpRequest *http.Request) error { - log.Debug("websocket url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin")) - return nil - }, - } - - responseWriter := &wsHttpResponseWriter{ - Conn: i.conn.(net.Conn), - ReadWriter: i.bufReadWriter, - } - go wsServer.ServeHTTP(responseWriter, httpRequest) - - select { - case <-handshaked: - case <-time.After(protocol.TCPTimeout): - } - - if wsConn == nil { - return false, common.NewError("failed to perform websocket handshake") - } - //setup new readwriter - i.conn = wsConn - i.bufReadWriter = common.NewBufReadWriter(wsConn) - return true, nil -} - func (i *TrojanInboundConnSession) SetAuth(auth stat.Authenticator) { i.auth = auth } @@ -212,11 +127,13 @@ func NewInboundConnSession(conn net.Conn, config *conf.GlobalConfig, auth stat.A cancel: cancel, } if i.config.Websocket.Enabled { - validConn, err := i.parseWebsocket() - if err == nil { + ws, err := NewInboundWebsocket(conn, i.bufReadWriter, i.ctx, config) + if ws != nil { log.Debug("websocket conn") + i.conn = ws + i.bufReadWriter = common.NewBufReadWriter(ws) } - if !validConn { + if err != nil { //no need to continue parsing i.request = &protocol.Request{ IP: i.config.RemoteIP, diff --git a/protocol/trojan/outbound.go b/protocol/trojan/outbound.go index 8e41cb55a..df7507803 100644 --- a/protocol/trojan/outbound.go +++ b/protocol/trojan/outbound.go @@ -9,7 +9,6 @@ import ( "github.com/p4gefau1t/trojan-go/conf" "github.com/p4gefau1t/trojan-go/log" "github.com/p4gefau1t/trojan-go/protocol" - "golang.org/x/net/websocket" ) type TrojanOutboundConnSession struct { @@ -84,17 +83,11 @@ func NewOutboundConnSession(req *protocol.Request, conn io.ReadWriteCloser, conf } conn = tlsConn if config.Websocket.Enabled { - url := "wss://" + config.Websocket.HostName + config.Websocket.Path - origin := "https://" + config.Websocket.HostName - config, err := websocket.NewConfig(url, origin) + ws, err := NewOutboundWebosocket(tlsConn, config) if err != nil { - return nil, err + return nil, common.NewError("failed to start websocket connection").Base(err) } - wsConn, err := websocket.NewClient(config, conn) - if err != nil { - return nil, err - } - conn = wsConn + conn = ws } } o := &TrojanOutboundConnSession{ diff --git a/protocol/trojan/websocket.go b/protocol/trojan/websocket.go new file mode 100644 index 000000000..940a18e06 --- /dev/null +++ b/protocol/trojan/websocket.go @@ -0,0 +1,193 @@ +package trojan + +import ( + "bufio" + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/rand" + "crypto/tls" + "io" + "net" + "net/http" + "time" + + "github.com/p4gefau1t/trojan-go/common" + "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/log" + "github.com/p4gefau1t/trojan-go/protocol" + "golang.org/x/net/websocket" +) + +//this AES layer is used for obfuscation purpose +type obfReadWriteCloser struct { + *websocket.Conn + r cipher.StreamReader + w cipher.StreamWriter +} + +func (rwc *obfReadWriteCloser) Read(p []byte) (int, error) { + return rwc.r.Read(p) +} + +func (rwc *obfReadWriteCloser) Write(p []byte) (int, error) { + return rwc.w.Write(p) +} + +func (rwc *obfReadWriteCloser) Close() error { + return rwc.Conn.Close() +} + +func NewObfReadWriteCloser(password string, conn *websocket.Conn, iv []byte) *obfReadWriteCloser { + md5Hash := md5.New() + md5Hash.Write([]byte(password)) + key := md5Hash.Sum(nil) + block, err := aes.NewCipher(key) + common.Must(err) + return &obfReadWriteCloser{ + Conn: conn, + r: cipher.StreamReader{ + S: cipher.NewCTR(block, iv), + R: conn, + }, + w: cipher.StreamWriter{ + S: cipher.NewCTR(block, iv), + W: conn, + }, + } +} + +//Fake response writer +//Websocket ServeHTTP method uses its Hijack method to get the Readwriter +type wsHttpResponseWriter struct { + http.Hijacker + http.ResponseWriter + + ReadWriter *bufio.ReadWriter + Conn net.Conn +} + +func (w *wsHttpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return w.Conn, w.ReadWriter, nil +} + +func NewOutboundWebosocket(conn net.Conn, config *conf.GlobalConfig) (io.ReadWriteCloser, error) { + url := "wss://" + config.Websocket.HostName + config.Websocket.Path + origin := "https://" + config.Websocket.HostName + wsConfig, err := websocket.NewConfig(url, origin) + if err != nil { + return nil, err + } + wsConn, err := websocket.NewClient(wsConfig, conn) + if err != nil { + return nil, err + } + tlsConfig := &tls.Config{ + CipherSuites: config.TLS.CipherSuites, + RootCAs: config.TLS.CertPool, + ServerName: config.TLS.SNI, + SessionTicketsDisabled: !config.TLS.SessionTicket, + ClientSessionCache: tls.NewLRUClientSessionCache(-1), + //InsecureSkipVerify: !config.TLS.Verify, //must verify it + } + var transport net.Conn = wsConn + if config.Websocket.Password != "" { + iv := [aes.BlockSize]byte{} + rand.Reader.Read(iv[:]) + wsConn.Write(iv[:]) + transport = NewObfReadWriteCloser(config.Websocket.Password, wsConn, iv[:]) + } + tlsConn := tls.Client(transport, tlsConfig) + if err := tlsConn.Handshake(); err != nil { + return nil, err + } + if config.LogLevel == 0 { + state := tlsConn.ConnectionState() + chain := state.VerifiedChains + log.Debug("Websocket TLS handshaked", "cipher:", tls.CipherSuiteName(state.CipherSuite)) + for i := range chain { + for j := range chain[i] { + log.Debug("subject:", chain[i][j].Subject, ", issuer:", chain[i][j].Issuer) + } + } + } + return tlsConn, nil +} + +func NewInboundWebsocket(conn io.ReadWriteCloser, rw *bufio.ReadWriter, ctx context.Context, config *conf.GlobalConfig) (io.ReadWriteCloser, error) { + correct := "GET " + config.Websocket.Path + " HTTP/1.1\r\n" + first, err := rw.Peek(len(correct)) + if err != nil { + return nil, err + } + if !bytes.Equal([]byte(correct), first) { + //it may be a normal trojan conn + log.Debug("not a ws conn", string(first)) + return nil, nil + } + + httpRequest, err := http.ReadRequest(rw.Reader) + if err != nil { + //malformed http request + return nil, err + } + + url := "wss://" + config.Websocket.HostName + config.Websocket.Path + origin := "https://" + config.Websocket.HostName + wsConfig, err := websocket.NewConfig(url, origin) + + handshaked := make(chan struct{}) + + var wsConn *websocket.Conn + wsServer := websocket.Server{ + Config: *wsConfig, + Handler: func(conn *websocket.Conn) { + wsConn = conn //store the websocket after handshaking + log.Debug("websocket obtained") + handshaked <- struct{}{} + //this function will NOT return unless the connection is ended + //or the websocket will be closed by ServeHTTP method + <-ctx.Done() + }, + Handshake: func(wsConfig *websocket.Config, httpRequest *http.Request) error { + log.Debug("websocket url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin")) + return nil + }, + } + + responseWriter := &wsHttpResponseWriter{ + Conn: conn.(net.Conn), + ReadWriter: rw, + } + go wsServer.ServeHTTP(responseWriter, httpRequest) + + select { + case <-handshaked: + case <-time.After(protocol.TCPTimeout): + } + + if wsConn == nil { + return nil, common.NewError("failed to perform websocket handshake") + } + + var transport net.Conn = wsConn + tlsConfig := &tls.Config{ + Certificates: config.TLS.KeyPair, + CipherSuites: config.TLS.CipherSuites, + PreferServerCipherSuites: config.TLS.PreferServerCipher, + SessionTicketsDisabled: !config.TLS.SessionTicket, + } + if config.Websocket.Password != "" { + iv := [aes.BlockSize]byte{} + rand.Reader.Read(iv[:]) + wsConn.Read(iv[:]) + transport = NewObfReadWriteCloser(config.Websocket.Password, wsConn, iv[:]) + } + tlsConn := tls.Server(transport, tlsConfig) + if err := tlsConn.Handshake(); err != nil { + return nil, err + } + return tlsConn, nil +} diff --git a/test/proxy_test.go b/test/proxy_test.go index 852982e83..e0399568c 100644 --- a/test/proxy_test.go +++ b/test/proxy_test.go @@ -239,6 +239,7 @@ func TestWebsocketClient(t *testing.T) { Enabled: true, HostName: "127.0.0.1", Path: "/websocket", + Password: "testpassword", }, } c := client.Client{} @@ -260,6 +261,7 @@ func TestWebsocketMuxClient(t *testing.T) { Enabled: true, HostName: "127.0.0.1", Path: "/websocket", + Password: "testpassword", }, Mux: conf.MuxConfig{ Enabled: true, @@ -286,6 +288,7 @@ func TestWebsocketServer(t *testing.T) { Enabled: true, HostName: "127.0.0.1", Path: "/websocket", + Password: "testpassword", }, } s := server.Server{}