diff --git a/cert/cert.go b/cert/cert.go index 4f9e91055..33daf4c48 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -24,6 +24,8 @@ import ( var logger = log.New(os.Stdout) var caDir string = "https://acme-v02.api.letsencrypt.org/directory" +var tlsPort string = "443" +var httpPort string = "80" type User struct { Email string @@ -118,12 +120,12 @@ func obtainCertificate(domain, email string, userKey *ecdsa.PrivateKey, serverKe // (used later when we attempt to pass challenges). Keep in mind that you still // need to proxy challenge traffic to port 5002 and 5001. //err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "5002")) - err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "")) + err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", httpPort)) if err != nil { return nil, err } //err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", "5001")) - err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", "")) + err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", tlsPort)) if err != nil { return nil, err } @@ -155,7 +157,24 @@ func obtainCertificate(domain, email string, userKey *ecdsa.PrivateKey, serverKe return certificates, nil } +func isFilesExist(nameList []string) bool { + fileInfo, err := ioutil.ReadDir("./") + common.Must(err) + for _, v := range fileInfo { + name := v.Name() + for _, u := range nameList { + if name == u { + return true + } + } + } + return false +} + func RequestCert(domain, email string) error { + if isFilesExist([]string{"server.key", "server.crt"}) { + return common.NewError("cert files(server.key, server.crt) already exist") + } userKey, err := loadUserKey() if err != nil { logger.Warn("failed to load user key, trying to create one..") diff --git a/cert/cert_test.go b/cert/cert_test.go index d83abaebf..b5e9c4894 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -8,14 +8,14 @@ import ( func TestCreate(t *testing.T) { caDir = "https://127.0.0.1:14000/dir" + tlsPort = "5001" + httpPort = "5002" common.Must(RequestCert("localhost", "test@email.com")) } func TestRenew(t *testing.T) { caDir = "https://127.0.0.1:14000/dir" + tlsPort = "5001" + httpPort = "5002" common.Must(RenewCert("localhost", "test@email.com")) } - -func TestCertGuide(t *testing.T) { - RequestCertGuide() -} diff --git a/cert/cli.go b/cert/cli.go index 5792268d2..0420b697f 100644 --- a/cert/cli.go +++ b/cert/cli.go @@ -46,7 +46,6 @@ func askForConfirmation() bool { } func RequestCertGuide() { - //caDir = "https://127.0.0.1:14000/dir" logger.Info("Guide mode: request cert") logger.Warn("To perform a ACME challenge, trojan-go need the ROOT PRIVILEGE to bind port 80 and 443") @@ -96,7 +95,6 @@ func RequestCertGuide() { } func RenewCertGuide() { - //caDir = "https://127.0.0.1:14000/dir" logger.Info("Guide mode: renew cert") logger.Warn("To perform a ACME challenge, trojan-go need the ROOT PRIVILEGE to bind port 80 and 443") diff --git a/main.go b/main.go index adfa3ebf1..06ed2db9b 100644 --- a/main.go +++ b/main.go @@ -26,7 +26,10 @@ func main() { case "renew": cert.RenewCertGuide() return + case "": default: + logger.Error("Invalid cert arg") + return } data, err := ioutil.ReadFile(*configFile) if err != nil { @@ -44,6 +47,7 @@ func main() { sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) + logger.Info("Trojan-Go interrupted") select { case <-sigs: proxy.Close() diff --git a/proxy/client.go b/proxy/client.go index 9ec0d760f..0e50a2efa 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -258,11 +258,11 @@ func (c *Client) Run() error { func (c *Client) Close() error { logger.Info("shutting down client..") - c.cancel() c.muxClientLock.Lock() defer c.muxClientLock.Unlock() if c.muxClient != nil { c.muxClient.Close() } + c.cancel() return nil } diff --git a/proxy/server.go b/proxy/server.go index 31670ef81..f31283803 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -20,11 +20,12 @@ import ( type Server struct { common.Runnable - auth stat.Authenticator - meter stat.TrafficMeter - config *conf.GlobalConfig - ctx context.Context - cancel context.CancelFunc + listener net.Listener + auth stat.Authenticator + meter stat.TrafficMeter + config *conf.GlobalConfig + ctx context.Context + cancel context.CancelFunc } func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) { @@ -200,6 +201,7 @@ func (s *Server) Run() error { return err } } + s.listener = listener defer listener.Close() tlsConfig := &tls.Config{ @@ -232,6 +234,9 @@ func (s *Server) Run() error { func (s *Server) Close() error { logger.Info("shutting down server..") + if s.listener != nil { + s.listener.Close() + } s.cancel() return nil }