diff --git a/wshandle/client.go b/wshandle/client.go index ce2fea73..1d8b1f27 100644 --- a/wshandle/client.go +++ b/wshandle/client.go @@ -1,10 +1,14 @@ package wshandle import ( + "crypto/tls" "log" + "net" + "net/http" "net/url" "os" "os/signal" + "strings" "sync" "time" @@ -25,6 +29,7 @@ type WsConn struct { var wsconn *WsConn var hostP = GetenvDefault("NEXTTRACE_HOSTPORT", "api.leo.moe") +var host, port, fast_ip string func (c *WsConn) keepAlive() { go func() { @@ -105,10 +110,16 @@ func (c *WsConn) messageSendHandler() { } func (c *WsConn) recreateWsConn() { - u := url.URL{Scheme: "wss", Host: hostP, Path: "/v2/ipGeoWs"} + u := url.URL{Scheme: "wss", Host: fast_ip + ":" + port, Path: "/v2/ipGeoWs"} // log.Printf("connecting to %s", u.String()) - - ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + requestHeader := http.Header{ + "Host": []string{host}, + } + dialer := websocket.DefaultDialer + dialer.TLSClientConfig = &tls.Config{ + ServerName: host, + } + ws, _, err := websocket.DefaultDialer.Dial(u.String(), requestHeader) c.Conn = ws if err != nil { log.Println("dial:", err) @@ -129,11 +140,47 @@ func createWsConn() *WsConn { // 设置终端中断通道 interrupt := make(chan os.Signal, 1) signal.Notify(interrupt, os.Interrupt) + // 解析域名 + hostArr := strings.Split(hostP, ":") + // 判断是否有指定端口 + if len(hostArr) > 1 { + // 判断是否为 IPv6 + if strings.HasPrefix(hostP, "[") { + tmp := strings.Split(hostP, "]") + host = tmp[0] + host = host[1:] + if port = tmp[1]; port != "" { + port = port[1:] + } + } else { + host, port = hostArr[0], hostArr[1] + } + } else { + host = hostP + } + if port == "" { + // 默认端口 + port = "443" + } + // 默认配置完成,开始寻找最优 IP + fast_ip = GetFastIP(host, port) - u := url.URL{Scheme: "wss", Host: hostP, Path: "/v2/ipGeoWs"} + // 如果 host 是一个 IP 使用默认域名 + if valid := net.ParseIP(host); valid != nil { + host = "api.leo.moe" + } + // 判断是否是一个 IP + requestHeader := http.Header{ + "Host": []string{host}, + } + dialer := websocket.DefaultDialer + dialer.TLSClientConfig = &tls.Config{ + ServerName: host, + } + u := url.URL{Scheme: "wss", Host: fast_ip + ":" + port, Path: "/v2/ipGeoWs"} // log.Printf("connecting to %s", u.String()) - c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + c, _, err := websocket.DefaultDialer.Dial(u.String(), requestHeader) wsconn = &WsConn{ Conn: c, @@ -169,9 +216,9 @@ func GetWsConn() *WsConn { return wsconn } func GetenvDefault(key, defVal string) string { - val, ok := os.LookupEnv(key) - if ok { - return val - } - return defVal - } + val, ok := os.LookupEnv(key) + if ok { + return val + } + return defVal +} diff --git a/wshandle/latency.go b/wshandle/latency.go new file mode 100644 index 00000000..158c956c --- /dev/null +++ b/wshandle/latency.go @@ -0,0 +1,66 @@ +package wshandle + +import ( + "fmt" + "log" + "net" + "strings" + "time" + + "github.com/fatih/color" +) + +var ( + result string + results = make(chan string) +) + +func GetFastIP(domain string, port string) string { + + ips, err := net.LookupIP(domain) + if err != nil { + log.Fatal("DNS 解析失败,请检查您的系统 DNS 设置") + return "" + } + + for _, ip := range ips { + go checkLatency(ip.String(), port) + } + + select { + case result = <-results: + case <-time.After(1 * time.Second): + + } + if result == "" { + log.Fatal("IP 连接均超时,请检查您的网络") + } + res := strings.Split(result, "-") + + if len(ips) > 1 { + fmt.Fprintf(color.Output, "%s 已为您优选最近的节点 %s - %s\n", + color.New(color.FgWhite, color.Bold).Sprintf("[NextTrace API]"), + color.New(color.FgGreen, color.Bold).Sprintf("%s", res[0]), + color.New(color.FgCyan, color.Bold).Sprintf("%sms", res[1]), + ) + } + + return res[0] +} + +func checkLatency(ip string, port string) { + start := time.Now() + if !strings.Contains(ip, ".") { + ip = "[" + ip + "]" + } + conn, err := net.DialTimeout("tcp", ip+":"+port, time.Second*1) + if err != nil { + return + } + defer conn.Close() + if result == "" { + result = fmt.Sprintf("%s-%.2f", ip, float64(time.Since(start))/float64(time.Millisecond)) + results <- result + return + } +} diff --git a/wshandle/latency_test.go b/wshandle/latency_test.go new file mode 100644 index 00000000..ba5ab82c --- /dev/null +++ b/wshandle/latency_test.go @@ -0,0 +1,9 @@ +package wshandle + +import ( + "testing" +) + +func TestGetFastIP(t *testing.T) { + GetFastIP("api.leo.moe", "443") +}