Skip to content

Commit

Permalink
add tls over ws
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed Apr 9, 2020
1 parent 9aaa346 commit 2c63c60
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 98 deletions.
1 change: 1 addition & 0 deletions conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ type WebsocketConfig struct {
Enabled bool `json:"enabled"`
HostName string `json:"hostname"`
Path string `json:"path"`
Password string `json:"password"`
}

type GlobalConfig struct {
Expand Down
93 changes: 5 additions & 88 deletions protocol/trojan/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@ import (
"context"
"io"
"net"
"net/http"
"time"

"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/log"
"github.com/p4gefau1t/trojan-go/protocol"
"github.com/p4gefau1t/trojan-go/stat"
"golang.org/x/net/websocket"
)

type TrojanInboundConnSession struct {
Expand Down Expand Up @@ -109,88 +106,6 @@ func (i *TrojanInboundConnSession) parseRequest() error {
return nil
}

//Fake response writer
//Websocket ServeHTTP method uses its Hijack method to get the Readwriter
type wsHttpResponseWriter struct {
http.Hijacker
http.ResponseWriter

ReadWriter *bufio.ReadWriter
Conn net.Conn
}

func (w *wsHttpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.Conn, w.ReadWriter, nil
}

func (i *TrojanInboundConnSession) parseWebsocket() (bool, error) {
correct := "GET " + i.config.Websocket.Path
first, err := i.bufReadWriter.Peek(len(correct))
if err != nil {
return false, err
}
if !bytes.Equal([]byte(correct), first) {
//it may be a normal trojan conn
log.Debug("not a ws conn", string(first))
return true, common.NewError("invalid header")
}

httpRequest, err := http.ReadRequest(i.bufReadWriter.Reader)
if err != nil {
//malformed http request
return false, err
}

url := "wss://" + i.config.Websocket.HostName + i.config.Websocket.Path
origin := "https://" + i.config.Websocket.HostName
wsConfig, err := websocket.NewConfig(url, origin)

if httpRequest.URL.String() != i.config.Websocket.Path {
log.Error("invalid websocket path, url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin"))
i.readBytes = bytes.NewBuffer([]byte{})
httpRequest.Write(i.readBytes)
return false, common.NewError("invalid url")
}

handshaked := make(chan struct{})

var wsConn *websocket.Conn
wsServer := websocket.Server{
Config: *wsConfig,
Handler: func(conn *websocket.Conn) {
wsConn = conn //store the websocket after handshaking
log.Debug("websocket obtained")
handshaked <- struct{}{}
//this function will NOT return unless the connection is ended
//or the websocket will be closed by ServeHTTP method
<-i.ctx.Done()
},
Handshake: func(wsConfig *websocket.Config, httpRequest *http.Request) error {
log.Debug("websocket url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin"))
return nil
},
}

responseWriter := &wsHttpResponseWriter{
Conn: i.conn.(net.Conn),
ReadWriter: i.bufReadWriter,
}
go wsServer.ServeHTTP(responseWriter, httpRequest)

select {
case <-handshaked:
case <-time.After(protocol.TCPTimeout):
}

if wsConn == nil {
return false, common.NewError("failed to perform websocket handshake")
}
//setup new readwriter
i.conn = wsConn
i.bufReadWriter = common.NewBufReadWriter(wsConn)
return true, nil
}

