From 76fd836c7763960bfe2e9dbc6d72650967667065 Mon Sep 17 00:00:00 2001 From: p4gefau1t Date: Mon, 23 Mar 2020 14:18:26 -0400 Subject: [PATCH] add tls fallback proxy --- common/common.go | 2 +- conf/conf.go | 6 +++- conf/parse.go | 66 ++++++++++++++++---------------------- protocol/trojan/inbound.go | 2 +- proxy/server.go | 30 ++++++++++++++--- proxy/server_test.go | 5 +-- 6 files changed, 63 insertions(+), 48 deletions(-) diff --git a/common/common.go b/common/common.go index 5a4bebe24..30c23c845 100644 --- a/common/common.go +++ b/common/common.go @@ -48,7 +48,7 @@ func HumanFriendlyTraffic(bytes int) string { if bytes <= GiB { return fmt.Sprintf("%.2f MiB", float32(bytes)/MiB) } - return fmt.Sprintf("%.2f TiB", float32(bytes)/GiB) + return fmt.Sprintf("%.2f GiB", float32(bytes)/GiB) } func ConnectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) { diff --git a/conf/conf.go b/conf/conf.go index ef6f059d0..a6bde13d1 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -23,10 +23,14 @@ type TLSConfig struct { KeyPassword string `json:"key_password"` Cipher string `json:"cipher"` CipherTLS13 string `json:"cipher_tls13"` - HTTPFile string `json:"plain_http_response"` PreferServerCipher bool `json:"prefer_server_cipher"` SNI string `json:"sni"` + HTTPFile string `json:"plain_http_response"` + FallbackHost string `json:"fallback_addr"` + FallbackPort uint16 `json:"fallback_port"` + FallbackAddr net.Addr + CertPool *x509.CertPool KeyPair []tls.Certificate HTTPResponse []byte diff --git a/conf/parse.go b/conf/parse.go index 86f81d6b4..88c9e3057 100644 --- a/conf/parse.go +++ b/conf/parse.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "fmt" "io/ioutil" "net" "os" @@ -16,19 +17,18 @@ import ( var logger = log.New(os.Stdout) -func ConvertToIP(s string) ([]net.IP, error) { - ip := net.ParseIP(s) - if ip == nil { - ips, err := net.LookupIP(s) - if err != nil { - return nil, err - } - if len(ips) == 0 { - return nil, common.NewError("cannot resolve host:" + s) - } - return ips, nil +func convertToAddr(preferV4 bool, host string, port uint16) (*net.TCPAddr, error) { + ip := net.ParseIP(host) + if ip != nil { + return &net.TCPAddr{ + IP: ip, + Port: int(port), + }, nil } - return []net.IP{ip}, nil + if preferV4 { + return net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port)) + } + return net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", host, port)) } func ParseJSON(data []byte) (*GlobalConfig, error) { @@ -108,39 +108,27 @@ func ParseJSON(data []byte) (*GlobalConfig, error) { default: return nil, common.NewError("invalid run type") } - localIPs, err := ConvertToIP(config.LocalHost) + + localAddr, err := convertToAddr(config.TCP.PreferIPV4, config.LocalHost, config.LocalPort) if err != nil { - return nil, err + return nil, common.NewError("invalid local address").Base(err) } - remoteIPs, err := ConvertToIP(config.RemoteHost) + config.LocalAddr = localAddr + config.LocalIP = localAddr.IP + + remoteAddr, err := convertToAddr(config.TCP.PreferIPV4, config.RemoteHost, config.RemotePort) if err != nil { - return nil, err + return nil, common.NewError("invalid remote address").Base(err) } + config.RemoteAddr = remoteAddr + config.RemoteIP = remoteAddr.IP - config.LocalIP = localIPs[0] - config.RemoteIP = remoteIPs[0] - - if config.TCP.PreferIPV4 { - for _, ip := range localIPs { - if ip.To4() != nil { - config.LocalIP = ip - break - } - } - for _, ip := range remoteIPs { - if ip.To4() != nil { - config.RemoteIP = ip - break - } + if config.TLS.FallbackHost != "" { + fallbackAddr, err := convertToAddr(config.TCP.PreferIPV4, config.TLS.FallbackHost, config.TLS.FallbackPort) + if err != nil { + return nil, common.NewError("invalid tls fallback address").Base(err) } - } - config.LocalAddr = &net.TCPAddr{ - IP: config.LocalIP, - Port: int(config.LocalPort), - } - config.RemoteAddr = &net.TCPAddr{ - IP: config.RemoteIP, - Port: int(config.RemotePort), + config.TLS.FallbackAddr = fallbackAddr } if config.TLS.Cipher != "" || config.TLS.CipherTLS13 != "" { diff --git a/protocol/trojan/inbound.go b/protocol/trojan/inbound.go index 5dfb40dc6..97cbcc22d 100644 --- a/protocol/trojan/inbound.go +++ b/protocol/trojan/inbound.go @@ -65,7 +65,7 @@ func (i *TrojanInboundConnSession) parseRequest() error { Port: i.config.RemotePort, NetworkType: "tcp", } - logger.Warn("invalid hash or other protocol:", string(userHash)) + logger.Warn("remote", i.conn.RemoteAddr(), "invalid hash or other protocol:", string(userHash)) return nil } i.passwordHash = string(userHash) diff --git a/proxy/server.go b/proxy/server.go index d0235664b..07784e2f4 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "database/sql" "net" + "reflect" "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/conf" @@ -28,12 +29,12 @@ type Server struct { func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) { inboundConn, err := mux.NewInboundMuxConnSession(stream, passwordHash) - inboundConn.(protocol.NeedMeter).SetMeter(s.meter) if err != nil { stream.Close() logger.Error(common.NewError("cannot start inbound session").Base(err)) return } + inboundConn.(protocol.NeedMeter).SetMeter(s.meter) defer inboundConn.Close() req := inboundConn.GetRequest() if req.Command != protocol.Connect { @@ -52,11 +53,11 @@ func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) { func (s *Server) handleConn(conn net.Conn) { inboundConn, err := trojan.NewInboundConnSession(conn, s.config, s.auth) - if err != nil { - logger.Error(err) + logger.Error(common.NewError("failed to start inbound session, remote:" + conn.RemoteAddr().String()).Base(err)) return } + req := inboundConn.GetRequest() hash := inboundConn.(protocol.HasHash).GetHash() @@ -191,10 +192,31 @@ func (s *Server) Run() error { tlsConn := tls.Server(conn, tlsConfig) err = tlsConn.Handshake() if err != nil { - logger.Warn(common.NewError("failed to perform handshake, responsing http payload").Base(err)) + logger.Warn(common.NewError("failed to perform tls handshake, remote:" + conn.RemoteAddr().String()).Base(err)) + if len(s.config.TLS.HTTPResponse) > 0 { + logger.Warn("trying to response a plain http response") conn.Write(s.config.TLS.HTTPResponse) + continue } + + if s.config.TLS.FallbackAddr != nil { + //HACK + //obtain the bytes buffered by the tls conn + v := reflect.ValueOf(*tlsConn) + buf := v.FieldByName("rawInput").FieldByName("buf").Bytes() + logger.Debug("payload:" + string(buf)) + + remote, err := net.Dial("tcp", s.config.TLS.FallbackAddr.String()) + if err != nil { + logger.Warn(common.NewError("failed to dial to tls fallback server").Base(err)) + } + logger.Warn("proxying this invalid tls conn to the tls fallback server") + remote.Write(buf) + go proxyConn(conn, remote) + continue + } + conn.Close() continue } diff --git a/proxy/server_test.go b/proxy/server_test.go index eae529b58..d7d5b6b54 100644 --- a/proxy/server_test.go +++ b/proxy/server_test.go @@ -152,9 +152,10 @@ func TestServerTCPRedirecting(t *testing.T) { } config.TLS.KeyPair = []tls.Certificate{key} config.TLS.SNI = "localhost" - payload, err := ioutil.ReadFile("http.txt") + addr, err := net.ResolveTCPAddr("tcp", "localhost:443") common.Must(err) - config.TLS.HTTPResponse = payload + config.TLS.FallbackAddr = addr + server := Server{ config: config, }