Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed Apr 1, 2020
2 parents 73d2ece + 0996b46 commit 519b4e6
Show file tree
Hide file tree
Showing 29 changed files with 1,130 additions and 1,045 deletions.
43 changes: 43 additions & 0 deletions cert/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package cert

import (
"flag"

"github.com/p4gefau1t/trojan-go/common"
)

type certOption struct {
args *string
common.OptionHandler
}

func (*certOption) Name() string {
return "cert"
}

func (*certOption) Priority() int {
return 10
}

func (c *certOption) Handle() error {
switch *c.args {
case "request":
RequestCertGuide()
return nil
case "renew":
RenewCertGuide()
return nil
case "INVALID":
return common.NewError("not specified")
default:
err := common.NewError("invalid args " + *c.args)
logger.Error(err)
return common.NewError("invalid args")
}
}

func init() {
common.RegisterOptionHandler(&certOption{
args: flag.String("cert", "INVALID", "Simple letsencrpyt cert acme client. Use \"-cert request\" to request a cert or \"-cert renew\" to renew a cert"),
})
}
17 changes: 3 additions & 14 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ package common
import (
"bufio"
"crypto/sha256"
"database/sql"
"fmt"
"io"
"strings"
)

_ "github.com/go-sql-driver/mysql"
//_ "github.com/mattn/go-sqlite3"
const (
Version = "v0.0.15"
)

type Runnable interface {
Expand Down Expand Up @@ -50,13 +49,3 @@ func HumanFriendlyTraffic(bytes int) string {
}
return fmt.Sprintf("%.2f GiB", float32(bytes)/GiB)
}

func ConnectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) {
path := strings.Join([]string{username, ":", password, "@tcp(", ip, ":", fmt.Sprintf("%d", port), ")/", dbName, "?charset=utf8"}, "")
return sql.Open(driverName, path)
}

func ConnectSQLite(dbName string) (*sql.DB, error) {
//for debug only
return sql.Open("sqlite3", dbName)
}
12 changes: 12 additions & 0 deletions common/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package common

import (
"database/sql"
"fmt"
"strings"
)

func ConnectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) {
path := strings.Join([]string{username, ":", password, "@tcp(", ip, ":", fmt.Sprintf("%d", port), ")/", dbName, "?charset=utf8"}, "")
return sql.Open(driverName, path)
}
27 changes: 27 additions & 0 deletions common/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package common

type OptionHandler interface {
Name() string
Handle() error
Priority() int
}

var handlers map[string]OptionHandler = make(map[string]OptionHandler)

func RegisterOptionHandler(h OptionHandler) {
handlers[h.Name()] = h
}

func PopOptionHandler() (OptionHandler, error) {
var maxHandler OptionHandler = nil
for _, h := range handlers {
if maxHandler == nil || maxHandler.Priority() < h.Priority() {
maxHandler = h
}
}
if maxHandler == nil {
return nil, NewError("no option left")
}
delete(handlers, maxHandler.Name())
return maxHandler, nil
}
59 changes: 17 additions & 42 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,31 @@ package main

import (
"flag"
"io/ioutil"
"os"
"os/signal"

"github.com/p4gefau1t/trojan-go/cert"
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/log"
"github.com/p4gefau1t/trojan-go/proxy"

_ "github.com/go-sql-driver/mysql"
_ "github.com/p4gefau1t/trojan-go/cert"
_ "github.com/p4gefau1t/trojan-go/proxy/client"
_ "github.com/p4gefau1t/trojan-go/proxy/forward"
_ "github.com/p4gefau1t/trojan-go/proxy/server"
_ "github.com/p4gefau1t/trojan-go/version"
)

var logger = log.New(os.Stdout)

