Skip to content

Commit

Permalink
capture interrupt
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed Mar 22, 2020
1 parent 350a8a7 commit d61907c
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 34 deletions.
20 changes: 16 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"flag"
"io/ioutil"
"os"
"os/signal"

"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/proxy"
Expand All @@ -13,7 +14,7 @@ import (
var logger = log.New(os.Stdout).WithColor()

func main() {
logger.Info("Trojan-Go initializing")
logger.Info("Trojan-Go initializing...")
configFile := flag.String("config", "config.json", "Config file name")
flag.Parse()
data, err := ioutil.ReadFile(*configFile)
Expand All @@ -24,8 +25,19 @@ func main() {
if err != nil {
logger.Fatal("Failed to parse config file", err)
}
err = proxy.NewProxy(config).Run()
if err != nil {
logger.Fatal("Error occured", err)
proxy := proxy.NewProxy(config)
errChan := make(chan error)
go func() {
errChan <- proxy.Run()
}()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
select {
case <-sigs:
proxy.Close()
case err := <-errChan:
logger.Fatal(err)
}
logger.Info("Trojan-Go exited")
}
33 changes: 27 additions & 6 deletions proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ type muxConn struct {
}

type Client struct {
config *conf.GlobalConfig
common.Runnable

config *conf.GlobalConfig
muxClient *smux.Session
muxClientLock sync.Mutex
muxConnCount int32
lastActiveTime time.Time
ctx context.Context
cancel context.CancelFunc
}

func (c *Client) checkAndNewMuxClient() {
Expand Down Expand Up @@ -222,19 +224,27 @@ func (c *Client) handleConn(conn net.Conn) {

func (c *Client) Run() error {
listener, err := net.Listen("tcp", c.config.LocalAddr.String())
//TODO
ctx, _ := context.WithCancel(context.Background())
if err != nil {
return common.NewError("failed to listen local address").Base(err)
}
defer listener.Close()

ctx, cancel := context.WithCancel(context.Background())
c.ctx = ctx
c.cancel = cancel

if c.config.TCP.MuxIdleTimeout > 0 {
go c.checkAndCloseIdleMuxClient()
}
c.ctx = ctx
if err != nil {
return err
}
logger.Info("running client at", listener.Addr())
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-c.ctx.Done():
default:
}
logger.Error("error occured when accpeting conn", err)
continue
}
Expand All @@ -245,3 +255,14 @@ func (c *Client) Run() error {
}
}
}

func (c *Client) Close() error {
logger.Info("shutting down client..")
c.cancel()
c.muxClientLock.Lock()
defer c.muxClientLock.Unlock()
if c.muxClient != nil {
c.muxClient.Close()
}
return nil
}
15 changes: 15 additions & 0 deletions proxy/forward.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"context"
"net"

"github.com/p4gefau1t/trojan-go/common"
Expand All @@ -10,6 +11,8 @@ import (
type Forward struct {
common.Runnable
config *conf.GlobalConfig
ctx context.Context
cancel context.CancelFunc
}

func (f *Forward) handleConn(conn net.Conn) {
Expand All @@ -26,12 +29,24 @@ func (f *Forward) Run() error {
if err != nil {
return common.NewError("failed to listen local address").Base(err)
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-f.ctx.Done():
return nil
default:
}
logger.Error(err)
continue
}
go f.handleConn(conn)
}
}

func (f *Forward) Close() error {
logger.Info("")
f.cancel()
return nil
}
47 changes: 32 additions & 15 deletions proxy/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package proxy

import (
"context"
"net"

"github.com/p4gefau1t/trojan-go/common"
Expand All @@ -14,38 +15,35 @@ import (

type NAT struct {
common.Runnable
config *conf.GlobalConfig

config *conf.GlobalConfig
ctx context.Context
cancel context.CancelFunc
packetInbound protocol.PacketSession
listener net.Listener
}

func (n *NAT) handleConn(conn net.Conn) {
inbound, err := nat.NewInboundConnSession(conn)
if err != nil {
logger.Error("failed to start inbound session", err)
return
}
req := inbound.GetRequest()
defer inbound.Close()
outbound, err := trojan.NewOutboundConnSession(req, nil, n.config)
if err != nil {
logger.Error("failed to start outbound session", err)
return
}
defer outbound.Close()
logger.Info("transparent nat from", conn.RemoteAddr(), "tunneling to", req)
proxyConn(inbound, outbound)
}

func (n *NAT) listenTCP(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
logger.Error(err)
continue
}
go n.handleConn(conn)
}
}

func (n *NAT) listenUDP() {
inbound, err := nat.NewInboundPacketSession(n.config)
n.packetInbound = inbound
if err != nil {
logger.Error(err)
panic(err)
Expand All @@ -60,6 +58,11 @@ func (n *NAT) listenUDP() {
for {
tunnel, err := trojan.NewOutboundConnSession(&req, nil, n.config)
if err != nil {
select {
case <-n.ctx.Done():
return
default:
}
logger.Error(err)
continue
}
Expand All @@ -70,15 +73,29 @@ func (n *NAT) listenUDP() {
}

func (n *NAT) Run() error {
go n.listenUDP()
logger.Info("nat running at", n.config.LocalAddr)
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: n.config.LocalIP,
Port: int(n.config.LocalPort),
})
if err != nil {
return err
}
go n.listenUDP()
n.listenTCP(tcpListener)
n.listener = listener
defer listener.Close()
for {
conn, err := n.listener.Accept()
if err != nil {
logger.Error(err)
continue
}
go n.handleConn(conn)
}
}

func (n *NAT) Close() error {
logger.Info("shutting down nat...")
n.cancel()
return nil
}
43 changes: 34 additions & 9 deletions proxy/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"context"
"crypto/tls"
"database/sql"
"net"
Expand All @@ -16,10 +17,13 @@ import (
)

type Server struct {
config *conf.GlobalConfig
common.Runnable
auth stat.Authenticator
meter stat.TrafficMeter

auth stat.Authenticator
meter stat.TrafficMeter
config *conf.GlobalConfig
ctx context.Context
cancel context.CancelFunc
}

func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) {
Expand Down Expand Up @@ -104,12 +108,10 @@ func (s *Server) handleConn(conn net.Conn) {
}

func (s *Server) Run() error {
tlsConfig := &tls.Config{
Certificates: s.config.TLS.KeyPair,
CipherSuites: s.config.TLS.CipherSuites,
PreferServerCipherSuites: s.config.TLS.PreferServerCipher,
SessionTicketsDisabled: !s.config.TLS.SessionTicket,
}
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
s.cancel = cancel

var db *sql.DB
var err error
if s.config.MySQL.Enabled {
Expand Down Expand Up @@ -145,6 +147,8 @@ func (s *Server) Run() error {
return common.NewError("failed to init traffic meter").Base(err)
}
}
defer s.auth.Close()
defer s.meter.Close()
logger.Info("Server running at", s.config.LocalAddr)

var listener net.Listener
Expand All @@ -156,15 +160,31 @@ func (s *Server) Run() error {
s.config.LocalIP,
s.config.LocalAddr.String(),
)
if err != nil {
return err
}
} else {
listener, err = net.Listen("tcp", s.config.LocalAddr.String())
if err != nil {
return err
}
}
defer listener.Close()

tlsConfig := &tls.Config{
Certificates: s.config.TLS.KeyPair,
CipherSuites: s.config.TLS.CipherSuites,
PreferServerCipherSuites: s.config.TLS.PreferServerCipher,
SessionTicketsDisabled: !s.config.TLS.SessionTicket,
}
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return nil
default:
}
logger.Warn(err)
continue
}
Expand All @@ -180,5 +200,10 @@ func (s *Server) Run() error {
}
go s.handleConn(tlsConn)
}
}

func (s *Server) Close() error {
logger.Info("shutting down server..")
s.cancel()
return nil
}

0 comments on commit d61907c

Please sign in to comment.