Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed Mar 27, 2020
2 parents 433e7e3 + 49e36d5 commit c9ba8e2
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 68 deletions.
27 changes: 21 additions & 6 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"io/ioutil"
"os"
Expand All @@ -25,6 +24,8 @@ import (

var logger = log.New(os.Stdout)
var caDir string = "https://acme-v02.api.letsencrypt.org/directory"
var tlsPort string = "443"
var httpPort string = "80"

type User struct {
Email string
Expand Down Expand Up @@ -79,9 +80,6 @@ func loadUserKey() (*ecdsa.PrivateKey, error) {
func saveServerKeyAndCert(cert *certificate.Resource) error {
ioutil.WriteFile("server.key", cert.PrivateKey, os.ModePerm)
ioutil.WriteFile("server.crt", cert.Certificate, os.ModePerm)
data, err := json.Marshal(cert)
common.Must(err)
ioutil.WriteFile("server.json", data, os.ModePerm)
return nil
}

Expand Down Expand Up @@ -122,12 +120,12 @@ func obtainCertificate(domain, email string, userKey *ecdsa.PrivateKey, serverKe
// (used later when we attempt to pass challenges). Keep in mind that you still
// need to proxy challenge traffic to port 5002 and 5001.
//err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "5002"))
err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", ""))
err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", httpPort))
if err != nil {
return nil, err
}
//err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", "5001"))
err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", ""))
err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", tlsPort))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -159,7 +157,24 @@ func obtainCertificate(domain, email string, userKey *ecdsa.PrivateKey, serverKe
return certificates, nil
}

func isFilesExist(nameList []string) bool {
fileInfo, err := ioutil.ReadDir("./")
common.Must(err)
for _, v := range fileInfo {
name := v.Name()
for _, u := range nameList {
if name == u {
return true
}
}
}
return false
}

func RequestCert(domain, email string) error {
if isFilesExist([]string{"server.key", "server.crt"}) {
return common.NewError("cert files(server.key, server.crt) already exist")
}
userKey, err := loadUserKey()
if err != nil {
logger.Warn("failed to load user key, trying to create one..")
Expand Down
8 changes: 4 additions & 4 deletions cert/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (

func TestCreate(t *testing.T) {
caDir = "https://127.0.0.1:14000/dir"
tlsPort = "5001"
httpPort = "5002"
common.Must(RequestCert("localhost", "[email protected]"))
}

func TestRenew(t *testing.T) {
caDir = "https://127.0.0.1:14000/dir"
tlsPort = "5001"
httpPort = "5002"
common.Must(RenewCert("localhost", "[email protected]"))
}

func TestCertGuide(t *testing.T) {
RequestCertGuide()
}
2 changes: 0 additions & 2 deletions cert/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ func askForConfirmation() bool {
}

func RequestCertGuide() {
//caDir = "https://127.0.0.1:14000/dir"
logger.Info("Guide mode: request cert")

logger.Warn("To perform a ACME challenge, trojan-go need the ROOT PRIVILEGE to bind port 80 and 443")
Expand Down Expand Up @@ -96,7 +95,6 @@ func RequestCertGuide() {
}

func RenewCertGuide() {
//caDir = "https://127.0.0.1:14000/dir"
logger.Info("Guide mode: renew cert")

logger.Warn("To perform a ACME challenge, trojan-go need the ROOT PRIVILEGE to bind port 80 and 443")
Expand Down
1 change: 1 addition & 0 deletions conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type MySQLConfig struct {
Database string `json:"database"`
Username string `json:"username"`
Password string `json:"password"`
CheckRate int `json:"check_rate"`
}

type SQLiteConfig struct {
Expand Down
1 change: 1 addition & 0 deletions conf/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func ParseJSON(data []byte) (*GlobalConfig, error) {
config.TLS.VerifyHostname = true
config.TLS.SessionTicket = true
config.TCP.MuxIdleTimeout = 5
config.MySQL.CheckRate = 60

err := json.Unmarshal(data, &config)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ func main() {
case "renew":
cert.RenewCertGuide()
return
case "":
default:
logger.Error("Invalid cert arg")
return
}
data, err := ioutil.ReadFile(*configFile)
if err != nil {
Expand All @@ -44,6 +47,7 @@ func main() {

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
logger.Info("Trojan-Go interrupted")
select {
case <-sigs:
proxy.Close()
Expand Down
2 changes: 1 addition & 1 deletion proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ func (c *Client) Run() error {

func (c *Client) Close() error {
logger.Info("shutting down client..")
c.cancel()
c.muxClientLock.Lock()
defer c.muxClientLock.Unlock()
if c.muxClient != nil {
c.muxClient.Close()
}
c.cancel()
return nil
}
17 changes: 11 additions & 6 deletions proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import (
type Server struct {
common.Runnable

auth stat.Authenticator
meter stat.TrafficMeter
config *conf.GlobalConfig
ctx context.Context
cancel context.CancelFunc
listener net.Listener
auth stat.Authenticator
meter stat.TrafficMeter
config *conf.GlobalConfig
ctx context.Context
cancel context.CancelFunc
}

func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) {
Expand Down Expand Up @@ -173,7 +174,7 @@ func (s *Server) Run() error {
if err != nil {
return common.NewError("failed to init auth").Base(err)
}
s.meter, err = stat.NewDBTrafficMeter(db)
s.meter, err = stat.NewDBTrafficMeter(s.config, db)
if err != nil {
return common.NewError("failed to init traffic meter").Base(err)
}
Expand All @@ -200,6 +201,7 @@ func (s *Server) Run() error {
return err
}
}
s.listener = listener
defer listener.Close()

tlsConfig := &tls.Config{
Expand Down Expand Up @@ -232,6 +234,9 @@ func (s *Server) Run() error {

func (s *Server) Close() error {
logger.Info("shutting down server..")
if s.listener != nil {
s.listener.Close()
}
s.cancel()
return nil
}
71 changes: 29 additions & 42 deletions stat/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
)

type trafficInfo struct {
Expand All @@ -17,16 +18,13 @@ type trafficInfo struct {

type DBTrafficMeter struct {
TrafficMeter
db *sql.DB
trafficChan chan *trafficInfo
ctx context.Context
cancel context.CancelFunc
db *sql.DB
trafficChan chan *trafficInfo
ctx context.Context
cancel context.CancelFunc
updateDuration time.Duration
}

const (
statsUpdateDuration = time.Second * 5
)

func (c *DBTrafficMeter) Count(passwordHash string, sent int, recv int) {
c.trafficChan <- &trafficInfo{
passwordHash: passwordHash,
Expand Down Expand Up @@ -56,12 +54,12 @@ func (c *DBTrafficMeter) dbDaemon() {
}
t.sent += u.sent
t.recv += u.recv
case <-time.After(statsUpdateDuration):
case <-time.After(c.updateDuration):
break
case <-c.ctx.Done():
return
}
if time.Now().Sub(beginTime) > statsUpdateDuration {
if time.Now().Sub(beginTime) > c.updateDuration {
break
}
}
Expand Down Expand Up @@ -97,25 +95,12 @@ func (c *DBTrafficMeter) dbDaemon() {
}
}

func NewDBTrafficMeter(db *sql.DB) (TrafficMeter, error) {
/*
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS users (
id INT UNSIGNED AUTO_INCREMENT,
username VARCHAR(64) NOT NULL,
password CHAR(56) NOT NULL,
quota BIGINT NOT NULL DEFAULT 0,
download BIGINT UNSIGNED NOT NULL DEFAULT 0,
upload BIGINT UNSIGNED NOT NULL DEFAULT 0,
PRIMARY KEY (id)
);`)
if err != nil {
logger.Warn(common.NewError("cannot check and create table").Base(err))
}
*/
func NewDBTrafficMeter(config *conf.GlobalConfig, db *sql.DB) (TrafficMeter, error) {
c := &DBTrafficMeter{
db: db,
trafficChan: make(chan *trafficInfo, 1024),
ctx: context.Background(),
db: db,
trafficChan: make(chan *trafficInfo, 1024),
ctx: context.Background(),
updateDuration: time.Duration(config.MySQL.CheckRate) * time.Second,
}
go c.dbDaemon()
return c, nil
Expand All @@ -130,10 +115,11 @@ type userInfo struct {
}

type DBAuthenticator struct {
db *sql.DB
validUsers sync.Map
ctx context.Context
cancel context.CancelFunc
db *sql.DB
validUsers sync.Map
ctx context.Context
cancel context.CancelFunc
updateDuration time.Duration
Authenticator
}

Expand All @@ -147,23 +133,23 @@ func (a *DBAuthenticator) CheckHash(hash string) bool {

func (a *DBAuthenticator) updateDaemon() {
for {
rows, err := a.db.Query("SELECT username,password,quota,download,upload FROM users")
rows, err := a.db.Query("SELECT password,quota,download,upload FROM users")
if err != nil {
logger.Error(common.NewError("failed to pull data from the database").Base(err))
time.Sleep(statsUpdateDuration)
time.Sleep(a.updateDuration)
continue
}
newValidUsers := make(map[string]string)
for rows.Next() {
var username, passwordHash string
var passwordHash string
var quota, download, upload int64
err := rows.Scan(&username, &passwordHash, &quota, &download, &upload)
err := rows.Scan(&passwordHash, &quota, &download, &upload)
if err != nil {
logger.Error(common.NewError("failed to obtain data from the query result").Base(err))
break
}
if download+upload < quota || quota < 0 {
newValidUsers[passwordHash] = username
newValidUsers[passwordHash] = "valid"
}
}
//delete those out of quota
Expand All @@ -177,7 +163,7 @@ func (a *DBAuthenticator) updateDaemon() {
a.validUsers.Store(k, v)
}
select {
case <-time.After(statsUpdateDuration):
case <-time.After(a.updateDuration):
break
case <-a.ctx.Done():
return
Expand All @@ -190,12 +176,13 @@ func (a *DBAuthenticator) Close() error {
return a.db.Close()
}

func NewDBAuthenticator(db *sql.DB) (Authenticator, error) {
func NewDBAuthenticator(config *conf.GlobalConfig, db *sql.DB) (Authenticator, error) {
ctx, cancel := context.WithCancel(context.Background())
a := &DBAuthenticator{
db: db,
cancel: cancel,
ctx: ctx,
db: db,
cancel: cancel,
ctx: ctx,
updateDuration: time.Duration(config.MySQL.CheckRate) * time.Second,
}
go a.updateDaemon()
return a, nil
Expand Down
22 changes: 16 additions & 6 deletions stat/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

_ "github.com/go-sql-driver/mysql"
"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
)

func TestDBTrafficMeter(t *testing.T) {
Expand All @@ -21,16 +22,18 @@ func TestDBTrafficMeter(t *testing.T) {
dbName := "trojan"
path := strings.Join([]string{userName, ":", password, "@tcp(", ip, ":", port, ")/", dbName, "?charset=utf8"}, "")
db, err := sql.Open("mysql", path)
hash := common.SHA224String("hashhash")
common.Must(err)
defer db.Close()
c := &DBTrafficMeter{
db: db,
trafficChan: make(chan *trafficInfo, 1024),
ctx: context.Background(),
db: db,
trafficChan: make(chan *trafficInfo, 1024),
ctx: context.Background(),
updateDuration: time.Second * 5,
}
simulation := func() {
for i := 0; i < 100; i++ {
c.Count("hashhash", rand.Intn(500), rand.Intn(500))
c.Count(hash, rand.Intn(500), rand.Intn(500))
time.Sleep(time.Duration(int64(time.Millisecond) * rand.Int63n(300)))
}
fmt.Println("done")
Expand All @@ -52,8 +55,15 @@ func TestDBAuthenticator(t *testing.T) {
db, err := sql.Open("mysql", path)
common.Must(err)
defer db.Close()
a, err := NewDBAuthenticator(db)
config := conf.GlobalConfig{
MySQL: conf.MySQLConfig{
CheckRate: 2,
},
}
a, err := NewDBAuthenticator(&config, db)
common.Must(err)
time.Sleep(time.Second * 5)
fmt.Println(a.CheckHash("hashhash"), a.CheckHash("jasdlkflfejlqjef"))
hash := common.SHA224String("hashhash")
fmt.Println(common.SHA224String("hashhash"))
fmt.Println(a.CheckHash(hash), a.CheckHash("jasdlkflfejlqjef"))
}
Loading

0 comments on commit c9ba8e2

Please sign in to comment.