func main() {
logger.Info("Trojan-Go initializing...")
configFile := flag.String("config", "config.json", "Config filename")
guideMode := flag.String("cert", "", "Simple letsencrpyt cert acme client. Use \"-cert request\" to request a cert or \"-cert renew\" to renew a cert")
flag.Parse()
switch *guideMode {
case "request":
cert.RequestCertGuide()
return
case "renew":
cert.RenewCertGuide()
return
case "":
default:
logger.Error("Invalid cert arg")
return
}
data, err := ioutil.ReadFile(*configFile)
if err != nil {
logger.Fatal("Failed to read config file", err)
}
config, err := conf.ParseJSON(data)
if err != nil {
logger.Fatal("Failed to parse config file", err)
}
proxy := proxy.NewProxy(config)
errChan := make(chan error)
go func() {
errChan <- proxy.Run()
}()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
select {
case <-sigs:
proxy.Close()
case err := <-errChan:
logger.Fatal(err)
for {
h, err := common.PopOptionHandler()
if err != nil {
logger.Fatal("invalid options")
}
err = h.Handle()
if err == nil {
break
}
}
logger.Info("Trojan-Go exited")
}
7 changes: 4 additions & 3 deletions protocol/direct/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ func (o *DirectOutboundConnSession) Read(p []byte) (int, error) {
}

func (o *DirectOutboundConnSession) Write(p []byte) (int, error) {
defer o.bufReadWriter.Flush()
return o.bufReadWriter.Write(p)
n, err := o.bufReadWriter.Write(p)
o.bufReadWriter.Flush()
return n, err
}

func (o *DirectOutboundConnSession) Close() error {
Expand Down Expand Up @@ -105,7 +106,7 @@ func (o *DirectOutboundPacketSession) WritePacket(req *protocol.Request, packet
if err != nil {
return 0, common.NewError("cannot dial udp").Base(err)
}
logger.Info("UDP directly dialing to", remote)
logger.Debug("UDP directly dialing to", remote)
n, err := conn.Write(packet)
return n, err
}
Expand Down
12 changes: 7 additions & 5 deletions protocol/http/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ func (i *HTTPInboundTunnelConnSession) Read(p []byte) (int, error) {
}

func (i *HTTPInboundTunnelConnSession) Write(p []byte) (int, error) {
defer i.bufReadWriter.Flush()
return i.bufReadWriter.Write(p)
n, err := i.bufReadWriter.Write(p)
i.bufReadWriter.Flush()
return n, err
}

func (i *HTTPInboundTunnelConnSession) Close() error {
return i.conn.Close()
}

func (i *HTTPInboundTunnelConnSession) Respond(r io.Reader) error {
func (i *HTTPInboundTunnelConnSession) Respond() error {
payload := fmt.Sprintf("HTTP/%d.%d 200 Connection established\r\n\r\n", i.httpRequest.ProtoMajor, i.httpRequest.ProtoMinor)
_, err := i.Write([]byte(payload))
return err
Expand Down Expand Up @@ -175,6 +176,7 @@ func (i *HTTPInboundPacketSession) ReadPacket() (*protocol.Request, []byte, erro
}

func (i *HTTPInboundPacketSession) WritePacket(req *protocol.Request, packet []byte) (int, error) {
defer i.bufReadWriter.Flush()
return i.bufReadWriter.Write(packet)
n, err := i.bufReadWriter.Write(packet)
i.bufReadWriter.Flush()
return n, err
}
5 changes: 3 additions & 2 deletions protocol/nat/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func (i *NATInboundPacketSession) cleanExpiredSession() {
select {
case <-time.After(protocol.UDPTimeout):
case <-i.ctx.Done():
i.conn.Close()
return
}
}
Expand Down Expand Up @@ -131,7 +132,7 @@ func (i *NATInboundPacketSession) ReadPacket() (*protocol.Request, []byte, error
expire: time.Now().Add(protocol.UDPTimeout),
}
i.tableMutex.Unlock()
logger.Info("tproxy UDP packet from", src, "to", dst)
logger.Debug("tproxy UDP packet from", src, "to", dst)
req := &protocol.Request{
IP: dst.IP,
Port: uint16(dst.Port),
Expand All @@ -147,7 +148,7 @@ func (i *NATInboundPacketSession) ReadPacket() (*protocol.Request, []byte, error

func (i *NATInboundPacketSession) Close() error {
i.cancel()
return i.conn.Close()
return nil
}

func NewInboundPacketSession(config *conf.GlobalConfig) (protocol.PacketSession, error) {
Expand Down
6 changes: 3 additions & 3 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ const (

const (
MaxUDPPacketSize = 1024 * 4
UDPTimeout = time.Second * 6
TCPTimeout = time.Second * 6
UDPTimeout = time.Second * 5
TCPTimeout = time.Second * 5
)

type Request struct {
Expand Down Expand Up @@ -69,7 +69,7 @@ type HasHash interface {
}

type NeedRespond interface {
Respond(io.Reader) error
Respond() error
}

type PacketReader interface {
Expand Down
Loading

0 comments on commit 519b4e6

Please sign in to comment.