func (i *TrojanInboundConnSession) SetAuth(auth stat.Authenticator) {
i.auth = auth
}
Expand All @@ -212,11 +127,13 @@ func NewInboundConnSession(conn net.Conn, config *conf.GlobalConfig, auth stat.A
cancel: cancel,
}
if i.config.Websocket.Enabled {
validConn, err := i.parseWebsocket()
if err == nil {
ws, err := NewInboundWebsocket(conn, i.bufReadWriter, i.ctx, config)
if ws != nil {
log.Debug("websocket conn")
i.conn = ws
i.bufReadWriter = common.NewBufReadWriter(ws)
}
if !validConn {
if err != nil {
//no need to continue parsing
i.request = &protocol.Request{
IP: i.config.RemoteIP,
Expand Down
13 changes: 3 additions & 10 deletions protocol/trojan/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/log"
"github.com/p4gefau1t/trojan-go/protocol"
"golang.org/x/net/websocket"
)

type TrojanOutboundConnSession struct {
Expand Down Expand Up @@ -84,17 +83,11 @@ func NewOutboundConnSession(req *protocol.Request, conn io.ReadWriteCloser, conf
}
conn = tlsConn
if config.Websocket.Enabled {
url := "wss://" + config.Websocket.HostName + config.Websocket.Path
origin := "https://" + config.Websocket.HostName
config, err := websocket.NewConfig(url, origin)
ws, err := NewOutboundWebosocket(tlsConn, config)
if err != nil {
return nil, err
return nil, common.NewError("failed to start websocket connection").Base(err)
}
wsConn, err := websocket.NewClient(config, conn)
if err != nil {
return nil, err
}
conn = wsConn
conn = ws
}
}
o := &TrojanOutboundConnSession{
Expand Down
193 changes: 193 additions & 0 deletions protocol/trojan/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package trojan

import (
"bufio"
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"crypto/tls"
"io"
"net"
"net/http"
"time"

"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/log"
"github.com/p4gefau1t/trojan-go/protocol"
"golang.org/x/net/websocket"
)

//this AES layer is used for obfuscation purpose
type obfReadWriteCloser struct {
*websocket.Conn
r cipher.StreamReader
w cipher.StreamWriter
}

func (rwc *obfReadWriteCloser) Read(p []byte) (int, error) {
return rwc.r.Read(p)
}

func (rwc *obfReadWriteCloser) Write(p []byte) (int, error) {
return rwc.w.Write(p)
}

func (rwc *obfReadWriteCloser) Close() error {
return rwc.Conn.Close()
}

func NewObfReadWriteCloser(password string, conn *websocket.Conn, iv []byte) *obfReadWriteCloser {
md5Hash := md5.New()
md5Hash.Write([]byte(password))
key := md5Hash.Sum(nil)
block, err := aes.NewCipher(key)
common.Must(err)
return &obfReadWriteCloser{
Conn: conn,
r: cipher.StreamReader{
S: cipher.NewCTR(block, iv),
R: conn,
},
w: cipher.StreamWriter{
S: cipher.NewCTR(block, iv),
W: conn,
},
}
}

//Fake response writer
//Websocket ServeHTTP method uses its Hijack method to get the Readwriter
type wsHttpResponseWriter struct {
http.Hijacker
http.ResponseWriter

ReadWriter *bufio.ReadWriter
Conn net.Conn
}

func (w *wsHttpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.Conn, w.ReadWriter, nil
}

func NewOutboundWebosocket(conn net.Conn, config *conf.GlobalConfig) (io.ReadWriteCloser, error) {
url := "wss://" + config.Websocket.HostName + config.Websocket.Path
origin := "https://" + config.Websocket.HostName
wsConfig, err := websocket.NewConfig(url, origin)
if err != nil {
return nil, err
}
wsConn, err := websocket.NewClient(wsConfig, conn)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
CipherSuites: config.TLS.CipherSuites,
RootCAs: config.TLS.CertPool,
ServerName: config.TLS.SNI,
SessionTicketsDisabled: !config.TLS.SessionTicket,
ClientSessionCache: tls.NewLRUClientSessionCache(-1),
//InsecureSkipVerify: !config.TLS.Verify, //must verify it
}
var transport net.Conn = wsConn
if config.Websocket.Password != "" {
iv := [aes.BlockSize]byte{}
rand.Reader.Read(iv[:])
wsConn.Write(iv[:])
transport = NewObfReadWriteCloser(config.Websocket.Password, wsConn, iv[:])
}
tlsConn := tls.Client(transport, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
if config.LogLevel == 0 {
state := tlsConn.ConnectionState()
chain := state.VerifiedChains
log.Debug("Websocket TLS handshaked", "cipher:", tls.CipherSuiteName(state.CipherSuite))
for i := range chain {
for j := range chain[i] {
log.Debug("subject:", chain[i][j].Subject, ", issuer:", chain[i][j].Issuer)
}
}
}
return tlsConn, nil
}

func NewInboundWebsocket(conn io.ReadWriteCloser, rw *bufio.ReadWriter, ctx context.Context, config *conf.GlobalConfig) (io.ReadWriteCloser, error) {
correct := "GET " + config.Websocket.Path + " HTTP/1.1\r\n"
first, err := rw.Peek(len(correct))
if err != nil {
return nil, err
}
if !bytes.Equal([]byte(correct), first) {
//it may be a normal trojan conn
log.Debug("not a ws conn", string(first))
return nil, nil
}

httpRequest, err := http.ReadRequest(rw.Reader)
if err != nil {
//malformed http request
return nil, err
}

url := "wss://" + config.Websocket.HostName + config.Websocket.Path
origin := "https://" + config.Websocket.HostName
wsConfig, err := websocket.NewConfig(url, origin)

handshaked := make(chan struct{})

var wsConn *websocket.Conn
wsServer := websocket.Server{
Config: *wsConfig,
Handler: func(conn *websocket.Conn) {
wsConn = conn //store the websocket after handshaking
log.Debug("websocket obtained")
handshaked <- struct{}{}
//this function will NOT return unless the connection is ended
//or the websocket will be closed by ServeHTTP method
<-ctx.Done()
},
Handshake: func(wsConfig *websocket.Config, httpRequest *http.Request) error {
log.Debug("websocket url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin"))
return nil
},
}

responseWriter := &wsHttpResponseWriter{
Conn: conn.(net.Conn),
ReadWriter: rw,
}
go wsServer.ServeHTTP(responseWriter, httpRequest)

select {
case <-handshaked:
case <-time.After(protocol.TCPTimeout):
}

if wsConn == nil {
return nil, common.NewError("failed to perform websocket handshake")
}

var transport net.Conn = wsConn
tlsConfig := &tls.Config{
Certificates: config.TLS.KeyPair,
CipherSuites: config.TLS.CipherSuites,
PreferServerCipherSuites: config.TLS.PreferServerCipher,
SessionTicketsDisabled: !config.TLS.SessionTicket,
}
if config.Websocket.Password != "" {
iv := [aes.BlockSize]byte{}
rand.Reader.Read(iv[:])
wsConn.Read(iv[:])
transport = NewObfReadWriteCloser(config.Websocket.Password, wsConn, iv[:])
}
tlsConn := tls.Server(transport, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
return tlsConn, nil
}
3 changes: 3 additions & 0 deletions test/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ func TestWebsocketClient(t *testing.T) {
Enabled: true,
HostName: "127.0.0.1",
Path: "/websocket",
Password: "testpassword",
},
}
c := client.Client{}
Expand All @@ -260,6 +261,7 @@ func TestWebsocketMuxClient(t *testing.T) {
Enabled: true,
HostName: "127.0.0.1",
Path: "/websocket",
Password: "testpassword",
},
Mux: conf.MuxConfig{
Enabled: true,
Expand All @@ -286,6 +288,7 @@ func TestWebsocketServer(t *testing.T) {
Enabled: true,
HostName: "127.0.0.1",
Path: "/websocket",
Password: "testpassword",
},
}
s := server.Server{}
Expand Down

0 comments on commit 2c63c60

Please sign in to comment.