From d61907c897b9ff13c6ae02b097890f429a7bbe5e Mon Sep 17 00:00:00 2001 From: p4gefau1t Date: Sun, 22 Mar 2020 04:17:39 -0400 Subject: [PATCH] capture interrupt --- main.go | 20 ++++++++++++++++---- proxy/client.go | 33 +++++++++++++++++++++++++++------ proxy/forward.go | 15 +++++++++++++++ proxy/nat.go | 47 ++++++++++++++++++++++++++++++++--------------- proxy/server.go | 43 ++++++++++++++++++++++++++++++++++--------- 5 files changed, 124 insertions(+), 34 deletions(-) diff --git a/main.go b/main.go index f4126a6fd..ef21053ac 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "flag" "io/ioutil" "os" + "os/signal" "github.com/p4gefau1t/trojan-go/conf" "github.com/p4gefau1t/trojan-go/proxy" @@ -13,7 +14,7 @@ import ( var logger = log.New(os.Stdout).WithColor() func main() { - logger.Info("Trojan-Go initializing") + logger.Info("Trojan-Go initializing...") configFile := flag.String("config", "config.json", "Config file name") flag.Parse() data, err := ioutil.ReadFile(*configFile) @@ -24,8 +25,19 @@ func main() { if err != nil { logger.Fatal("Failed to parse config file", err) } - err = proxy.NewProxy(config).Run() - if err != nil { - logger.Fatal("Error occured", err) + proxy := proxy.NewProxy(config) + errChan := make(chan error) + go func() { + errChan <- proxy.Run() + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + select { + case <-sigs: + proxy.Close() + case err := <-errChan: + logger.Fatal(err) } + logger.Info("Trojan-Go exited") } diff --git a/proxy/client.go b/proxy/client.go index 367239a10..9ec0d760f 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -22,13 +22,15 @@ type muxConn struct { } type Client struct { - config *conf.GlobalConfig common.Runnable + + config *conf.GlobalConfig muxClient *smux.Session muxClientLock sync.Mutex muxConnCount int32 lastActiveTime time.Time ctx context.Context + cancel context.CancelFunc } func (c *Client) checkAndNewMuxClient() { @@ -222,19 +224,27 @@ func (c *Client) handleConn(conn net.Conn) { func (c *Client) Run() error { listener, err := net.Listen("tcp", c.config.LocalAddr.String()) - //TODO - ctx, _ := context.WithCancel(context.Background()) + if err != nil { + return common.NewError("failed to listen local address").Base(err) + } + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + c.ctx = ctx + c.cancel = cancel + if c.config.TCP.MuxIdleTimeout > 0 { go c.checkAndCloseIdleMuxClient() } c.ctx = ctx - if err != nil { - return err - } logger.Info("running client at", listener.Addr()) for { conn, err := listener.Accept() if err != nil { + select { + case <-c.ctx.Done(): + default: + } logger.Error("error occured when accpeting conn", err) continue } @@ -245,3 +255,14 @@ func (c *Client) Run() error { } } } + +func (c *Client) Close() error { + logger.Info("shutting down client..") + c.cancel() + c.muxClientLock.Lock() + defer c.muxClientLock.Unlock() + if c.muxClient != nil { + c.muxClient.Close() + } + return nil +} diff --git a/proxy/forward.go b/proxy/forward.go index c35194dea..c8465decb 100644 --- a/proxy/forward.go +++ b/proxy/forward.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "net" "github.com/p4gefau1t/trojan-go/common" @@ -10,6 +11,8 @@ import ( type Forward struct { common.Runnable config *conf.GlobalConfig + ctx context.Context + cancel context.CancelFunc } func (f *Forward) handleConn(conn net.Conn) { @@ -26,12 +29,24 @@ func (f *Forward) Run() error { if err != nil { return common.NewError("failed to listen local address").Base(err) } + defer listener.Close() for { conn, err := listener.Accept() if err != nil { + select { + case <-f.ctx.Done(): + return nil + default: + } logger.Error(err) continue } go f.handleConn(conn) } } + +func (f *Forward) Close() error { + logger.Info("") + f.cancel() + return nil +} diff --git a/proxy/nat.go b/proxy/nat.go index d0e5ea4fc..5f22d3f54 100644 --- a/proxy/nat.go +++ b/proxy/nat.go @@ -3,6 +3,7 @@ package proxy import ( + "context" "net" "github.com/p4gefau1t/trojan-go/common" @@ -14,38 +15,35 @@ import ( type NAT struct { common.Runnable - config *conf.GlobalConfig + + config *conf.GlobalConfig + ctx context.Context + cancel context.CancelFunc + packetInbound protocol.PacketSession + listener net.Listener } func (n *NAT) handleConn(conn net.Conn) { inbound, err := nat.NewInboundConnSession(conn) if err != nil { logger.Error("failed to start inbound session", err) + return } req := inbound.GetRequest() defer inbound.Close() outbound, err := trojan.NewOutboundConnSession(req, nil, n.config) if err != nil { logger.Error("failed to start outbound session", err) + return } defer outbound.Close() logger.Info("transparent nat from", conn.RemoteAddr(), "tunneling to", req) proxyConn(inbound, outbound) } -func (n *NAT) listenTCP(l net.Listener) { - for { - conn, err := l.Accept() - if err != nil { - logger.Error(err) - continue - } - go n.handleConn(conn) - } -} - func (n *NAT) listenUDP() { inbound, err := nat.NewInboundPacketSession(n.config) + n.packetInbound = inbound if err != nil { logger.Error(err) panic(err) @@ -60,6 +58,11 @@ func (n *NAT) listenUDP() { for { tunnel, err := trojan.NewOutboundConnSession(&req, nil, n.config) if err != nil { + select { + case <-n.ctx.Done(): + return + default: + } logger.Error(err) continue } @@ -70,15 +73,29 @@ func (n *NAT) listenUDP() { } func (n *NAT) Run() error { + go n.listenUDP() logger.Info("nat running at", n.config.LocalAddr) - tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{ + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: n.config.LocalIP, Port: int(n.config.LocalPort), }) if err != nil { return err } - go n.listenUDP() - n.listenTCP(tcpListener) + n.listener = listener + defer listener.Close() + for { + conn, err := n.listener.Accept() + if err != nil { + logger.Error(err) + continue + } + go n.handleConn(conn) + } +} + +func (n *NAT) Close() error { + logger.Info("shutting down nat...") + n.cancel() return nil } diff --git a/proxy/server.go b/proxy/server.go index df24f5aab..815ff8a2d 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "crypto/tls" "database/sql" "net" @@ -16,10 +17,13 @@ import ( ) type Server struct { - config *conf.GlobalConfig common.Runnable - auth stat.Authenticator - meter stat.TrafficMeter + + auth stat.Authenticator + meter stat.TrafficMeter + config *conf.GlobalConfig + ctx context.Context + cancel context.CancelFunc } func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) { @@ -104,12 +108,10 @@ func (s *Server) handleConn(conn net.Conn) { } func (s *Server) Run() error { - tlsConfig := &tls.Config{ - Certificates: s.config.TLS.KeyPair, - CipherSuites: s.config.TLS.CipherSuites, - PreferServerCipherSuites: s.config.TLS.PreferServerCipher, - SessionTicketsDisabled: !s.config.TLS.SessionTicket, - } + ctx, cancel := context.WithCancel(context.Background()) + s.ctx = ctx + s.cancel = cancel + var db *sql.DB var err error if s.config.MySQL.Enabled { @@ -145,6 +147,8 @@ func (s *Server) Run() error { return common.NewError("failed to init traffic meter").Base(err) } } + defer s.auth.Close() + defer s.meter.Close() logger.Info("Server running at", s.config.LocalAddr) var listener net.Listener @@ -156,15 +160,31 @@ func (s *Server) Run() error { s.config.LocalIP, s.config.LocalAddr.String(), ) + if err != nil { + return err + } } else { listener, err = net.Listen("tcp", s.config.LocalAddr.String()) if err != nil { return err } } + defer listener.Close() + + tlsConfig := &tls.Config{ + Certificates: s.config.TLS.KeyPair, + CipherSuites: s.config.TLS.CipherSuites, + PreferServerCipherSuites: s.config.TLS.PreferServerCipher, + SessionTicketsDisabled: !s.config.TLS.SessionTicket, + } for { conn, err := listener.Accept() if err != nil { + select { + case <-s.ctx.Done(): + return nil + default: + } logger.Warn(err) continue } @@ -180,5 +200,10 @@ func (s *Server) Run() error { } go s.handleConn(tlsConn) } +} +func (s *Server) Close() error { + logger.Info("shutting down server..") + s.cancel() + return nil }