diff --git a/cert/option.go b/cert/option.go new file mode 100644 index 000000000..dec9a2912 --- /dev/null +++ b/cert/option.go @@ -0,0 +1,43 @@ +package cert + +import ( + "flag" + + "github.com/p4gefau1t/trojan-go/common" +) + +type certOption struct { + args *string + common.OptionHandler +} + +func (*certOption) Name() string { + return "cert" +} + +func (*certOption) Priority() int { + return 10 +} + +func (c *certOption) Handle() error { + switch *c.args { + case "request": + RequestCertGuide() + return nil + case "renew": + RenewCertGuide() + return nil + case "INVALID": + return common.NewError("not specified") + default: + err := common.NewError("invalid args " + *c.args) + logger.Error(err) + return common.NewError("invalid args") + } +} + +func init() { + common.RegisterOptionHandler(&certOption{ + args: flag.String("cert", "INVALID", "Simple letsencrpyt cert acme client. Use \"-cert request\" to request a cert or \"-cert renew\" to renew a cert"), + }) +} diff --git a/common/common.go b/common/common.go index 30c23c845..37307562f 100644 --- a/common/common.go +++ b/common/common.go @@ -3,13 +3,12 @@ package common import ( "bufio" "crypto/sha256" - "database/sql" "fmt" "io" - "strings" +) - _ "github.com/go-sql-driver/mysql" - //_ "github.com/mattn/go-sqlite3" +const ( + Version = "v0.0.15" ) type Runnable interface { @@ -50,13 +49,3 @@ func HumanFriendlyTraffic(bytes int) string { } return fmt.Sprintf("%.2f GiB", float32(bytes)/GiB) } - -func ConnectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) { - path := strings.Join([]string{username, ":", password, "@tcp(", ip, ":", fmt.Sprintf("%d", port), ")/", dbName, "?charset=utf8"}, "") - return sql.Open(driverName, path) -} - -func ConnectSQLite(dbName string) (*sql.DB, error) { - //for debug only - return sql.Open("sqlite3", dbName) -} diff --git a/common/db.go b/common/db.go new file mode 100644 index 000000000..cbe7b523b --- /dev/null +++ b/common/db.go @@ -0,0 +1,12 @@ +package common + +import ( + "database/sql" + "fmt" + "strings" +) + +func ConnectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) { + path := strings.Join([]string{username, ":", password, "@tcp(", ip, ":", fmt.Sprintf("%d", port), ")/", dbName, "?charset=utf8"}, "") + return sql.Open(driverName, path) +} diff --git a/common/option.go b/common/option.go new file mode 100644 index 000000000..cf9984266 --- /dev/null +++ b/common/option.go @@ -0,0 +1,27 @@ +package common + +type OptionHandler interface { + Name() string + Handle() error + Priority() int +} + +var handlers map[string]OptionHandler = make(map[string]OptionHandler) + +func RegisterOptionHandler(h OptionHandler) { + handlers[h.Name()] = h +} + +func PopOptionHandler() (OptionHandler, error) { + var maxHandler OptionHandler = nil + for _, h := range handlers { + if maxHandler == nil || maxHandler.Priority() < h.Priority() { + maxHandler = h + } + } + if maxHandler == nil { + return nil, NewError("no option left") + } + delete(handlers, maxHandler.Name()) + return maxHandler, nil +} diff --git a/main.go b/main.go index d6058adc6..b33ebd274 100644 --- a/main.go +++ b/main.go @@ -2,56 +2,31 @@ package main import ( "flag" - "io/ioutil" "os" - "os/signal" - "github.com/p4gefau1t/trojan-go/cert" - "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/log" - "github.com/p4gefau1t/trojan-go/proxy" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/p4gefau1t/trojan-go/cert" + _ "github.com/p4gefau1t/trojan-go/proxy/client" + _ "github.com/p4gefau1t/trojan-go/proxy/forward" + _ "github.com/p4gefau1t/trojan-go/proxy/server" + _ "github.com/p4gefau1t/trojan-go/version" ) var logger = log.New(os.Stdout) func main() { - logger.Info("Trojan-Go initializing...") - configFile := flag.String("config", "config.json", "Config filename") - guideMode := flag.String("cert", "", "Simple letsencrpyt cert acme client. Use \"-cert request\" to request a cert or \"-cert renew\" to renew a cert") flag.Parse() - switch *guideMode { - case "request": - cert.RequestCertGuide() - return - case "renew": - cert.RenewCertGuide() - return - case "": - default: - logger.Error("Invalid cert arg") - return - } - data, err := ioutil.ReadFile(*configFile) - if err != nil { - logger.Fatal("Failed to read config file", err) - } - config, err := conf.ParseJSON(data) - if err != nil { - logger.Fatal("Failed to parse config file", err) - } - proxy := proxy.NewProxy(config) - errChan := make(chan error) - go func() { - errChan <- proxy.Run() - }() - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, os.Interrupt) - select { - case <-sigs: - proxy.Close() - case err := <-errChan: - logger.Fatal(err) + for { + h, err := common.PopOptionHandler() + if err != nil { + logger.Fatal("invalid options") + } + err = h.Handle() + if err == nil { + break + } } - logger.Info("Trojan-Go exited") } diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 2f4e0cbbe..2757cb014 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -23,8 +23,9 @@ func (o *DirectOutboundConnSession) Read(p []byte) (int, error) { } func (o *DirectOutboundConnSession) Write(p []byte) (int, error) { - defer o.bufReadWriter.Flush() - return o.bufReadWriter.Write(p) + n, err := o.bufReadWriter.Write(p) + o.bufReadWriter.Flush() + return n, err } func (o *DirectOutboundConnSession) Close() error { @@ -105,7 +106,7 @@ func (o *DirectOutboundPacketSession) WritePacket(req *protocol.Request, packet if err != nil { return 0, common.NewError("cannot dial udp").Base(err) } - logger.Info("UDP directly dialing to", remote) + logger.Debug("UDP directly dialing to", remote) n, err := conn.Write(packet) return n, err } diff --git a/protocol/http/inbound.go b/protocol/http/inbound.go index 81c8688b6..ab37a6f57 100644 --- a/protocol/http/inbound.go +++ b/protocol/http/inbound.go @@ -35,15 +35,16 @@ func (i *HTTPInboundTunnelConnSession) Read(p []byte) (int, error) { } func (i *HTTPInboundTunnelConnSession) Write(p []byte) (int, error) { - defer i.bufReadWriter.Flush() - return i.bufReadWriter.Write(p) + n, err := i.bufReadWriter.Write(p) + i.bufReadWriter.Flush() + return n, err } func (i *HTTPInboundTunnelConnSession) Close() error { return i.conn.Close() } -func (i *HTTPInboundTunnelConnSession) Respond(r io.Reader) error { +func (i *HTTPInboundTunnelConnSession) Respond() error { payload := fmt.Sprintf("HTTP/%d.%d 200 Connection established\r\n\r\n", i.httpRequest.ProtoMajor, i.httpRequest.ProtoMinor) _, err := i.Write([]byte(payload)) return err @@ -175,6 +176,7 @@ func (i *HTTPInboundPacketSession) ReadPacket() (*protocol.Request, []byte, erro } func (i *HTTPInboundPacketSession) WritePacket(req *protocol.Request, packet []byte) (int, error) { - defer i.bufReadWriter.Flush() - return i.bufReadWriter.Write(packet) + n, err := i.bufReadWriter.Write(packet) + i.bufReadWriter.Flush() + return n, err } diff --git a/protocol/nat/inbound.go b/protocol/nat/inbound.go index 1c5bc6f62..09d6f77e2 100644 --- a/protocol/nat/inbound.go +++ b/protocol/nat/inbound.go @@ -94,6 +94,7 @@ func (i *NATInboundPacketSession) cleanExpiredSession() { select { case <-time.After(protocol.UDPTimeout): case <-i.ctx.Done(): + i.conn.Close() return } } @@ -131,7 +132,7 @@ func (i *NATInboundPacketSession) ReadPacket() (*protocol.Request, []byte, error expire: time.Now().Add(protocol.UDPTimeout), } i.tableMutex.Unlock() - logger.Info("tproxy UDP packet from", src, "to", dst) + logger.Debug("tproxy UDP packet from", src, "to", dst) req := &protocol.Request{ IP: dst.IP, Port: uint16(dst.Port), @@ -147,7 +148,7 @@ func (i *NATInboundPacketSession) ReadPacket() (*protocol.Request, []byte, error func (i *NATInboundPacketSession) Close() error { i.cancel() - return i.conn.Close() + return nil } func NewInboundPacketSession(config *conf.GlobalConfig) (protocol.PacketSession, error) { diff --git a/protocol/protocol.go b/protocol/protocol.go index 89e169177..80d07f43e 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -30,8 +30,8 @@ const ( const ( MaxUDPPacketSize = 1024 * 4 - UDPTimeout = time.Second * 6 - TCPTimeout = time.Second * 6 + UDPTimeout = time.Second * 5 + TCPTimeout = time.Second * 5 ) type Request struct { @@ -69,7 +69,7 @@ type HasHash interface { } type NeedRespond interface { - Respond(io.Reader) error + Respond() error } type PacketReader interface { diff --git a/protocol/socks/inbound.go b/protocol/socks/inbound.go index 75cccb00a..3edfe68f5 100644 --- a/protocol/socks/inbound.go +++ b/protocol/socks/inbound.go @@ -3,13 +3,20 @@ package socks import ( "bufio" "bytes" + "context" "io" "net" + "os" + "sync" + "time" "github.com/p4gefau1t/trojan-go/common" + "github.com/p4gefau1t/trojan-go/log" "github.com/p4gefau1t/trojan-go/protocol" ) +var logger = log.New(os.Stdout) + type SocksConnInboundSession struct { protocol.ConnSession protocol.NeedRespond @@ -69,7 +76,7 @@ func (i *SocksConnInboundSession) parseRequest() error { return nil } -func (i *SocksConnInboundSession) Respond(r io.Reader) error { +func (i *SocksConnInboundSession) Respond() error { if i.request.Command == protocol.Connect { i.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) return nil @@ -86,8 +93,9 @@ func (i *SocksConnInboundSession) Read(p []byte) (int, error) { } func (i *SocksConnInboundSession) Write(p []byte) (int, error) { - defer i.bufReadWriter.Flush() - return i.bufReadWriter.Write(p) + n, err := i.bufReadWriter.Write(p) + i.bufReadWriter.Flush() + return n, err } func (i *SocksConnInboundSession) Close() error { @@ -115,13 +123,23 @@ func NewInboundConnSession(conn io.ReadWriteCloser, rw *bufio.ReadWriter) (proto return i, nil } +type udpSession struct { + src *net.UDPAddr + req *protocol.Request + expire time.Time +} + type SocksInboundPacketSession struct { protocol.PacketSession + conn *net.UDPConn - socks5Client *net.UDPAddr + sessionTable map[string]*udpSession + tableMutex sync.Mutex + ctx context.Context + cancel context.CancelFunc } -func (i *SocksInboundPacketSession) parsePacketHeader(rawPacket []byte) (*protocol.Request, []byte, error) { +func (i *SocksInboundPacketSession) parsePacket(rawPacket []byte) (*protocol.Request, []byte, error) { if len(rawPacket) <= 4 { return nil, nil, common.NewError("too short") } @@ -147,14 +165,46 @@ func (i *SocksInboundPacketSession) writePacketHeader(w io.Writer, req *protocol return nil } +func (i *SocksInboundPacketSession) cleanExpiredSession() { + for { + i.tableMutex.Lock() + now := time.Now() + for k, v := range i.sessionTable { + if now.After(v.expire) { + logger.Debug("deleting expired session", v.src, "req:", v.req) + delete(i.sessionTable, k) + } + } + i.tableMutex.Unlock() + select { + case <-time.After(protocol.UDPTimeout): + case <-i.ctx.Done(): + i.conn.Close() + return + } + } +} + func (i *SocksInboundPacketSession) ReadPacket() (*protocol.Request, []byte, error) { buf := make([]byte, protocol.MaxUDPPacketSize) - n, remote, err := i.conn.ReadFromUDP(buf) - i.socks5Client = remote + n, src, err := i.conn.ReadFromUDP(buf) if err != nil { return nil, nil, err } - return i.parsePacketHeader(buf[0:n]) + req, payload, err := i.parsePacket(buf[0:n]) + if err != nil { + return nil, nil, err + } + session := &udpSession{ + src: src, + req: req, + expire: time.Now().Add(protocol.UDPTimeout), + } + i.tableMutex.Lock() + i.sessionTable[req.String()] = session + i.tableMutex.Unlock() + logger.Debug("UDP read from", src, "req", req) + return req, payload, err } func (i *SocksInboundPacketSession) WritePacket(req *protocol.Request, packet []byte) (int, error) { @@ -163,16 +213,28 @@ func (i *SocksInboundPacketSession) WritePacket(req *protocol.Request, packet [] return 0, err } w.Write(packet) - return i.conn.WriteToUDP(w.Bytes(), i.socks5Client) + client, found := i.sessionTable[req.String()] + if !found { + return 0, common.NewError("session not found") + } + logger.Debug("UDP write to", client.src, "req", req) + return i.conn.WriteToUDP(w.Bytes(), client.src) } func (i *SocksInboundPacketSession) Close() error { + i.cancel() return i.conn.Close() } func NewInboundPacketSession(conn *net.UDPConn) (*SocksInboundPacketSession, error) { - i := &SocksInboundPacketSession{} + ctx, cancel := context.WithCancel(context.Background()) conn.SetWriteBuffer(0) - i.conn = conn + i := &SocksInboundPacketSession{ + ctx: ctx, + cancel: cancel, + sessionTable: make(map[string]*udpSession), + conn: conn, + } + go i.cleanExpiredSession() return i, nil } diff --git a/protocol/trojan/outbound.go b/protocol/trojan/outbound.go index 23fd81c35..140e81776 100644 --- a/protocol/trojan/outbound.go +++ b/protocol/trojan/outbound.go @@ -7,6 +7,7 @@ import ( "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/log" "github.com/p4gefau1t/trojan-go/protocol" ) @@ -76,6 +77,19 @@ func NewOutboundConnSession(req *protocol.Request, conn io.ReadWriteCloser, conf return nil, common.NewError("failed to verify hostname").Base(err) } } + if log.LogLevel == 0 { + state := tlsConn.ConnectionState() + chain := state.VerifiedChains + logger.Debug("tls handshaked", "cipher", tls.CipherSuiteName(state.CipherSuite)) + logger.Debug("chains:") + for i := range chain { + logger.Debug("--------------------------------") + for j := range chain[i] { + logger.Debug("subject:", chain[i][j].Subject, "issuer:", chain[i][j].Issuer) + } + } + logger.Debug("--------------------------------") + } conn = tlsConn } o := &TrojanOutboundConnSession{ diff --git a/proxy/client.go b/proxy/client/client.go similarity index 55% rename from proxy/client.go rename to proxy/client/client.go index b3387cf25..931960ba3 100644 --- a/proxy/client.go +++ b/proxy/client/client.go @@ -1,133 +1,85 @@ -package proxy +package client import ( "bufio" "context" - "math/rand" "net" - "sync" + "os" "time" "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/log" "github.com/p4gefau1t/trojan-go/protocol" "github.com/p4gefau1t/trojan-go/protocol/http" "github.com/p4gefau1t/trojan-go/protocol/mux" "github.com/p4gefau1t/trojan-go/protocol/socks" "github.com/p4gefau1t/trojan-go/protocol/trojan" - "github.com/xtaci/smux" + "github.com/p4gefau1t/trojan-go/proxy" ) -type muxID uint32 +var logger = log.New(os.Stdout) -func generateMuxID() muxID { - return muxID(rand.Uint32()) -} - -type muxClientInfo struct { - id muxID - client *smux.Session - lastActiveTime time.Time +type packetInfo struct { + request *protocol.Request + packet []byte } type Client struct { common.Runnable + proxy.Buildable - config *conf.GlobalConfig - ctx context.Context - cancel context.CancelFunc - - muxLock sync.Mutex - muxPool map[muxID]*muxClientInfo -} - -func (c *Client) newMuxClient() (*muxClientInfo, error) { - id := generateMuxID() - if _, found := c.muxPool[id]; found { - return nil, common.NewError("duplicated id") - } - req := &protocol.Request{ - Command: protocol.Mux, - DomainName: []byte("MUX_CONN"), - AddressType: protocol.DomainName, - } - conn, err := trojan.NewOutboundConnSession(req, nil, c.config) - if err != nil { - logger.Error(common.NewError("failed to dial tls tunnel").Base(err)) - return nil, err - } - - client, err := smux.Client(conn, nil) - common.Must(err) - logger.Info("mux TLS tunnel established, id:", id) - return &muxClientInfo{ - client: client, - id: id, - lastActiveTime: time.Now(), - }, nil -} - -func (c *Client) pickMuxClient() (*muxClientInfo, error) { - c.muxLock.Lock() - defer c.muxLock.Unlock() - - for _, info := range c.muxPool { - if !info.client.IsClosed() && (info.client.NumStreams() < c.config.TCP.MuxConcurrency || c.config.TCP.MuxConcurrency <= 0) { - info.lastActiveTime = time.Now() - return info, nil - } - } - - //not found - info, err := c.newMuxClient() - if err != nil { - return nil, err - } - c.muxPool[info.id] = info - return info, nil -} - -func (c *Client) openMuxConn() (*smux.Stream, *muxClientInfo, error) { - info, err := c.pickMuxClient() - if err != nil { - return nil, nil, err - } - stream, err := info.client.OpenStream() - if err != nil { - return nil, nil, err - } - info.lastActiveTime = time.Now() - return stream, info, nil + config *conf.GlobalConfig + ctx context.Context + cancel context.CancelFunc + mux *muxPoolManager + associatedChan chan int } -func (c *Client) checkAndCloseIdleMuxClient() { - muxIdleDuration := time.Duration(c.config.TCP.MuxIdleTimeout) * time.Second +func (c *Client) listenUDP() { for { - select { - case <-time.After(muxIdleDuration / 4): - c.muxLock.Lock() - for id, info := range c.muxPool { - if info.client.IsClosed() { - delete(c.muxPool, id) - logger.Info("mux", id, "is dead") - } else if info.client.NumStreams() == 0 && time.Now().Sub(info.lastActiveTime) > muxIdleDuration { - info.client.Close() - delete(c.muxPool, id) - logger.Info("mux", id, "is closed due to inactive") - } - } - if len(c.muxPool) != 0 { - logger.Info("current mux pool conn num", len(c.muxPool)) - } - c.muxLock.Unlock() - case <-c.ctx.Done(): - c.muxLock.Lock() - for id, info := range c.muxPool { - info.client.Close() - logger.Info("mux", id, "closed") + start: + listener, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: c.config.LocalIP, + Port: int(c.config.LocalPort), + }) + if err != nil { + logger.Fatal(common.NewError("failed to listen udp").Base(err)) + } + inbound, err := socks.NewInboundPacketSession(listener) + <-c.associatedChan + common.Must(err) + logger.Debug("associated signal") + req := protocol.Request{ + DomainName: []byte("UDP_CONN"), + AddressType: protocol.DomainName, + Command: protocol.Associate, + } + tunnel, err := trojan.NewOutboundConnSession(&req, nil, c.config) + if err != nil { + logger.Error(err) + continue + } + outbound, err := trojan.NewPacketSession(tunnel) + common.Must(err) + alive := make(chan int) + go proxy.ProxyPacketWithAliveChan(inbound, outbound, alive) + for { + select { + case <-alive: + logger.Debug("keep alive..(alive)") + case <-c.associatedChan: + logger.Debug("keep alive..(associated)") + case <-time.After(protocol.UDPTimeout): + logger.Debug("time out, closing UDP tunnel") + outbound.Close() + inbound.Close() + goto start + case <-c.ctx.Done(): + outbound.Close() + inbound.Close() + return } - c.muxLock.Unlock() - return } } } @@ -135,75 +87,57 @@ func (c *Client) checkAndCloseIdleMuxClient() { func (c *Client) handleSocksConn(conn net.Conn, rw *bufio.ReadWriter) { inboundConn, err := socks.NewInboundConnSession(conn, rw) if err != nil { - logger.Error(common.NewError("failed to start new inbound session:").Base(err)) + logger.Error(common.NewError("failed to start new inbound session").Base(err)) return } defer inboundConn.Close() req := inboundConn.GetRequest() if req.Command == protocol.Associate { - outboundConn, err := trojan.NewOutboundConnSession(req, nil, c.config) - if err != nil { - logger.Error(common.NewError("failed to start new outbound session for UDP").Base(err)) - return - } - - listenConn, err := net.ListenUDP("udp", &net.UDPAddr{ - IP: c.config.LocalIP, - }) - if err != nil { - logger.Error(common.NewError("failed to listen udp:").Base(err)) - return - } - + //setting up the bind address to respond + //listenUDP() will handle the incoming udp packets req.IP = c.config.LocalIP - port, err := protocol.ParsePort(listenConn.LocalAddr()) - common.Must(err) - req.Port = port - req.AddressType = protocol.IPv4 - - inboundPacket, err := socks.NewInboundPacketSession(listenConn) - if err != nil { - logger.Error("failed to start inbound packet session:", err) - return + req.Port = c.config.LocalPort + if c.config.LocalIP.To16() != nil { + req.AddressType = protocol.IPv6 + } else { + req.AddressType = protocol.IPv4 } - defer inboundPacket.Close() - - outboundPacket, err := trojan.NewPacketSession(outboundConn) - common.Must(err) - go proxyPacket(inboundPacket, outboundPacket) - - inboundConn.(protocol.NeedRespond).Respond(nil) + //notify listenUDP to get ready for relaying udp packets + c.associatedChan <- 1 logger.Info("UDP associated to", req) + if err := inboundConn.(protocol.NeedRespond).Respond(); err != nil { + logger.Error("failed to repsond") + } //stop relaying UDP once TCP connection is closed var buf [1]byte _, err = conn.Read(buf[:]) - logger.Info("UDP conn ends", err) + logger.Debug(common.NewError("UDP conn ends").Base(err)) return } - if err := inboundConn.(protocol.NeedRespond).Respond(nil); err != nil { + if err := inboundConn.(protocol.NeedRespond).Respond(); err != nil { logger.Error(common.NewError("failed to respond").Base(err)) return } if c.config.TCP.Mux { - stream, info, err := c.openMuxConn() + stream, info, err := c.mux.OpenMuxConn() if err != nil { logger.Error(common.NewError("failed to open mux stream").Base(err)) return } - defer stream.Close() outboundConn, err := mux.NewOutboundMuxConnSession(stream, req) if err != nil { + stream.Close() logger.Error(common.NewError("fail to start trojan session over mux conn").Base(err)) return } defer outboundConn.Close() logger.Info("conn from", conn.RemoteAddr(), "mux tunneling to", req, "mux id", info.id) - proxyConn(inboundConn, outboundConn) + proxy.ProxyConn(inboundConn, outboundConn) } else { outboundConn, err := trojan.NewOutboundConnSession(req, nil, c.config) if err != nil { @@ -213,7 +147,7 @@ func (c *Client) handleSocksConn(conn net.Conn, rw *bufio.ReadWriter) { defer outboundConn.Close() logger.Info("conn from", conn.RemoteAddr(), "tunneling to", req) - proxyConn(inboundConn, outboundConn) + proxy.ProxyConn(inboundConn, outboundConn) } } @@ -227,13 +161,13 @@ func (c *Client) handleHTTPConn(conn net.Conn, rw *bufio.ReadWriter) { defer inboundConn.Close() req := inboundConn.GetRequest() - if err := inboundConn.(protocol.NeedRespond).Respond(nil); err != nil { + if err := inboundConn.(protocol.NeedRespond).Respond(); err != nil { logger.Error(common.NewError("failed to respond").Base(err)) return } if c.config.TCP.Mux { - stream, info, err := c.openMuxConn() + stream, info, err := c.mux.OpenMuxConn() if err != nil { logger.Error(common.NewError("failed to open mux stream").Base(err)) return @@ -246,7 +180,7 @@ func (c *Client) handleHTTPConn(conn net.Conn, rw *bufio.ReadWriter) { } defer outboundConn.Close() logger.Info("conn from", conn.RemoteAddr(), "mux tunneling to", req, "mux id", info.id) - proxyConn(inboundConn, outboundConn) + proxy.ProxyConn(inboundConn, outboundConn) } else { outboundConn, err := trojan.NewOutboundConnSession(req, nil, c.config) if err != nil { @@ -256,15 +190,11 @@ func (c *Client) handleHTTPConn(conn net.Conn, rw *bufio.ReadWriter) { defer outboundConn.Close() logger.Info("conn from", conn.RemoteAddr(), "tunneling to", req) - proxyConn(inboundConn, outboundConn) + proxy.ProxyConn(inboundConn, outboundConn) } } else { defer inboundPacket.Close() - type httpPacket struct { - request *protocol.Request - packet []byte - } - packetChan := make(chan *httpPacket, 128) + packetChan := make(chan *packetInfo, 128) readHTTPPackets := func() { for { @@ -273,7 +203,7 @@ func (c *Client) handleHTTPConn(conn net.Conn, rw *bufio.ReadWriter) { logger.Error(err) return } - packetChan <- &httpPacket{ + packetChan <- &packetInfo{ request: req, packet: packet, } @@ -286,7 +216,7 @@ func (c *Client) handleHTTPConn(conn net.Conn, rw *bufio.ReadWriter) { case packet := <-packetChan: var outboundConn protocol.ConnSession if c.config.TCP.Mux { - stream, info, err := c.openMuxConn() + stream, info, err := c.mux.OpenMuxConn() if err != nil { logger.Error(common.NewError("failed to open mux stream").Base(err)) continue @@ -315,15 +245,11 @@ func (c *Client) handleHTTPConn(conn net.Conn, rw *bufio.ReadWriter) { for { n, err := outboundConn.Read(buf[:]) if err != nil { - if err.Error() != "EOF" { - logger.Error(err) - } + logger.Debug(err) return } if _, err = inboundPacket.WritePacket(nil, buf[0:n]); err != nil { - if err.Error() != "EOF" { - logger.Error(err) - } + logger.Debug(err) return } } @@ -340,20 +266,13 @@ func (c *Client) handleHTTPConn(conn net.Conn, rw *bufio.ReadWriter) { } func (c *Client) Run() error { + go c.listenUDP() listener, err := net.Listen("tcp", c.config.LocalAddr.String()) if err != nil { return common.NewError("failed to listen local address").Base(err) } defer listener.Close() - ctx, cancel := context.WithCancel(context.Background()) - c.ctx = ctx - c.cancel = cancel - - if c.config.TCP.MuxIdleTimeout > 0 { - go c.checkAndCloseIdleMuxClient() - } - c.ctx = ctx logger.Info("client is running at", listener.Addr()) for { conn, err := listener.Accept() @@ -385,3 +304,21 @@ func (c *Client) Close() error { c.cancel() return nil } + +func (c *Client) Build(config *conf.GlobalConfig) (common.Runnable, error) { + c.ctx, c.cancel = context.WithCancel(context.Background()) + c.associatedChan = make(chan int) + if config.TCP.Mux { + var err error + c.mux, err = NewMuxPoolManager(c.ctx, config) + if err != nil { + logger.Fatal(err) + } + } + c.config = config + return c, nil +} + +func init() { + proxy.RegisterProxy(conf.Client, &Client{}) +} diff --git a/proxy/client/mux.go b/proxy/client/mux.go new file mode 100644 index 000000000..73fb6b771 --- /dev/null +++ b/proxy/client/mux.go @@ -0,0 +1,147 @@ +package client + +import ( + "context" + "math/rand" + "sync" + "time" + + "github.com/p4gefau1t/trojan-go/common" + "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/protocol" + "github.com/p4gefau1t/trojan-go/protocol/trojan" + "github.com/xtaci/smux" +) + +type muxID uint32 + +func generateMuxID() muxID { + return muxID(rand.Uint32()) +} + +type muxClientInfo struct { + id muxID + client *smux.Session + lastActiveTime time.Time +} + +type muxPoolManager struct { + sync.Mutex + muxPool map[muxID]*muxClientInfo + config *conf.GlobalConfig + ctx context.Context +} + +func (m *muxPoolManager) newMuxClient() (*muxClientInfo, error) { + id := generateMuxID() + if _, found := m.muxPool[id]; found { + return nil, common.NewError("duplicated id") + } + req := &protocol.Request{ + Command: protocol.Mux, + DomainName: []byte("MUX_CONN"), + AddressType: protocol.DomainName, + } + conn, err := trojan.NewOutboundConnSession(req, nil, m.config) + if err != nil { + logger.Error(common.NewError("failed to dial tls tunnel").Base(err)) + return nil, err + } + + client, err := smux.Client(conn, nil) + common.Must(err) + logger.Info("mux TLS tunnel established, id:", id) + return &muxClientInfo{ + client: client, + id: id, + lastActiveTime: time.Now(), + }, nil +} + +func (m *muxPoolManager) pickMuxClient() (*muxClientInfo, error) { + m.Lock() + defer m.Unlock() + + for _, info := range m.muxPool { + if info.client.IsClosed() { + delete(m.muxPool, info.id) + logger.Info("mux", info.id, "is dead") + continue + } + if info.client.NumStreams() < m.config.TCP.MuxConcurrency || m.config.TCP.MuxConcurrency <= 0 { + info.lastActiveTime = time.Now() + return info, nil + } + } + + //not found + info, err := m.newMuxClient() + if err != nil { + return nil, err + } + m.muxPool[info.id] = info + return info, nil +} + +func (m *muxPoolManager) OpenMuxConn() (*smux.Stream, *muxClientInfo, error) { + info, err := m.pickMuxClient() + if err != nil { + return nil, nil, err + } + stream, err := info.client.OpenStream() + if err != nil { + return nil, nil, err + } + info.lastActiveTime = time.Now() + return stream, info, nil +} + +func (m *muxPoolManager) checkAndCloseIdleMuxClient() { + var muxIdleDuration, checkDuration time.Duration + if m.config.TCP.MuxIdleTimeout <= 0 { + muxIdleDuration = 0 + checkDuration = time.Second * 10 + logger.Warn("invalid mux idle timeout") + } else { + muxIdleDuration = time.Duration(m.config.TCP.MuxIdleTimeout) * time.Second + checkDuration = muxIdleDuration / 4 + } + for { + select { + case <-time.After(checkDuration): + m.Lock() + for id, info := range m.muxPool { + if info.client.IsClosed() { + delete(m.muxPool, id) + logger.Info("mux", id, "is dead") + } else if info.client.NumStreams() == 0 && time.Now().Sub(info.lastActiveTime) > muxIdleDuration { + info.client.Close() + delete(m.muxPool, id) + logger.Info("mux", id, "is closed due to inactive") + } + } + if len(m.muxPool) != 0 { + logger.Info("current mux pool conn num", len(m.muxPool)) + } + m.Unlock() + case <-m.ctx.Done(): + m.Lock() + for id, info := range m.muxPool { + info.client.Close() + logger.Info("mux", id, "closed") + } + m.Unlock() + return + } + } +} + +func NewMuxPoolManager(ctx context.Context, config *conf.GlobalConfig) (*muxPoolManager, error) { + m := &muxPoolManager{ + ctx: ctx, + config: config, + muxPool: make(map[muxID]*muxClientInfo), + } + go m.checkAndCloseIdleMuxClient() + return m, nil +} diff --git a/proxy/nat.go b/proxy/client/nat.go similarity index 54% rename from proxy/nat.go rename to proxy/client/nat.go index 5f22d3f54..73f5ff7da 100644 --- a/proxy/nat.go +++ b/proxy/client/nat.go @@ -1,6 +1,6 @@ // +build linux -package proxy +package client import ( "context" @@ -9,50 +9,69 @@ import ( "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/conf" "github.com/p4gefau1t/trojan-go/protocol" + "github.com/p4gefau1t/trojan-go/protocol/mux" "github.com/p4gefau1t/trojan-go/protocol/nat" "github.com/p4gefau1t/trojan-go/protocol/trojan" + "github.com/p4gefau1t/trojan-go/proxy" ) type NAT struct { common.Runnable + proxy.Buildable config *conf.GlobalConfig ctx context.Context cancel context.CancelFunc packetInbound protocol.PacketSession listener net.Listener + mux *muxPoolManager } func (n *NAT) handleConn(conn net.Conn) { inbound, err := nat.NewInboundConnSession(conn) if err != nil { - logger.Error("failed to start inbound session", err) + logger.Error(common.NewError("failed to start inbound session").Base(err)) return } req := inbound.GetRequest() defer inbound.Close() + if n.config.TCP.Mux { + stream, info, err := n.mux.OpenMuxConn() + if err != nil { + logger.Error(common.NewError("failed to open mux stream").Base(err)) + return + } + outbound, err := mux.NewOutboundMuxConnSession(stream, req) + if err != nil { + stream.Close() + logger.Error(common.NewError("failed to start mux outbound session").Base(err)) + return + } + defer outbound.Close() + logger.Info("[transparent]conn from", conn.RemoteAddr(), "mux tunneling to", req, "mux id", info.id) + proxy.ProxyConn(inbound, outbound) + return + } outbound, err := trojan.NewOutboundConnSession(req, nil, n.config) if err != nil { logger.Error("failed to start outbound session", err) return } defer outbound.Close() - logger.Info("transparent nat from", conn.RemoteAddr(), "tunneling to", req) - proxyConn(inbound, outbound) + logger.Info("[transparent]conn from", conn.RemoteAddr(), "tunneling to", req) + proxy.ProxyConn(inbound, outbound) } func (n *NAT) listenUDP() { inbound, err := nat.NewInboundPacketSession(n.config) - n.packetInbound = inbound if err != nil { - logger.Error(err) - panic(err) + logger.Fatal(err) } + n.packetInbound = inbound defer inbound.Close() req := protocol.Request{ - IP: net.IPv4(233, 233, 233, 233), - Port: 2333, - AddressType: protocol.IPv4, + DomainName: []byte("UDP_CONN"), + AddressType: protocol.DomainName, Command: protocol.Associate, } for { @@ -67,7 +86,8 @@ func (n *NAT) listenUDP() { continue } outbound, err := trojan.NewPacketSession(tunnel) - proxyPacket(inbound, outbound) + common.Must(err) + proxy.ProxyPacket(inbound, outbound) tunnel.Close() } } @@ -87,6 +107,11 @@ func (n *NAT) Run() error { for { conn, err := n.listener.Accept() if err != nil { + select { + case <-n.ctx.Done(): + return nil + default: + } logger.Error(err) continue } @@ -97,5 +122,24 @@ func (n *NAT) Run() error { func (n *NAT) Close() error { logger.Info("shutting down nat...") n.cancel() + n.listener.Close() + n.packetInbound.Close() return nil } + +func (n *NAT) Build(config *conf.GlobalConfig) (common.Runnable, error) { + n.ctx, n.cancel = context.WithCancel(context.Background()) + n.config = config + if config.TCP.Mux { + mux, err := NewMuxPoolManager(n.ctx, config) + if err != nil { + logger.Fatal(err) + } + n.mux = mux + } + return n, nil +} + +func init() { + proxy.RegisterProxy(conf.NAT, &NAT{}) +} diff --git a/proxy/client_test.go b/proxy/client_test.go deleted file mode 100644 index 515371a7a..000000000 --- a/proxy/client_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package proxy - -import ( - "crypto/x509" - "io/ioutil" - "net" - "testing" - - "github.com/p4gefau1t/trojan-go/common" - "github.com/p4gefau1t/trojan-go/conf" -) - -func TestClient(t *testing.T) { - serverCertBytes, err := ioutil.ReadFile("./server.crt") - common.Must(err) - pool := x509.NewCertPool() - pool.AppendCertsFromPEM(serverCertBytes) - ip := net.IPv4(127, 0, 0, 1) - port := 4444 - password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 4445, - }, - Hash: map[string]string{common.SHA224String(password): password}, - } - config.TLS.CertPool = pool - config.TLS.SNI = "localhost" - - c := Client{ - config: config, - } - c.Run() -} - -func TestMuxClient(t *testing.T) { - serverCertBytes, err := ioutil.ReadFile("server.crt") - common.Must(err) - pool := x509.NewCertPool() - pool.AppendCertsFromPEM(serverCertBytes) - ip := net.IPv4(127, 0, 0, 1) - port := 4444 - password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 4445, - }, - TCP: conf.TCPConfig{ - Mux: true, - MuxConcurrency: 8, - }, - Hash: map[string]string{common.SHA224String(password): password}, - } - config.TCP.MuxIdleTimeout = 10 - config.TLS.CertPool = pool - config.TLS.SNI = "localhost" - - c := Client{ - config: config, - muxPool: make(map[muxID]*muxClientInfo), - } - c.Run() -} - -func TestClientWithJSON(t *testing.T) { - data, err := ioutil.ReadFile("client.json") - common.Must(err) - config, err := conf.ParseJSON(data) - common.Must(err) - - client := Client{ - config: config, - } - client.Run() -} diff --git a/proxy/forward.go b/proxy/forward/forward.go similarity index 69% rename from proxy/forward.go rename to proxy/forward/forward.go index 0e9027e4e..f4a1d5b92 100644 --- a/proxy/forward.go +++ b/proxy/forward/forward.go @@ -1,13 +1,18 @@ -package proxy +package forward import ( "context" "net" + "os" "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/log" + "github.com/p4gefau1t/trojan-go/proxy" ) +var logger = log.New(os.Stdout) + type Forward struct { common.Runnable config *conf.GlobalConfig @@ -21,7 +26,7 @@ func (f *Forward) handleConn(conn net.Conn) { logger.Error("failed to connect to remote endpoint:", err) return } - proxyConn(newConn, conn) + proxy.ProxyConn(newConn, conn) } func (f *Forward) Run() error { @@ -50,3 +55,13 @@ func (f *Forward) Close() error { f.cancel() return nil } + +func (f *Forward) Build(config *conf.GlobalConfig) (common.Runnable, error) { + f.ctx, f.cancel = context.WithCancel(context.Background()) + f.config = config + return f, nil +} + +func init() { + proxy.RegisterProxy(conf.Forward, &Forward{}) +} diff --git a/proxy/forward_test.go b/proxy/forward_test.go deleted file mode 100644 index b76e6990b..000000000 --- a/proxy/forward_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package proxy - -import ( - "crypto/tls" - "crypto/x509" - "io/ioutil" - "net" - "testing" - - "github.com/p4gefau1t/trojan-go/common" - "github.com/p4gefau1t/trojan-go/conf" -) - -func TestForward(t *testing.T) { - serverCertBytes, err := ioutil.ReadFile("./server.crt") - common.Must(err) - pool := x509.NewCertPool() - pool.AppendCertsFromPEM(serverCertBytes) - ip := net.IPv4(127, 0, 0, 1) - forwardPort := 5000 - password := "pass123123" - clientConfig := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: 4444, - }, - LocalIP: ip, - LocalPort: uint16(forwardPort), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: forwardPort, - }, - Hash: map[string]string{common.SHA224String(password): password}, - } - clientConfig.TLS.CertPool = pool - clientConfig.TLS.SNI = "localhost" - - c := Client{ - config: clientConfig, - } - go c.Run() - - key, err := tls.LoadX509KeyPair("server.crt", "server.key") - - common.Must(err) - serverPort := 4445 - serverConfig := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: serverPort, - }, - LocalIP: ip, - LocalPort: uint16(forwardPort), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 80, - }, - RemoteIP: ip, - RemotePort: 80, - Hash: map[string]string{common.SHA224String(password): password}, - } - serverConfig.TLS.KeyPair = []tls.Certificate{key} - serverConfig.TLS.SNI = "localhost" - - server := Server{ - config: serverConfig, - } - go server.Run() - - forwardConfig := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: forwardPort, - }, - LocalIP: ip, - LocalPort: uint16(forwardPort), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: serverPort, - }, - RemoteIP: ip, - RemotePort: uint16(serverPort), - } - - forward := Forward{ - config: forwardConfig, - } - forward.Run() - -} diff --git a/proxy/nat_stub.go b/proxy/nat_stub.go deleted file mode 100644 index 74d6d16f0..000000000 --- a/proxy/nat_stub.go +++ /dev/null @@ -1,17 +0,0 @@ -// +build !linux - -package proxy - -import ( - "github.com/p4gefau1t/trojan-go/common" - "github.com/p4gefau1t/trojan-go/conf" -) - -type NAT struct { - common.Runnable - config *conf.GlobalConfig -} - -func (n *NAT) Run() error { - return common.NewError("not supported os") -} diff --git a/proxy/nat_test.go b/proxy/nat_test.go deleted file mode 100644 index 08191cfd5..000000000 --- a/proxy/nat_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package proxy - -import ( - "io/ioutil" - "testing" - - "github.com/p4gefau1t/trojan-go/common" - "github.com/p4gefau1t/trojan-go/conf" -) - -func TestNAT(t *testing.T) { - data, err := ioutil.ReadFile("nat2.json") - common.Must(err) - config, err := conf.ParseJSON(data) - common.Must(err) - - nat := NAT{ - config: config, - } - err = nat.Run() - common.Must(err) -} diff --git a/proxy/option.go b/proxy/option.go new file mode 100644 index 000000000..d7330f01b --- /dev/null +++ b/proxy/option.go @@ -0,0 +1,61 @@ +package proxy + +import ( + "flag" + "io/ioutil" + "os" + "os/signal" + + "github.com/p4gefau1t/trojan-go/common" + "github.com/p4gefau1t/trojan-go/conf" +) + +type proxyOption struct { + args *string + common.OptionHandler +} + +func (*proxyOption) Name() string { + return "proxy" +} + +func (*proxyOption) Priority() int { + return 0 +} + +func (c *proxyOption) Handle() error { + logger.Info("Trojan-Go proxy initializing...") + data, err := ioutil.ReadFile(*c.args) + if err != nil { + logger.Fatal(common.NewError("Failed to read config file").Base(err)) + } + config, err := conf.ParseJSON(data) + if err != nil { + logger.Fatal(common.NewError("Failed to parse config file").Base(err)) + } + proxy, err := NewProxy(config) + if err != nil { + logger.Fatal(err) + } + errChan := make(chan error) + go func() { + errChan <- proxy.Run() + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + select { + case <-sigs: + proxy.Close() + return nil + case err := <-errChan: + logger.Fatal(err) + return err + } +} + +func init() { + common.RegisterOptionHandler(&proxyOption{ + args: flag.String("config", "config.json", "Config filename"), + }) +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 2315e69b1..4bef954d4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -12,6 +12,10 @@ import ( var logger = log.New(os.Stdout) +type Buildable interface { + Build(config *conf.GlobalConfig) (common.Runnable, error) +} + func copyConn(dst io.Writer, src io.Reader, errChan chan error) { _, err := io.Copy(dst, src) errChan <- err @@ -32,58 +36,62 @@ func copyPacket(dst protocol.PacketWriter, src protocol.PacketReader, errChan ch } } -func proxyConn(a io.ReadWriteCloser, b io.ReadWriteCloser) { +func ProxyConn(a io.ReadWriteCloser, b io.ReadWriteCloser) { errChan := make(chan error, 2) go copyConn(a, b, errChan) go copyConn(b, a, errChan) err := <-errChan if err != nil { - if err.Error() != "EOF" { - logger.Error(common.NewError("conn proxy ends").Base(err)) - } - } else { - logger.Debug("conn proxy ends") + logger.Debug(common.NewError("conn proxy ends").Base(err)) } } -func proxyPacket(a protocol.PacketReadWriter, b protocol.PacketReadWriter) { +func ProxyPacket(a protocol.PacketReadWriter, b protocol.PacketReadWriter) { errChan := make(chan error, 2) go copyPacket(a, b, errChan) go copyPacket(b, a, errChan) err := <-errChan if err != nil { - if err.Error() != "EOF" { - logger.Error(common.NewError("packet proxy ends").Base(err)) - } - } else { - logger.Debug("packet proxy ends") + logger.Debug(common.NewError("packet proxy ends").Base(err)) } } -func NewProxy(config *conf.GlobalConfig) common.Runnable { - switch config.RunType { - case conf.Client: - client := &Client{ - config: config, - muxPool: make(map[muxID]*muxClientInfo), - } - return client - case conf.Server: - server := &Server{ - config: config, - } - return server - case conf.Forward: - forward := &Forward{ - config: config, +func copyPacketWithAliveChan(dst protocol.PacketWriter, src protocol.PacketReader, errChan chan error, aliveChan chan int) { + for { + req, packet, err := src.ReadPacket() + if err != nil { + errChan <- err + return } - return forward - case conf.NAT: - nat := &NAT{ - config: config, + _, err = dst.WritePacket(req, packet) + if err != nil { + errChan <- err + return } - return nat - default: - panic("invalid run type") + aliveChan <- 1 } } + +func ProxyPacketWithAliveChan(a protocol.PacketReadWriter, b protocol.PacketReadWriter, aliveChan chan int) { + errChan := make(chan error, 2) + go copyPacket(a, b, errChan) + go copyPacket(b, a, errChan) + err := <-errChan + if err != nil { + logger.Debug(common.NewError("packet proxy ends").Base(err)) + } +} + +var buildableMap map[conf.RunType]Buildable = make(map[conf.RunType]Buildable) + +func NewProxy(config *conf.GlobalConfig) (common.Runnable, error) { + runType := config.RunType + if buildable, found := buildableMap[runType]; found { + return buildable.Build(config) + } + return nil, common.NewError("invalid run_type") +} + +func RegisterProxy(t conf.RunType, b Buildable) { + buildableMap[t] = b +} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go deleted file mode 100644 index 397415b55..000000000 --- a/proxy/proxy_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package proxy - -import ( - "crypto/x509" - "io/ioutil" - "net" - "sync" - "testing" - "time" - - "github.com/p4gefau1t/trojan-go/common" - "github.com/p4gefau1t/trojan-go/conf" - "github.com/p4gefau1t/trojan-go/test" - "golang.org/x/net/proxy" -) - -func TestClientToServer(t *testing.T) { - go TestServer(t) - TestClient(t) -} - -func TestClientToDatabaseServer(t *testing.T) { - go TestServerWithDatabase(t) - TestClient(t) -} - -func TestClientToServerWithJSON(t *testing.T) { - go TestServerWithJSON(t) - TestClientWithJSON(t) -} - -func TestMuxClientToServer(t *testing.T) { - go TestMuxClient(t) - TestServer(t) -} - -func TestClientToPortReusingServer(t *testing.T) { - go TestClient(t) - TestPortReusingServer(t) -} - -func TestSNIConfig(t *testing.T) { - go ClientWithWrongSNI(t) - TestServer(t) -} - -func TestSQLite(t *testing.T) { - go TestClient(t) - TestServerWithSQLite(t) -} - -func ClientWithWrongSNI(t *testing.T) { - serverCertBytes, err := ioutil.ReadFile("./server.crt") - common.Must(err) - pool := x509.NewCertPool() - pool.AppendCertsFromPEM(serverCertBytes) - ip := net.IPv4(127, 0, 0, 1) - port := 4444 - password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 4445, - }, - Hash: map[string]string{common.SHA224String(password): password}, - } - config.TLS.Verify = true - config.TLS.CertPool = pool - config.TLS.SNI = "localhost123" - config.TLS.VerifyHostname = true - - c := Client{ - config: config, - } - c.Run() - time.Sleep(time.Hour) -} - -func BenchmarkClientToServerHugePayload(b *testing.B) { - b.StopTimer() - data, err := ioutil.ReadFile("client.json") - common.Must(err) - clientConfig, err := conf.ParseJSON(data) - common.Must(err) - - client := Client{ - config: clientConfig, - } - go client.Run() - - data, err = ioutil.ReadFile("server.json") - common.Must(err) - serverConfig, err := conf.ParseJSON(data) - common.Must(err) - - server := Server{ - config: serverConfig, - } - go server.Run() - - tcpServer := test.RunBlackHoleTCPServer() - - mbytes := 512 - payload := test.GeneratePayload(1024 * 1024 * mbytes) - dialer, err := proxy.SOCKS5("tcp", clientConfig.LocalAddr.String(), nil, nil) - common.Must(err) - conn, err := dialer.Dial("tcp", tcpServer.String()) - common.Must(err) - b.StartTimer() - t1 := time.Now() - conn.Write(payload) - t2 := time.Now() - speed := float64(mbytes) / t2.Sub(t1).Seconds() - logger.Info("Speed: ", speed, "MBytes/s") - b.StopTimer() -} - -func BenchmarkClientToServerHugeConn(b *testing.B) { - b.StopTimer() - data, err := ioutil.ReadFile("client.json") - common.Must(err) - clientConfig, err := conf.ParseJSON(data) - common.Must(err) - - client := Client{ - config: clientConfig, - } - go client.Run() - - data, err = ioutil.ReadFile("server.json") - common.Must(err) - serverConfig, err := conf.ParseJSON(data) - common.Must(err) - - server := Server{ - config: serverConfig, - } - go server.Run() - - tcpServer := test.RunBlackHoleTCPServer() - - connNum := 1024 - mbytes := 1 - payload := test.GeneratePayload(1024 * 1024 * mbytes) - dialer, err := proxy.SOCKS5("tcp", clientConfig.LocalAddr.String(), nil, nil) - common.Must(err) - - wg := sync.WaitGroup{} - sender := func(wg *sync.WaitGroup) { - conn, err := dialer.Dial("tcp", tcpServer.String()) - common.Must(err) - conn.Write(payload) - conn.Close() - wg.Done() - } - b.StartTimer() - wg.Add(connNum) - t1 := time.Now() - for i := 0; i < connNum; i++ { - go sender(&wg) - } - wg.Wait() - t2 := time.Now() - speed := float64(mbytes) * float64(connNum) / t2.Sub(t1).Seconds() - logger.Info("Speed: ", speed, "MBytes/s") - b.StopTimer() -} - -func BenchmarkClientToContinuesHugeConn(b *testing.B) { - b.StopTimer() - data, err := ioutil.ReadFile("client.json") - common.Must(err) - clientConfig, err := conf.ParseJSON(data) - common.Must(err) - - client := Client{ - config: clientConfig, - } - go client.Run() - - data, err = ioutil.ReadFile("server.json") - common.Must(err) - serverConfig, err := conf.ParseJSON(data) - common.Must(err) - - server := Server{ - config: serverConfig, - } - go server.Run() - - tcpServer := test.RunBlackHoleTCPServer() - - connNum := 1024 - mbytes := 32 - payload := test.GeneratePayload(1024 * 1024 * mbytes) - dialer, err := proxy.SOCKS5("tcp", clientConfig.LocalAddr.String(), nil, nil) - common.Must(err) - - sender := func() { - conn, err := dialer.Dial("tcp", tcpServer.String()) - common.Must(err) - conn.Write(payload) - conn.Close() - } - b.StartTimer() - for i := 0; i < 100; i++ { - for i := 0; i < connNum; i++ { - go sender() - } - time.Sleep(time.Second / 10) - } - b.StopTimer() -} - -func BenchmarkMuxClientToServerHugePayload(b *testing.B) { - b.StopTimer() - data, err := ioutil.ReadFile("client.json") - common.Must(err) - clientConfig, err := conf.ParseJSON(data) - common.Must(err) - clientConfig.TCP.Mux = true - - client := Client{ - config: clientConfig, - } - go client.Run() - - data, err = ioutil.ReadFile("server.json") - common.Must(err) - serverConfig, err := conf.ParseJSON(data) - common.Must(err) - - server := Server{ - config: serverConfig, - } - go server.Run() - - tcpServer := test.RunBlackHoleTCPServer() - - mbytes := 128 - payload := test.GeneratePayload(1024 * 1024 * mbytes) - dialer, err := proxy.SOCKS5("tcp", clientConfig.LocalAddr.String(), nil, nil) - common.Must(err) - conn, err := dialer.Dial("tcp", tcpServer.String()) - common.Must(err) - b.StartTimer() - t1 := time.Now() - conn.Write(payload) - t2 := time.Now() - speed := float64(mbytes) / t2.Sub(t1).Seconds() - logger.Info("Speed: ", speed, "MBytes/s") - b.StopTimer() -} - -func BenchmarkMuxClientToContinuesHugeConn(b *testing.B) { - b.StopTimer() - data, err := ioutil.ReadFile("client.json") - common.Must(err) - clientConfig, err := conf.ParseJSON(data) - common.Must(err) - clientConfig.TCP.Mux = true - - client := Client{ - config: clientConfig, - } - go client.Run() - - data, err = ioutil.ReadFile("server.json") - common.Must(err) - serverConfig, err := conf.ParseJSON(data) - common.Must(err) - - server := Server{ - config: serverConfig, - } - go server.Run() - - tcpServer := test.RunBlackHoleTCPServer() - - connNum := 256 - mbytes := 16 - payload := test.GeneratePayload(1024 * 1024 * mbytes) - dialer, err := proxy.SOCKS5("tcp", clientConfig.LocalAddr.String(), nil, nil) - common.Must(err) - wg := sync.WaitGroup{} - wg.Add(connNum) - - sender := func() { - conn, err := dialer.Dial("tcp", tcpServer.String()) - common.Must(err) - conn.Write(payload) - conn.Close() - wg.Done() - } - b.StartTimer() - for i := 0; i < connNum; i++ { - go sender() - } - wg.Wait() - b.StopTimer() -} diff --git a/proxy/server.go b/proxy/server/server.go similarity index 89% rename from proxy/server.go rename to proxy/server/server.go index 972ff37ca..a9088ea12 100644 --- a/proxy/server.go +++ b/proxy/server/server.go @@ -1,24 +1,30 @@ -package proxy +package server import ( "context" "crypto/tls" "database/sql" "net" + "os" "reflect" "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/log" "github.com/p4gefau1t/trojan-go/protocol" "github.com/p4gefau1t/trojan-go/protocol/direct" "github.com/p4gefau1t/trojan-go/protocol/mux" "github.com/p4gefau1t/trojan-go/protocol/trojan" + "github.com/p4gefau1t/trojan-go/proxy" "github.com/p4gefau1t/trojan-go/stat" "github.com/xtaci/smux" ) +var logger = log.New(os.Stdout) + type Server struct { common.Runnable + proxy.Buildable listener net.Listener auth stat.Authenticator @@ -49,7 +55,7 @@ func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) { } logger.Info("user", passwordHash, "mux tunneling to", req.String()) defer outboundConn.Close() - proxyConn(inboundConn, outboundConn) + proxy.ProxyConn(inboundConn, outboundConn) } func (s *Server) handleConn(conn net.Conn) { @@ -69,11 +75,7 @@ func (s *Server) handleConn(conn net.Conn) { for { stream, err := muxServer.AcceptStream() if err != nil { - if err.Error() == "EOF" { - logger.Info("mux conn from", conn.RemoteAddr(), "closed") - } else { - logger.Info("mux conn from", conn.RemoteAddr(), "closed: ", err) - } + logger.Debug("mux conn from", conn.RemoteAddr(), "closed: ", err) return } go s.handleMuxConn(stream, hash) @@ -92,7 +94,7 @@ func (s *Server) handleConn(conn net.Conn) { } defer outboundPacket.Close() logger.Info("UDP associated") - proxyPacket(inboundPacket, outboundPacket) + proxy.ProxyPacket(inboundPacket, outboundPacket) logger.Info("UDP tunnel closed") return } @@ -106,15 +108,14 @@ func (s *Server) handleConn(conn net.Conn) { defer outboundConn.Close() logger.Info("conn from", conn.RemoteAddr(), "tunneling to", req.String()) - proxyConn(inboundConn, outboundConn) + proxy.ProxyConn(inboundConn, outboundConn) } func (s *Server) handleInvalidConn(conn net.Conn, tlsConn *tls.Conn) { - + defer conn.Close() if len(s.config.TLS.HTTPResponse) > 0 { logger.Warn("trying to response a plain http response") conn.Write(s.config.TLS.HTTPResponse) - conn.Close() return } @@ -133,22 +134,18 @@ func (s *Server) handleInvalidConn(conn net.Conn, tlsConn *tls.Conn) { remote, err := net.Dial("tcp", s.config.TLS.FallbackAddr.String()) if err != nil { logger.Warn(common.NewError("failed to dial to tls fallback server").Base(err)) + return } logger.Warn("proxying this invalid tls conn to the tls fallback server") remote.Write(buf) - go proxyConn(conn, remote) + proxy.ProxyConn(conn, remote) } else { logger.Warn("fallback port is unspecified, closing") - conn.Close() } } func (s *Server) Run() error { - ctx, cancel := context.WithCancel(context.Background()) - s.ctx = ctx - s.cancel = cancel - var db *sql.DB var err error if s.config.MySQL.Enabled { @@ -163,11 +160,6 @@ func (s *Server) Run() error { if err != nil { return common.NewError("failed to connect to database server").Base(err) } - } else if s.config.SQLite.Enabled { - db, err = common.ConnectSQLite(s.config.SQLite.Database) - if err != nil { - return common.NewError("failed to connect to database server").Base(err) - } } if db == nil { s.auth = &stat.ConfigUserAuthenticator{ @@ -245,3 +237,13 @@ func (s *Server) Close() error { s.cancel() return nil } + +func (s *Server) Build(config *conf.GlobalConfig) (common.Runnable, error) { + s.config = config + s.ctx, s.cancel = context.WithCancel(context.Background()) + return s, nil +} + +func init() { + proxy.RegisterProxy(conf.Server, &Server{}) +} diff --git a/proxy/tcp_option.go b/proxy/server/tcp_option.go similarity index 96% rename from proxy/tcp_option.go rename to proxy/server/tcp_option.go index e8cb3d731..554536e11 100644 --- a/proxy/tcp_option.go +++ b/proxy/server/tcp_option.go @@ -1,6 +1,6 @@ // +build !windows -package proxy +package server import ( "net" diff --git a/proxy/tcp_option_stub.go b/proxy/server/tcp_option_stub.go similarity index 92% rename from proxy/tcp_option_stub.go rename to proxy/server/tcp_option_stub.go index e7466488c..0f85ad442 100644 --- a/proxy/tcp_option_stub.go +++ b/proxy/server/tcp_option_stub.go @@ -1,6 +1,6 @@ // +build windows -package proxy +package server import ( "net" diff --git a/proxy/server_test.go b/proxy/server_test.go deleted file mode 100644 index 3e0d56fb4..000000000 --- a/proxy/server_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package proxy - -import ( - "crypto/tls" - "io/ioutil" - "net" - "testing" - - "github.com/p4gefau1t/trojan-go/common" - "github.com/p4gefau1t/trojan-go/conf" -) - -func TestServer(t *testing.T) { - key, err := tls.LoadX509KeyPair("server.crt", "server.key") - common.Must(err) - ip := net.IPv4(127, 0, 0, 1) - port := 4445 - password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 80, - }, - RemoteIP: ip, - RemotePort: 80, - Hash: map[string]string{common.SHA224String(password): password}, - } - config.TLS.KeyPair = []tls.Certificate{key} - config.TLS.SNI = "localhost" - - server := Server{ - config: config, - } - server.Run() -} - -func TestServerWithJSON(t *testing.T) { - data, err := ioutil.ReadFile("server.json") - common.Must(err) - config, err := conf.ParseJSON(data) - common.Must(err) - - server := Server{ - config: config, - } - common.Must(server.Run()) -} - -func TestServerWithDatabase(t *testing.T) { - key, err := tls.LoadX509KeyPair("server.crt", "server.key") - common.Must(err) - ip := net.IPv4(127, 0, 0, 1) - port := 4445 - //password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 80, - }, - RemoteIP: ip, - RemotePort: 80, - Hash: map[string]string{}, - MySQL: conf.MySQLConfig{ - Enabled: true, - ServerHost: "127.0.0.1", - ServerPort: 3306, - Username: "root", - Password: "password", - Database: "trojan", - }, - } - config.TLS.KeyPair = []tls.Certificate{key} - config.TLS.SNI = "localhost" - - server := Server{ - config: config, - } - common.Must(server.Run()) -} - -func TestPortReusingServer(t *testing.T) { - key, err := tls.LoadX509KeyPair("server.crt", "server.key") - common.Must(err) - ip := net.IPv4(127, 0, 0, 1) - port := 4445 - password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 80, - }, - RemoteIP: ip, - RemotePort: 80, - Hash: map[string]string{common.SHA224String(password): password}, - TCP: conf.TCPConfig{ - ReusePort: true, - }, - } - config.TLS.KeyPair = []tls.Certificate{key} - config.TLS.SNI = "localhost" - - server1 := Server{ - config: config, - } - server2 := Server{ - config: config, - } - go server1.Run() - server2.Run() - //common.Must(server2.Run()) - //time.Sleep(time.Hour) -} - -func TestServerTCPRedirecting(t *testing.T) { - key, err := tls.LoadX509KeyPair("server.crt", "server.key") - common.Must(err) - ip := net.IPv4(127, 0, 0, 1) - port := 4445 - password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 80, - }, - RemoteIP: ip, - RemotePort: 80, - Hash: map[string]string{common.SHA224String(password): password}, - } - config.TLS.KeyPair = []tls.Certificate{key} - config.TLS.SNI = "localhost" - - addr := &net.TCPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: 443, - } - - common.Must(err) - config.TLS.FallbackAddr = addr - - server := Server{ - config: config, - } - server.Run() -} - -func TestServerWithSQLite(t *testing.T) { - key, err := tls.LoadX509KeyPair("server.crt", "server.key") - common.Must(err) - ip := net.IPv4(127, 0, 0, 1) - port := 4445 - //password := "pass123123" - config := &conf.GlobalConfig{ - LocalAddr: &net.TCPAddr{ - IP: ip, - Port: port, - }, - LocalIP: ip, - LocalPort: uint16(port), - RemoteAddr: &net.TCPAddr{ - IP: ip, - Port: 80, - }, - RemoteIP: ip, - RemotePort: 80, - Hash: map[string]string{}, - SQLite: conf.SQLiteConfig{ - Enabled: true, - Database: "test.db", - }, - } - config.TLS.KeyPair = []tls.Certificate{key} - config.TLS.SNI = "localhost" - - server := Server{ - config: config, - } - common.Must(server.Run()) -} diff --git a/test/proxy_test.go b/test/proxy_test.go new file mode 100644 index 000000000..427cc8720 --- /dev/null +++ b/test/proxy_test.go @@ -0,0 +1,430 @@ +package test + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "net/http" + "sync" + "testing" + "time" + + _ "net/http/pprof" + + "github.com/p4gefau1t/trojan-go/common" + "github.com/p4gefau1t/trojan-go/conf" + "github.com/p4gefau1t/trojan-go/log" + "github.com/p4gefau1t/trojan-go/proxy/client" + "github.com/p4gefau1t/trojan-go/proxy/server" + "golang.org/x/net/proxy" +) + +var cert string = ` +-----BEGIN CERTIFICATE----- +MIIDZTCCAk0CFFphZh018B5iAD9F5fV4y0AlD0LxMA0GCSqGSIb3DQEBCwUAMG8x +CzAJBgNVBAYTAlVTMQ0wCwYDVQQIDARNYXJzMRMwEQYDVQQHDAppVHJhbnN3YXJw +MRMwEQYDVQQKDAppVHJhbnN3YXJwMRMwEQYDVQQLDAppVHJhbnN3YXJwMRIwEAYD +VQQDDAlsb2NhbGhvc3QwHhcNMjAwMzMxMTAwMDUxWhcNMzAwMzI5MTAwMDUxWjBv +MQswCQYDVQQGEwJVUzENMAsGA1UECAwETWFyczETMBEGA1UEBwwKaVRyYW5zd2Fy +cDETMBEGA1UECgwKaVRyYW5zd2FycDETMBEGA1UECwwKaVRyYW5zd2FycDESMBAG +A1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA +ml44fThYMkCcT627o7ibEs7mq2WOhImjDwYijYJ1684BatrCsHJNcw8PJGTuP+tg +GdngmALjA3l+RipjaE/UK4FJrAjruphA/hOCjZfWqk8KBR4qk0OltxCMWJlp/XCM +9ny1ogFdWUlBbqThs4NWSOUESgxf/Be2njeiOrngGR31qxSiLCLBvafIhKqq/4av +Rlx0Ht770uvF97MlAj1ASAvzTZICHAfUZxEdWl0J4MBbG7SNcnMBbyAF+s60eFTa +4RGMfRGnUa2Fzz/gfjhvfSIGeLQ3JRG6sl6jkc5xe0PZzhq3UNpK0gtQ48yy9CSP +neZnrynoKks7XC2bizsr3QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQAHS/xuG5+F +yGU3N6V4kv+HbKqHaXNOq4zKVsCc1k7vg4MFFpKUJKxtJYooCI8n2ypp5XRUTIGQ +bmEbVcIPqm9Rf/4vHtF0falNCwieAbXDkiEHoykRmmU1UE/ccPA7X8NO9aVLJAJO +N2Li8MH0Ixgs02pQH56eyGKoRBWPR5C3ETQ9Leqvazg6Dn1iJWvmfF0mOte5228s +mZJOntF9t8MZOJdIWGdrUHn6euRfhd0btkmL/NUDzeCTwJcuPORLxkBbCP5mTC6G +GnLS5Z4oRYgCgvT2pLtcM0r48hYjwgjXFQ4zalkW6YI9LPpqwwMhhOzINlXjBaDi +Haz8uKI4EciU +-----END CERTIFICATE----- +` + +var key string = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAml44fThYMkCcT627o7ibEs7mq2WOhImjDwYijYJ1684BatrC +sHJNcw8PJGTuP+tgGdngmALjA3l+RipjaE/UK4FJrAjruphA/hOCjZfWqk8KBR4q +k0OltxCMWJlp/XCM9ny1ogFdWUlBbqThs4NWSOUESgxf/Be2njeiOrngGR31qxSi +LCLBvafIhKqq/4avRlx0Ht770uvF97MlAj1ASAvzTZICHAfUZxEdWl0J4MBbG7SN +cnMBbyAF+s60eFTa4RGMfRGnUa2Fzz/gfjhvfSIGeLQ3JRG6sl6jkc5xe0PZzhq3 +UNpK0gtQ48yy9CSPneZnrynoKks7XC2bizsr3QIDAQABAoIBAFpYUo9W7qdakSFA ++NS1Mm0rkm01nteLBlfAq3BOrl030DSNm+xQuWthoOcX+yiFxVTb40qURfC+plzC +ajOepPphTJDXF7+5ZDBPktTzzLsYTzD3mstdiBtAICOqhhHCUX3hNxx91/htm1H6 +Re4eK921y3DbFUIhTswCm3vrVXDc4yTXtURGllVzo40K/1Of39CpufKFdpJ81HV+ +h/VW++h3o+sFV4KqcqIjClxBfDxoJpBaRlOCunTiHqZNvqO+EPqPR5zdn34werjU +xQEvPzmz+ClwnaEXQxYWgIcYQii9VNsHogDxEw4R31S7lVrUt0f0atDmGJip1lPb +E7IomAECgYEAzKQ3PzBV46nUNfVO9SODpf14Z+xYfLKouPC+Qnepwp0V0JS6zY1+ +Wzskyb80drjnoQraWSEvGsX+tEWeLcnjN7JuMu/U8DPKRcQ+Q2dsVo/q4sfBOgvl +VhPNMZLfa7NIkRUx2KXku++Ep0Xtak0dskrfQrZnvhymRPyWuIMM6IECgYEAwRwL +Gt/ZZdUueE/hwT3c1hNn6igeDLOwK2t6frib+Ofw5oCAQxtTROvP1ljlnWUPkeIS +uzTusmqucalcK3lCHIsyHLwApOI/B31M971pxMVBRZ0wIbBaoarCGND7gi8JUPFR +VErGcAB5YnpRlmfLPEgw2o7DpjsDc2KmdE9oNV0CgYEAmfNEWLYtNztxGTK1treD +96ELLutf2lexlIgQKgLJ5E22tpbdPXwfvdRtpZTBjDsojj+S6hCL1lFzfv0MtZe2 +5xTF0G4avKXJmti6moy4tRpJ81ehZuDCJBJ7gLrkd6qFghf2yuxqenQDUK/Lnvfq +ylGHSjHdM+lrsGRxotd8I4ECgYBoo4GA9nseqv2bQ+3YgGUBu1I7l7FwwI1decfO +ksoxfb0Tqd3WfyAH4J+mTlVdjD17lzz/JBeTpisQe+ztwa8JOIPW/ih7L/1nWYYz +V/fQH/LWfe5u0tjJcXXrbJJcYJBzw8+GFV6hoiAkNJOxJF0ENToDtAhgMuoTxAje +TYjyIQKBgQCmHkLLq0Bj3FpIOVrwo2gNvQteNPa7jkkGp4lljO8JQUHhCHDGWKEH +MUJ0EFsxS/EaQa+rW6jHhs3GyBA2TxmC783stAOOEX+hO/zpcbzdCWgp6eZ0aGMW +WS94/5WE/lwHJi8ZPSjH1AURCzXhUi4fGvBrNBtry95e+jcEvP5c0g== +-----END RSA PRIVATE KEY----- +` + +func getKeyPair() []tls.Certificate { + cert, err := tls.X509KeyPair([]byte(cert), []byte(key)) + common.Must(err) + return []tls.Certificate{cert} +} + +func getTLSConfig() conf.TLSConfig { + KeyPair := getKeyPair() + pool := x509.NewCertPool() + if ok := pool.AppendCertsFromPEM([]byte(cert)); !ok { + panic("invalid cert") + } + c := conf.TLSConfig{ + CertPool: pool, + KeyPair: KeyPair, + Verify: true, + VerifyHostname: true, + SNI: "localhost", + } + return c +} + +func getLocalAddr(port int) net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: port, + } +} + +func getLocalIP() net.IP { + return net.IPv4(127, 0, 0, 1) +} + +func getHash(password string) map[string]string { + hash := common.SHA224String(password) + m := make(map[string]string) + m[hash] = password + return m +} + +func TestClientJSON(t *testing.T) { + data, err := ioutil.ReadFile("client.json") + common.Must(err) + config, err := conf.ParseJSON(data) + common.Must(err) + c := client.Client{} + c.Build(config) + c.Run() +} + +func TestClient(t *testing.T) { + config := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4444, + LocalAddr: getLocalAddr(4444), + RemoteIP: getLocalIP(), + RemotePort: 4445, + RemoteAddr: getLocalAddr(4445), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + c := client.Client{} + c.Build(config) + common.Must(c.Run()) +} + +func TestServer(t *testing.T) { + config := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4445, + LocalAddr: getLocalAddr(4445), + RemoteIP: getLocalIP(), + RemotePort: 80, + RemoteAddr: getLocalAddr(80), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + s := server.Server{} + s.Build(config) + common.Must(s.Run()) +} + +func TestNAT(t *testing.T) { + config := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4445, + LocalAddr: getLocalAddr(4445), + RemoteIP: getLocalIP(), + RemotePort: 80, + RemoteAddr: getLocalAddr(80), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + n := client.NAT{} + n.Build(config) + common.Must(n.Run()) +} + +func TestMuxClient(t *testing.T) { + config := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4444, + LocalAddr: getLocalAddr(4444), + RemoteIP: getLocalIP(), + RemotePort: 4445, + RemoteAddr: getLocalAddr(4445), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + TCP: conf.TCPConfig{ + Mux: true, + MuxConcurrency: 8, + MuxIdleTimeout: 30, + }, + } + client := client.Client{} + client.Build(config) + client.Run() +} + +func TestClientAndServer(t *testing.T) { + go func() { + err := http.ListenAndServe("0.0.0.0:8000", nil) + logger.Error(err) + }() + go TestClient(t) + TestServer(t) +} + +func TestMuxClientAndServer(t *testing.T) { + go func() { + err := http.ListenAndServe("0.0.0.0:8000", nil) + logger.Error(err) + }() + go TestMuxClient(t) + TestServer(t) +} + +func BenchmarkNormalClientToServer(b *testing.B) { + log.LogLevel = 5 + config1 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4444, + LocalAddr: getLocalAddr(4444), + RemoteIP: getLocalIP(), + RemotePort: 4445, + RemoteAddr: getLocalAddr(4445), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + c := client.Client{} + c.Build(config1) + go c.Run() + + config2 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4445, + LocalAddr: getLocalAddr(4445), + RemoteIP: getLocalIP(), + RemotePort: 80, + RemoteAddr: getLocalAddr(80), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + s := server.Server{} + s.Build(config2) + go s.Run() + + target := RunBlackHoleTCPServer() + dialer, err := proxy.SOCKS5("tcp", getLocalAddr(4444).String(), nil, nil) + common.Must(err) + conn, err := dialer.Dial("tcp", target.String()) + common.Must(err) + mbytes := 512 + payload := GeneratePayload(1024 * 1024 * mbytes) + t1 := time.Now() + conn.Write(payload) + t2 := time.Now() + speed := float64(mbytes) / t2.Sub(t1).Seconds() + b.Log("Speed: ", speed, "MB/s") + conn.Close() +} + +func BenchmarkMuxClientToServer(b *testing.B) { + log.LogLevel = 5 + config1 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4444, + LocalAddr: getLocalAddr(4444), + RemoteIP: getLocalIP(), + RemotePort: 4445, + RemoteAddr: getLocalAddr(4445), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + TCP: conf.TCPConfig{ + Mux: true, + MuxConcurrency: 8, + MuxIdleTimeout: 30, + }, + } + c := client.Client{} + c.Build(config1) + go c.Run() + + config2 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4445, + LocalAddr: getLocalAddr(4445), + RemoteIP: getLocalIP(), + RemotePort: 80, + RemoteAddr: getLocalAddr(80), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + s := server.Server{} + s.Build(config2) + go s.Run() + + target := RunBlackHoleTCPServer() + dialer, err := proxy.SOCKS5("tcp", getLocalAddr(4444).String(), nil, nil) + common.Must(err) + conn, err := dialer.Dial("tcp", target.String()) + common.Must(err) + mbytes := 512 + payload := GeneratePayload(1024 * 1024 * mbytes) + t1 := time.Now() + conn.Write(payload) + t2 := time.Now() + speed := float64(mbytes) / t2.Sub(t1).Seconds() + b.Log("Speed: ", speed, "MB/s") + conn.Close() +} + +func BenchmarkNormalClientToServerHighConcurrency(b *testing.B) { + log.LogLevel = 5 + config1 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4444, + LocalAddr: getLocalAddr(4444), + RemoteIP: getLocalIP(), + RemotePort: 4445, + RemoteAddr: getLocalAddr(4445), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + c := client.Client{} + c.Build(config1) + go c.Run() + + config2 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4445, + LocalAddr: getLocalAddr(4445), + RemoteIP: getLocalIP(), + RemotePort: 80, + RemoteAddr: getLocalAddr(80), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + s := server.Server{} + s.Build(config2) + go s.Run() + + target := RunBlackHoleTCPServer() + dialer, err := proxy.SOCKS5("tcp", getLocalAddr(4444).String(), nil, nil) + common.Must(err) + + connNum := 128 + mbytes := 128 + payload := GeneratePayload(1024 * 1024 * mbytes) + + wg := sync.WaitGroup{} + sender := func(wg *sync.WaitGroup) { + conn, err := dialer.Dial("tcp", target.String()) + common.Must(err) + conn.Write(payload) + conn.Close() + wg.Done() + } + + wg.Add(connNum) + + t1 := time.Now() + for i := 0; i < connNum; i++ { + go sender(&wg) + } + wg.Wait() + t2 := time.Now() + speed := float64(mbytes) * float64(connNum) / t2.Sub(t1).Seconds() + b.Log("Speed: ", speed, "MB/s") +} + +func BenchmarkMuxClientToServerHighConcurrency(b *testing.B) { + log.LogLevel = 5 + config1 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4444, + LocalAddr: getLocalAddr(4444), + RemoteIP: getLocalIP(), + RemotePort: 4445, + RemoteAddr: getLocalAddr(4445), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + TCP: conf.TCPConfig{ + Mux: true, + MuxConcurrency: 8, + MuxIdleTimeout: 30, + }, + } + c := client.Client{} + c.Build(config1) + go c.Run() + + config2 := &conf.GlobalConfig{ + LocalIP: getLocalIP(), + LocalPort: 4445, + LocalAddr: getLocalAddr(4445), + RemoteIP: getLocalIP(), + RemotePort: 80, + RemoteAddr: getLocalAddr(80), + TLS: getTLSConfig(), + Hash: getHash("pass123"), + } + s := server.Server{} + s.Build(config2) + go s.Run() + + target := RunBlackHoleTCPServer() + dialer, err := proxy.SOCKS5("tcp", getLocalAddr(4444).String(), nil, nil) + common.Must(err) + + connNum := 128 + mbytes := 128 + payload := GeneratePayload(1024 * 1024 * mbytes) + + wg := sync.WaitGroup{} + sender := func(wg *sync.WaitGroup) { + conn, err := dialer.Dial("tcp", target.String()) + common.Must(err) + conn.Write(payload) + conn.Close() + wg.Done() + } + + wg.Add(connNum) + + t1 := time.Now() + for i := 0; i < connNum; i++ { + go sender(&wg) + } + wg.Wait() + t2 := time.Now() + speed := float64(mbytes) * float64(connNum) / t2.Sub(t1).Seconds() + b.Log("Speed: ", speed, "MB/s") +} diff --git a/test/test.go b/test/test.go index 8dff9f119..1b4f7e2cc 100644 --- a/test/test.go +++ b/test/test.go @@ -34,6 +34,7 @@ func RunBlackHoleTCPServer() net.Addr { common.Must(err) blackhole := func(conn net.Conn) { io.Copy(ioutil.Discard, conn) + conn.Close() } serve := func() { for { diff --git a/version/version.go b/version/version.go new file mode 100644 index 000000000..723fbd0c9 --- /dev/null +++ b/version/version.go @@ -0,0 +1,41 @@ +package version + +import ( + "flag" + "os" + + "github.com/p4gefau1t/trojan-go/common" + "github.com/p4gefau1t/trojan-go/log" +) + +var logger = log.New(os.Stdout) + +type versionOption struct { + arg *bool + common.OptionHandler +} + +func (*versionOption) Name() string { + return "help" +} + +func (*versionOption) Priority() int { + return 0 +} + +func (c *versionOption) Handle() error { + if *c.arg { + logger.Info("Trojan-Go", common.Version) + logger.Info("Developed by PageFault(p4gefau1t)") + logger.Info("Lisensed under GNU General Public License v3") + logger.Info("GitHub Repository: https://github.com/p4gefau1t/trojan-go") + return nil + } + return common.NewError("not set") +} + +func init() { + common.RegisterOptionHandler(&versionOption{ + arg: flag.Bool("version", false, "Display version and help info"), + }) +}