forked from txthinking/brook
-
Notifications
You must be signed in to change notification settings - Fork 0
/
websocket.go
118 lines (114 loc) · 2.89 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright (c) 2016-present Cloud <[email protected]>
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of version 3 of the GNU General Public
// License as published by the Free Software Foundation.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package brook
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"time"
x1 "github.com/txthinking/x"
)
func WebSocketDial(src, dst, addr, host, path string, tc *tls.Config, timeout int) (net.Conn, error) {
var c net.Conn
var err error
if src == "" || dst == "" {
c, err = DialTCP("tcp", "", addr)
}
if src != "" && dst != "" {
c, err = NATDial("tcp", src, dst, addr)
}
if err != nil {
return nil, err
}
if timeout != 0 {
if err := c.SetDeadline(time.Now().Add(time.Duration(timeout) * time.Second)); err != nil {
c.Close()
return nil, err
}
}
if tc != nil {
c1 := tls.Client(c, tc)
if !tc.InsecureSkipVerify {
if err := c1.Handshake(); err != nil {
c1.Close()
return nil, err
}
s := host
h, _, err := net.SplitHostPort(host)
if err == nil {
s = h
}
if err := c1.VerifyHostname(s); err != nil {
c1.Close()
return nil, err
}
}
c = c1
}
p := x1.BP16.Get().([]byte)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
x1.BP16.Put(p)
c.Close()
return nil, err
}
k := base64.StdEncoding.EncodeToString(p)
x1.BP16.Put(p)
b := make([]byte, 0, 300)
b = append(b, []byte("GET "+path+" HTTP/1.1\r\n")...)
b = append(b, []byte(fmt.Sprintf("Host: %s\r\n", host))...)
b = append(b, []byte("Upgrade: websocket\r\n")...)
b = append(b, []byte("Connection: Upgrade\r\n")...)
b = append(b, []byte(fmt.Sprintf("Sec-WebSocket-Key: %s\r\n", k))...)
b = append(b, []byte("Sec-WebSocket-Version: 13\r\n\r\n")...)
if _, err := c.Write(b); err != nil {
c.Close()
return nil, err
}
r := bufio.NewReader(c)
for {
b, err = r.ReadBytes('\n')
if err != nil {
c.Close()
return nil, err
}
b = bytes.TrimSpace(b)
if len(b) == 0 {
break
}
if bytes.HasPrefix(b, []byte("HTTP/1.1 ")) {
if !bytes.Contains(b, []byte("101")) {
c.Close()
return nil, errors.New(string(b))
}
}
if bytes.HasPrefix(b, []byte("Sec-WebSocket-Accept: ")) {
h := sha1.New()
h.Write([]byte(k))
h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
ak := base64.StdEncoding.EncodeToString(h.Sum(nil))
if string(b[len("Sec-WebSocket-Accept: "):]) != ak {
c.Close()
return nil, errors.New(string(b))
}
}
}
return c, nil
}