diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..854fa9a2 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,30 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/go +{ + "name": "Go", + "image": "mcr.microsoft.com/devcontainers/go", + "features": { + "ghcr.io/guiyomh/features/golangci-lint:0": {}, + "ghcr.io/devcontainers-contrib/features/go-task:1": {} + }, + "postCreateCommand": "go mod download", + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "go version", + // Configure tool-specific properties. + "customizations": { + "vscode": { + "extensions": [ + "golang.go", + "shardulm94.trailing-spaces", + "IBM.output-colorizer", + "task.vscode-task", + "github.vscode-github-actions", + "redhat.vscode-yaml" + ] + } + } +} \ No newline at end of file diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 070040e4..1902caaf 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -42,7 +42,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -53,7 +53,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 # ℹī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -67,4 +67,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 06e6b5df..f2169d74 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -36,6 +36,6 @@ jobs: uses: docker/build-push-action@v5 with: push: true - platforms: linux/amd64,linux/arm64,linux/arm/v6,linux/arm/v7 + platforms: linux/amd64,linux/arm/v7,linux/arm64/v8,linux/386,linux/ppc64le tags: | ghcr.io/firefart/stunner:latest diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 7632f2b6..62ecb65c 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -4,20 +4,20 @@ jobs: build: name: Build runs-on: ubuntu-latest - strategy: - matrix: - go: ["stable"] steps: - - name: Set up Go ${{ matrix.go }} + - name: Set up Go uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go }} + go-version: "stable" + + - name: Install Task + uses: arduino/setup-task@v2 - name: Check out code uses: actions/checkout@v4 - name: build cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} @@ -28,8 +28,11 @@ jobs: run: | go get -v -t -d ./... - - name: Build - run: make build + - name: Build linux + run: task linux + + - name: Build windows + run: task windows - name: Test - run: make test + run: task test diff --git a/.github/workflows/update.yml b/.github/workflows/update.yml index fec7c19e..7cb404cf 100644 --- a/.github/workflows/update.yml +++ b/.github/workflows/update.yml @@ -19,9 +19,12 @@ jobs: with: go-version: "stable" + - name: Install Task + uses: arduino/setup-task@v2 + - name: update run: | - make update + task update - name: setup git config run: | diff --git a/.gitignore b/.gitignore index 72e75567..c19698fe 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ config.json *.secret *.env stunner -*.sh \ No newline at end of file +*.sh diff --git a/Makefile b/Makefile deleted file mode 100644 index ae27b920..00000000 --- a/Makefile +++ /dev/null @@ -1,42 +0,0 @@ -.DEFAULT_GOAL := build - -.PHONY: update -update: - go get -u - go mod tidy - -.PHONY: build -build: test - go fmt ./... - go vet ./... - go build - -.PHONY: run -run: build - ./stunner - -.PHONY: lint -lint: - "$$(go env GOPATH)/bin/golangci-lint" run ./... - go mod tidy - -.PHONY: lint-update -lint-update: - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $$(go env GOPATH)/bin - $$(go env GOPATH)/bin/golangci-lint --version - -.PHONY: test -test: - go test -race -cover ./... - -.PHONY: tag -tag: - @[ "${TAG}" ] && echo "Tagging a new version ${TAG}" || ( echo "TAG is not set"; exit 1 ) - git tag -a "${TAG}" -m "${TAG}" - git push origin "${TAG}" - -.PHONY: windows -windows: - GOOS=windows GOARCH=amd64 go fmt ./... - GOOS=windows GOARCH=amd64 go vet ./... - GOOS=windows GOARCH=amd64 go build diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 00000000..8e247a34 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,64 @@ +version: "3" + +vars: + PROGRAM: stunner + +tasks: + update: + cmds: + - go get -u + - go mod tidy -v + + build: + aliases: [default] + cmds: + - go fmt ./... + - go vet ./... + - go build -o {{.OUTPUT_FILE | default .PROGRAM}} + env: + CGO_ENABLED: 0 + + linux: + cmds: + - task: build + env: + CGO_ENABLED: 0 + GOOS: linux + GOARCH: amd64 + + windows: + cmds: + - task: build + vars: + OUTPUT_FILE: "{{.PROGRAM}}.exe" + env: + CGO_ENABLED: 0 + GOOS: windows + GOARCH: amd64 + + test: + env: + CGO_ENABLED: 1 # required by -race + cmds: + - go test -race -cover ./... + + lint: + cmds: + - golangci-lint run ./... --timeout=30m + - go mod tidy + + lint-update: + cmds: + - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b {{ .GOPATH }}/bin + - golangci-lint --version + vars: + GOPATH: + sh: go env GOPATH + + tag: + cmds: + - git tag -a "${TAG}" -m "${TAG}" + - git push origin "${TAG}" + preconditions: + - sh: '[[ -n "${TAG}" ]]' + msg: "Please set the TAG environment variable" diff --git a/go.mod b/go.mod index 73e1f3db..d31c8645 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module github.com/firefart/stunner go 1.21 require ( - github.com/firefart/gosocks v0.3.0 - github.com/pion/dtls/v2 v2.2.8 + github.com/firefart/gosocks v0.4.1 + github.com/pion/dtls/v2 v2.2.10 github.com/sirupsen/logrus v1.9.3 - github.com/urfave/cli/v2 v2.26.0 + github.com/urfave/cli/v2 v2.27.1 ) require ( @@ -14,8 +14,8 @@ require ( github.com/pion/logging v0.2.2 // indirect github.com/pion/transport/v2 v2.2.4 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect - github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect - golang.org/x/crypto v0.16.0 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sys v0.15.0 // indirect + github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e // indirect + golang.org/x/crypto v0.18.0 // indirect + golang.org/x/net v0.20.0 // indirect + golang.org/x/sys v0.16.0 // indirect ) diff --git a/go.sum b/go.sum index ea972c5d..07f482dd 100644 --- a/go.sum +++ b/go.sum @@ -3,13 +3,12 @@ github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/firefart/gosocks v0.3.0 h1:k8cVQmHQipzdOgpiwHAxKUxXE+cwk5W1i6rteoQMzWM= -github.com/firefart/gosocks v0.3.0/go.mod h1:Kboswm/Albj/QPeuVNAfyp3j7zxfrBMr8B1IhWd5EK0= -github.com/pion/dtls/v2 v2.2.8 h1:BUroldfiIbV9jSnC6cKOMnyiORRWrWWpV11JUyEu5OA= -github.com/pion/dtls/v2 v2.2.8/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= +github.com/firefart/gosocks v0.4.1 h1:E5l/BrJE9yPNlOZ6WymBNeEFd5sP8QkM40QonpsfAJY= +github.com/firefart/gosocks v0.4.1/go.mod h1:zflfN1fX57OOzUz6nAZyem3xUf2aIGSFwWbcFrKtjto= +github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= +github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -24,31 +23,28 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/urfave/cli/v2 v2.26.0 h1:3f3AMg3HpThFNT4I++TKOejZO8yU55t3JnnSr4S4QEI= -github.com/urfave/cli/v2 v2.26.0/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= -github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= -github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +github.com/urfave/cli/v2 v2.27.1 h1:8xSQ6szndafKVRmfyeUMxkNUJQMjL1F2zmsZ+qHpfho= +github.com/urfave/cli/v2 v2.27.1/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= +github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e h1:+SOyEddqYF09QP7vr7CgJ1eti3pY9Fn3LHO1M1r/0sI= +github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -59,23 +55,23 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/internal/cmd/bruteforce.go b/internal/cmd/bruteforce.go index 066c5209..7d776e24 100644 --- a/internal/cmd/bruteforce.go +++ b/internal/cmd/bruteforce.go @@ -2,6 +2,7 @@ package cmd import ( "bufio" + "context" "fmt" "os" "strings" @@ -43,7 +44,7 @@ func (opts BruteforceOpts) Validate() error { return nil } -func BruteForce(opts BruteforceOpts) error { +func BruteForce(ctx context.Context, opts BruteforceOpts) error { if err := opts.Validate(); err != nil { return err } @@ -56,7 +57,7 @@ func BruteForce(opts BruteforceOpts) error { scanner := bufio.NewScanner(pfile) for scanner.Scan() { - if err := testPassword(opts, scanner.Text()); err != nil { + if err := testPassword(ctx, opts, scanner.Text()); err != nil { return err } } @@ -67,15 +68,15 @@ func BruteForce(opts BruteforceOpts) error { return nil } -func testPassword(opts BruteforceOpts, password string) error { - remote, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func testPassword(ctx context.Context, opts BruteforceOpts, password string) error { + remote, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return err } addressFamily := internal.AllocateProtocolIgnore allocateRequest := internal.AllocateRequest(internal.RequestedTransportUDP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending AllocateRequest: %w", err) } @@ -87,7 +88,7 @@ func testPassword(opts BruteforceOpts, password string) error { nonce := string(allocateResponse.GetAttribute(internal.AttrNonce).Value) allocateRequest = internal.AllocateRequestAuth(opts.Username, password, nonce, realm, internal.RequestedTransportUDP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending AllocateRequest Auth: %w", err) } diff --git a/internal/cmd/brutetransports.go b/internal/cmd/brutetransports.go index 600459b8..24b1d522 100644 --- a/internal/cmd/brutetransports.go +++ b/internal/cmd/brutetransports.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "strings" "time" @@ -42,20 +43,20 @@ func (opts BruteTransportOpts) Validate() error { return nil } -func BruteTransports(opts BruteTransportOpts) error { +func BruteTransports(ctx context.Context, opts BruteTransportOpts) error { if err := opts.Validate(); err != nil { return err } for i := 0; i <= 255; i++ { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return err } x := internal.RequestedTransport(uint32(i)) allocateRequest := internal.AllocateRequest(x, internal.AllocateProtocolIgnore) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return fmt.Errorf("error on sending allocate request: %w", err) } @@ -64,7 +65,7 @@ func BruteTransports(opts BruteTransportOpts) error { nonce := string(allocateResponse.GetAttribute(internal.AttrNonce).Value) allocateRequest = internal.AllocateRequestAuth(opts.Username, opts.Password, nonce, realm, x, internal.AllocateProtocolIgnore) - allocateResponse, err = allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return fmt.Errorf("error on sending allocate request auth: %w", err) } diff --git a/internal/cmd/info.go b/internal/cmd/info.go index cc867a4d..0470409b 100644 --- a/internal/cmd/info.go +++ b/internal/cmd/info.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "strings" "time" @@ -35,12 +36,12 @@ func (opts InfoOpts) Validate() error { return nil } -func Info(opts InfoOpts) error { +func Info(ctx context.Context, opts InfoOpts) error { if err := opts.Validate(); err != nil { return err } - if attr, err := testStun(opts); err != nil { + if attr, err := testStun(ctx, opts); err != nil { opts.Log.Debugf("STUN error: %v", err) opts.Log.Error("this server does not support the STUN protocol") } else { @@ -48,7 +49,7 @@ func Info(opts InfoOpts) error { printAttributes(opts, attr) } - if attr, err := testTurn(opts, internal.RequestedTransportUDP); err != nil { + if attr, err := testTurn(ctx, opts, internal.RequestedTransportUDP); err != nil { opts.Log.Debugf("TURN UDP error: %v", err) opts.Log.Error("this server does not support the TURN UDP protocol") } else { @@ -56,7 +57,7 @@ func Info(opts InfoOpts) error { printAttributes(opts, attr) } - if attr, err := testTurn(opts, internal.RequestedTransportTCP); err != nil { + if attr, err := testTurn(ctx, opts, internal.RequestedTransportTCP); err != nil { opts.Log.Debugf("TURN TCP error: %v", err) opts.Log.Error("this server does not support the TURN TCP protocol") } else { @@ -67,15 +68,15 @@ func Info(opts InfoOpts) error { return nil } -func testStun(opts InfoOpts) ([]internal.Attribute, error) { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func testStun(ctx context.Context, opts InfoOpts) ([]internal.Attribute, error) { + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return nil, err } defer conn.Close() bindingRequest := internal.BindingRequest() - bindingResponse, err := bindingRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + bindingResponse, err := bindingRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return nil, fmt.Errorf("error on sending binding request: %w", err) } @@ -86,15 +87,15 @@ func testStun(opts InfoOpts) ([]internal.Attribute, error) { return bindingResponse.Attributes, nil } -func testTurn(opts InfoOpts, proto internal.RequestedTransport) ([]internal.Attribute, error) { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func testTurn(ctx context.Context, opts InfoOpts, proto internal.RequestedTransport) ([]internal.Attribute, error) { + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return nil, err } defer conn.Close() allocateRequest := internal.AllocateRequest(proto, internal.AllocateProtocolIgnore) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return nil, fmt.Errorf("error on sending allocate request: %w", err) } diff --git a/internal/cmd/memoryleak.go b/internal/cmd/memoryleak.go index c67a3040..5fed900f 100644 --- a/internal/cmd/memoryleak.go +++ b/internal/cmd/memoryleak.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "net/netip" "strings" @@ -56,12 +57,12 @@ func (opts MemoryleakOpts) Validate() error { return nil } -func MemoryLeak(opts MemoryleakOpts) error { +func MemoryLeak(ctx context.Context, opts MemoryleakOpts) error { if err := opts.Validate(); err != nil { return err } - remote, realm, nonce, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, opts.TargetHost, opts.TargetPort, opts.Username, opts.Password) + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, opts.TargetHost, opts.TargetPort, opts.Username, opts.Password) if err != nil { return err } @@ -76,7 +77,7 @@ func MemoryLeak(opts MemoryleakOpts) error { return fmt.Errorf("error on generating ChannelBind request: %w", err) } opts.Log.Debugf("ChannelBind Request:\n%s", channelBindRequest.String()) - channelBindResponse, err := channelBindRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending ChannelBind request: %w", err) } @@ -91,7 +92,7 @@ func MemoryLeak(opts MemoryleakOpts) error { toSend = append(toSend, helper.PutUint16(opts.Size)...) toSend = append(toSend, []byte("xxx")...) toSend = internal.Padding(toSend) - err := helper.ConnectionWrite(remote, toSend, opts.Timeout) + err := helper.ConnectionWrite(ctx, remote, toSend, opts.Timeout) if err != nil { return fmt.Errorf("error on sending data: %w", err) } diff --git a/internal/cmd/rangescan.go b/internal/cmd/rangescan.go index a2b21450..914a7aa6 100644 --- a/internal/cmd/rangescan.go +++ b/internal/cmd/rangescan.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "errors" "fmt" "net/netip" @@ -45,7 +46,7 @@ func (opts RangeScanOpts) Validate() error { return nil } -func RangeScan(opts RangeScanOpts) error { +func RangeScan(ctx context.Context, opts RangeScanOpts) error { if err := opts.Validate(); err != nil { return err } @@ -105,7 +106,7 @@ func RangeScan(opts RangeScanOpts) error { return fmt.Errorf("target is no valid ip address: %w", err) } - suc, err := scanUDP(opts, ip, 80) + suc, err := scanUDP(ctx, opts, ip, 80) if err != nil { opts.Log.Errorf("UDP %s: %v", ip, err) } @@ -121,7 +122,7 @@ func RangeScan(opts RangeScanOpts) error { return fmt.Errorf("target is no valid ip address: %w", err) } - suc, err := scanTCP(opts, ip, 80) + suc, err := scanTCP(ctx, opts, ip, 80) if err != nil { opts.Log.Errorf("TCP %s: %v", ip, err) } @@ -132,8 +133,8 @@ func RangeScan(opts RangeScanOpts) error { return nil } -func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func scanTCP(ctx context.Context, opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return false, err } @@ -145,7 +146,7 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool } allocateRequest := internal.AllocateRequest(internal.RequestedTransportTCP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return false, fmt.Errorf("error on sending allocate request 1: %w", err) } @@ -157,7 +158,7 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool nonce := string(allocateResponse.GetAttribute(internal.AttrNonce).Value) allocateRequest = internal.AllocateRequestAuth(opts.Username, opts.Password, nonce, realm, internal.RequestedTransportTCP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return false, fmt.Errorf("error on sending allocate request 2: %w", err) } @@ -169,7 +170,7 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool if err != nil { return false, fmt.Errorf("error on generating Connect request: %w", err) } - connectResponse, err := connectRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + connectResponse, err := connectRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { // ignore timeouts, a timeout means open port if errors.Is(err, helper.ErrTimeout) { @@ -184,8 +185,8 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool return true, nil } -func scanUDP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { - remote, _, _, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, targetHost, targetPort, opts.Username, opts.Password) +func scanUDP(ctx context.Context, opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { + remote, _, _, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, targetHost, targetPort, opts.Username, opts.Password) if err != nil { return false, err } diff --git a/internal/cmd/socks.go b/internal/cmd/socks.go index e4df00bf..aeaaba29 100644 --- a/internal/cmd/socks.go +++ b/internal/cmd/socks.go @@ -52,13 +52,12 @@ func (opts SocksOpts) Validate() error { return nil } -func Socks(opts SocksOpts) error { +func Socks(ctx context.Context, opts SocksOpts) error { if err := opts.Validate(); err != nil { return err } handler := &socksimplementations.SocksTurnTCPHandler{ - Ctx: context.Background(), Server: opts.TurnServer, TURNUsername: opts.Username, TURNPassword: opts.Password, @@ -74,7 +73,7 @@ func Socks(opts SocksOpts) error { Log: opts.Log, } opts.Log.Infof("starting SOCKS server on %s", opts.Listen) - if err := p.Start(); err != nil { + if err := p.Start(ctx); err != nil { return err } <-p.Done diff --git a/internal/cmd/tcpscanner.go b/internal/cmd/tcpscanner.go index bf417767..e7ba4849 100644 --- a/internal/cmd/tcpscanner.go +++ b/internal/cmd/tcpscanner.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "crypto/tls" "encoding/hex" "fmt" @@ -55,7 +56,7 @@ func (opts TCPScannerOpts) Validate() error { return nil } -func TCPScanner(opts TCPScannerOpts) error { +func TCPScanner(ctx context.Context, opts TCPScannerOpts) error { if err := opts.Validate(); err != nil { return err } @@ -79,7 +80,7 @@ func TCPScanner(opts TCPScannerOpts) error { return fmt.Errorf("Invalid port %s: %w", port, err) } opts.Log.Debugf("Scanning %s:%d", ip.IP.String(), portI) - if err := httpScan(opts, ip.IP, uint16(portI)); err != nil { + if err := httpScan(ctx, opts, ip.IP, uint16(portI)); err != nil { opts.Log.Errorf("error on running HTTP Scan for %s:%d: %v", ip.IP.String(), portI, err) } } @@ -88,8 +89,8 @@ func TCPScanner(opts TCPScannerOpts) error { return nil } -func httpScan(opts TCPScannerOpts, ip netip.Addr, port uint16) error { - controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(opts.Log, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) +func httpScan(ctx context.Context, opts TCPScannerOpts, ip netip.Addr, port uint16) error { + _, _, controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(ctx, opts.Log, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) if err != nil { return err } @@ -103,10 +104,10 @@ func httpScan(opts TCPScannerOpts, ip netip.Addr, port uint16) error { if useTLS { tlsConn := tls.Client(dataConnection, &tls.Config{InsecureSkipVerify: true}) - if err := helper.ConnectionWrite(tlsConn, []byte(httpRequest), opts.Timeout); err != nil { + if err := helper.ConnectionWrite(ctx, tlsConn, []byte(httpRequest), opts.Timeout); err != nil { return fmt.Errorf("error on sending TLS data: %w", err) } - data, err := helper.ConnectionRead(tlsConn, opts.Timeout) + data, err := helper.ConnectionRead(ctx, tlsConn, opts.Timeout) if err != nil { return fmt.Errorf("error on reading after sending TLS data: %w", err) } @@ -116,10 +117,10 @@ func httpScan(opts TCPScannerOpts, ip netip.Addr, port uint16) error { } // plain text connection - if err := helper.ConnectionWrite(dataConnection, []byte(httpRequest), opts.Timeout); err != nil { + if err := helper.ConnectionWrite(ctx, dataConnection, []byte(httpRequest), opts.Timeout); err != nil { return fmt.Errorf("error on sending data: %w", err) } - data, err := helper.ConnectionRead(dataConnection, opts.Timeout) + data, err := helper.ConnectionRead(ctx, dataConnection, opts.Timeout) if err != nil { return fmt.Errorf("error on reading after sending data: %w", err) } diff --git a/internal/cmd/udpscanner.go b/internal/cmd/udpscanner.go index 607b251b..289c06b1 100644 --- a/internal/cmd/udpscanner.go +++ b/internal/cmd/udpscanner.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "errors" "fmt" "math/rand" @@ -56,7 +57,7 @@ func (opts UDPScannerOpts) Validate() error { return nil } -func UDPScanner(opts UDPScannerOpts) error { +func UDPScanner(ctx context.Context, opts UDPScannerOpts) error { if err := opts.Validate(); err != nil { return err } @@ -74,10 +75,10 @@ func UDPScanner(opts UDPScannerOpts) error { continue } opts.Log.Debugf("Scanning %s", ip.IP.String()) - if err := snmpScan(opts, ip.IP, 161, opts.CommunityString); err != nil { + if err := snmpScan(ctx, opts, ip.IP, 161, opts.CommunityString); err != nil { opts.Log.Errorf("error on running SNMP Scan for ip %s: %v", ip.IP.String(), err) } - if err := dnsScan(opts, ip.IP, 53, opts.DomainName); err != nil { + if err := dnsScan(ctx, opts, ip.IP, 53, opts.DomainName); err != nil { opts.Log.Errorf("error on running DNS Scan for ip %s: %v", ip.IP.String(), err) } } @@ -85,8 +86,8 @@ func UDPScanner(opts UDPScannerOpts) error { return nil } -func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) error { - remote, realm, nonce, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) +func snmpScan(ctx context.Context, opts UDPScannerOpts, ip netip.Addr, port uint16, community string) error { + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { @@ -105,7 +106,7 @@ func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) return fmt.Errorf("error on generating ChannelBindRequest: %w", err) } - channelBindResponse, err := channelBindRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending ChannelBindRequest: %w", err) } @@ -147,12 +148,12 @@ func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) buf = append(buf, helper.PutUint16(uint16(snmpLen))...) buf = append(buf, snmp...) - err = helper.ConnectionWrite(remote, buf, opts.Timeout) + err = helper.ConnectionWrite(ctx, remote, buf, opts.Timeout) if err != nil { return fmt.Errorf("error on sending SNMP request: %w", err) } - resp, err := helper.ConnectionRead(remote, opts.Timeout) + resp, err := helper.ConnectionRead(ctx, remote, opts.Timeout) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { @@ -172,8 +173,8 @@ func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) return nil } -func dnsScan(opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) error { - remote, realm, nonce, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) +func dnsScan(ctx context.Context, opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) error { + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { @@ -192,7 +193,7 @@ func dnsScan(opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) er return fmt.Errorf("error on generating ChannelBindRequest: %w", err) } - channelBindResponse, err := channelBindRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending ChannelBindRequest: %w", err) } @@ -239,12 +240,12 @@ func dnsScan(opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) er buf = append(buf, helper.PutUint16(uint16(dnsLen))...) buf = append(buf, dns...) - err = helper.ConnectionWrite(remote, buf, opts.Timeout) + err = helper.ConnectionWrite(ctx, remote, buf, opts.Timeout) if err != nil { return fmt.Errorf("error on sending DNS request: %w", err) } - resp, err := helper.ConnectionRead(remote, opts.Timeout) + resp, err := helper.ConnectionRead(ctx, remote, opts.Timeout) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { diff --git a/internal/connection.go b/internal/connection.go index 7da902e6..01d4575f 100644 --- a/internal/connection.go +++ b/internal/connection.go @@ -11,7 +11,7 @@ import ( "github.com/pion/dtls/v2" ) -func Connect(protocol string, turnServer string, useTLS bool, timeout time.Duration) (net.Conn, error) { +func Connect(ctx context.Context, protocol string, turnServer string, useTLS bool, timeout time.Duration) (net.Conn, error) { if !useTLS { // non TLS connection conn, err := net.DialTimeout(protocol, turnServer, timeout) @@ -39,7 +39,7 @@ func Connect(protocol string, turnServer string, useTLS bool, timeout time.Durat if err != nil { return nil, fmt.Errorf("error on establishing a connection to the server: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() dtlsConn, err := dtls.ClientWithContext(ctx, conn, &dtls.Config{ InsecureSkipVerify: true, @@ -54,12 +54,12 @@ func Connect(protocol string, turnServer string, useTLS bool, timeout time.Durat } // send serializes a STUN object and sends it on the provided connection -func (s *Stun) send(conn net.Conn, timeout time.Duration) error { +func (s *Stun) send(ctx context.Context, conn net.Conn, timeout time.Duration) error { data, err := s.Serialize() if err != nil { return fmt.Errorf("Serialize: %w", err) } - if err := helper.ConnectionWrite(conn, data, timeout); err != nil { + if err := helper.ConnectionWrite(ctx, conn, data, timeout); err != nil { return fmt.Errorf("ConnectionWrite: %w", err) } @@ -67,13 +67,13 @@ func (s *Stun) send(conn net.Conn, timeout time.Duration) error { } // SendAndReceive sends a TURN request on a connection and gets a response -func (s *Stun) SendAndReceive(logger DebugLogger, conn net.Conn, timeout time.Duration) (*Stun, error) { +func (s *Stun) SendAndReceive(ctx context.Context, logger DebugLogger, conn net.Conn, timeout time.Duration) (*Stun, error) { logger.Debugf("Sending\n%s", s.String()) - err := s.send(conn, timeout) + err := s.send(ctx, conn, timeout) if err != nil { return nil, fmt.Errorf("Send: %w", err) } - buffer, err := helper.ConnectionRead(conn, timeout) + buffer, err := helper.ConnectionRead(ctx, conn, timeout) if err != nil { return nil, fmt.Errorf("ConnectionRead: %w", err) } diff --git a/internal/helper/connection.go b/internal/helper/connection.go index bf138aad..9e9bc799 100644 --- a/internal/helper/connection.go +++ b/internal/helper/connection.go @@ -1,8 +1,8 @@ package helper import ( + "context" "errors" - "fmt" "io" "net" "time" @@ -11,56 +11,65 @@ import ( var ErrTimeout = errors.New("timeout occurred. you can try to increase the timeout if the server responds too slowly") // ConnectionRead reads all data from a connection -func ConnectionRead(conn net.Conn, timeout time.Duration) ([]byte, error) { +func ConnectionRead(ctx context.Context, conn net.Conn, timeout time.Duration) ([]byte, error) { var ret []byte - if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { - return nil, fmt.Errorf("could not set read deadline: %w", err) - } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() bufLen := 1024 for { - buf := make([]byte, bufLen) - i, err := conn.Read(buf) - if err != nil { - if err != io.EOF { - // also return read data on timeout so caller can use it - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return ret, ErrTimeout + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + buf := make([]byte, bufLen) + i, err := conn.Read(buf) + if err != nil { + if err != io.EOF { + // also return read data on timeout so caller can use it + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return ret, ErrTimeout + } + return nil, err } - return nil, err + return ret, nil + } + ret = append(ret, buf[:i]...) + // we've read all data, bail out + if i < bufLen { + return ret, nil } - return ret, nil - } - ret = append(ret, buf[:i]...) - // we've read all data, bail out - if i < bufLen { - return ret, nil } } } // ConnectionWrite makes sure to write all data to a connection -func ConnectionWrite(conn net.Conn, data []byte, timeout time.Duration) error { +func ConnectionWrite(ctx context.Context, conn net.Conn, data []byte, timeout time.Duration) error { toWriteLeft := len(data) written := 0 - err := conn.SetWriteDeadline(time.Now().Add(timeout)) - if err != nil { - return fmt.Errorf("could not set write deadline: %w", err) - } + var err error + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() for { - written, err = conn.Write(data[written:toWriteLeft]) - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return ErrTimeout - } else { - return err + select { + case <-ctx.Done(): + return ctx.Err() + default: + written, err = conn.Write(data[written:toWriteLeft]) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return ErrTimeout + } else { + return err + } } + if written == toWriteLeft { + return nil + } + toWriteLeft -= written } - if written == toWriteLeft { - return nil - } - toWriteLeft -= written } } diff --git a/internal/helpers_turn.go b/internal/helpers_turn.go index b58bb3d4..922ffc82 100644 --- a/internal/helpers_turn.go +++ b/internal/helpers_turn.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/binary" "fmt" "net" @@ -127,8 +128,8 @@ func ConvertXORAddr(input []byte, transactionID string) (string, uint16, error) // CreatePermission // // it returns the connection, the realm, the nonce and an error -func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (net.Conn, string, string, error) { - remote, err := Connect(connectProtocol, turnServer, useTLS, timeout) +func SetupTurnConnection(ctx context.Context, logger DebugLogger, connectProtocol string, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (net.Conn, string, string, error) { + remote, err := Connect(ctx, connectProtocol, turnServer, useTLS, timeout) if err != nil { return nil, "", "", err } @@ -139,7 +140,7 @@ func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer } allocateRequest := AllocateRequest(RequestedTransportUDP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(logger, remote, timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, logger, remote, timeout) if err != nil { return nil, "", "", fmt.Errorf("error on sending AllocateRequest: %w", err) } @@ -151,7 +152,7 @@ func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer nonce := string(allocateResponse.GetAttribute(AttrNonce).Value) allocateRequest = AllocateRequestAuth(username, password, nonce, realm, RequestedTransportUDP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(logger, remote, timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, logger, remote, timeout) if err != nil { return nil, "", "", fmt.Errorf("error on sending AllocateRequest Auth: %w", err) } @@ -162,7 +163,7 @@ func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer if err != nil { return nil, "", "", fmt.Errorf("error on generating CreatePermissionRequest: %w", err) } - permissionResponse, err := permissionRequest.SendAndReceive(logger, remote, timeout) + permissionResponse, err := permissionRequest.SendAndReceive(ctx, logger, remote, timeout) if err != nil { return nil, "", "", fmt.Errorf("error on sending CreatePermissionRequest: %w", err) } diff --git a/internal/helpers_turntcp.go b/internal/helpers_turntcp.go index 0fb04726..c4225320 100644 --- a/internal/helpers_turntcp.go +++ b/internal/helpers_turntcp.go @@ -1,12 +1,17 @@ package internal import ( + "context" "fmt" "net" "net/netip" "time" ) +type keepAlive interface { + SetKeepAlive(bool) +} + // SetupTurnTCPConnection executes the following: // // Allocate Unauth (to get realm and nonce) @@ -16,19 +21,16 @@ import ( // ConnectionBind // // it returns the controlConnection, the dataConnection and an error -func SetupTurnTCPConnection(logger DebugLogger, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (*net.TCPConn, *net.TCPConn, error) { +func SetupTurnTCPConnection(ctx context.Context, logger DebugLogger, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (string, string, net.Conn, net.Conn, error) { // protocol needs to be tcp - controlConnectionRaw, err := Connect("tcp", turnServer, useTLS, timeout) + controlConnection, err := Connect(ctx, "tcp", turnServer, useTLS, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on establishing control connection: %w", err) + return "", "", nil, nil, fmt.Errorf("error on establishing control connection: %w", err) } - controlConnection, ok := controlConnectionRaw.(*net.TCPConn) - if !ok { - return nil, nil, fmt.Errorf("could not cast control connection to TCPConn") - } - if err := controlConnection.SetKeepAlive(true); err != nil { - return nil, nil, fmt.Errorf("could not set KeepAlive on control connection: %w", err) + if x, ok := controlConnection.(keepAlive); ok { + logger.Debug("controlconnection: set keepalive to true") + x.SetKeepAlive(true) } logger.Debugf("opened turn tcp control connection from %s to %s", controlConnection.LocalAddr().String(), controlConnection.RemoteAddr().String()) @@ -39,63 +41,60 @@ func SetupTurnTCPConnection(logger DebugLogger, turnServer string, useTLS bool, } allocateRequest := AllocateRequest(RequestedTransportTCP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(logger, controlConnection, timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, logger, controlConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending allocate request 1: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending allocate request 1: %w", err) } if allocateResponse.Header.MessageType.Class != MsgTypeClassError { - return nil, nil, fmt.Errorf("MessageClass is not Error (should be not authenticated)") + return "", "", nil, nil, fmt.Errorf("MessageClass is not Error (should be not authenticated)") } realm := string(allocateResponse.GetAttribute(AttrRealm).Value) nonce := string(allocateResponse.GetAttribute(AttrNonce).Value) allocateRequest = AllocateRequestAuth(username, password, nonce, realm, RequestedTransportTCP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(logger, controlConnection, timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, logger, controlConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending allocate request 2: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending allocate request 2: %w", err) } if allocateResponse.Header.MessageType.Class == MsgTypeClassError { - return nil, nil, fmt.Errorf("error on allocate response: %s", allocateResponse.GetErrorString()) + return "", "", nil, nil, fmt.Errorf("error on allocate response: %s", allocateResponse.GetErrorString()) } connectRequest, err := ConnectRequestAuth(username, password, nonce, realm, targetHost, targetPort) if err != nil { - return nil, nil, fmt.Errorf("error on generating Connect request: %w", err) + return "", "", nil, nil, fmt.Errorf("error on generating Connect request: %w", err) } - connectResponse, err := connectRequest.SendAndReceive(logger, controlConnection, timeout) + connectResponse, err := connectRequest.SendAndReceive(ctx, logger, controlConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending Connect request: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending Connect request: %w", err) } if connectResponse.Header.MessageType.Class == MsgTypeClassError { - return nil, nil, fmt.Errorf("error on Connect response: %s", connectResponse.GetErrorString()) + return "", "", nil, nil, fmt.Errorf("error on Connect response: %s", connectResponse.GetErrorString()) } connectionID := connectResponse.GetAttribute(AttrConnectionID).Value - dataConnectionRaw, err := Connect("tcp", turnServer, useTLS, timeout) + dataConnection, err := Connect(ctx, "tcp", turnServer, useTLS, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on establishing data connection: %w", err) + return "", "", nil, nil, fmt.Errorf("error on establishing data connection: %w", err) } - dataConnection, ok := dataConnectionRaw.(*net.TCPConn) - if !ok { - return nil, nil, fmt.Errorf("could not cast data connection to TCPConn") - } - if err := dataConnection.SetKeepAlive(true); err != nil { - return nil, nil, fmt.Errorf("could not set KeepAlive on data connection: %w", err) + if x, ok := dataConnection.(keepAlive); ok { + logger.Debug("dataconnection: set keepalive to true") + x.SetKeepAlive(true) } logger.Debugf("opened turn tcp data connection from %s to %s", dataConnection.LocalAddr().String(), dataConnection.RemoteAddr().String()) connectionBindRequest := ConnectionBindRequest(connectionID, username, password, nonce, realm) - connectionBindResponse, err := connectionBindRequest.SendAndReceive(logger, dataConnection, timeout) + connectionBindResponse, err := connectionBindRequest.SendAndReceive(ctx, logger, dataConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending ConnectionBind request: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending ConnectionBind request: %w", err) } if connectionBindResponse.Header.MessageType.Class == MsgTypeClassError { - return nil, nil, fmt.Errorf("error on ConnectionBind reposnse: %s", connectionBindResponse.GetErrorString()) + return "", "", nil, nil, fmt.Errorf("error on ConnectionBind reposnse: %s", connectionBindResponse.GetErrorString()) } - return controlConnection, dataConnection, nil + return realm, nonce, controlConnection, dataConnection, nil } diff --git a/internal/logger.go b/internal/logger.go index a010651d..843b298f 100644 --- a/internal/logger.go +++ b/internal/logger.go @@ -1,5 +1,6 @@ package internal type DebugLogger interface { + Debug(...interface{}) Debugf(format string, args ...interface{}) } diff --git a/internal/socksimplementations/socksturntcphandler.go b/internal/socksimplementations/socksturntcphandler.go index 68c2a691..782b8501 100644 --- a/internal/socksimplementations/socksturntcphandler.go +++ b/internal/socksimplementations/socksturntcphandler.go @@ -2,6 +2,7 @@ package socksimplementations import ( "context" + "errors" "fmt" "io" "net" @@ -17,7 +18,6 @@ import ( // SocksTurnTCPHandler is the implementation of a TCP TURN server type SocksTurnTCPHandler struct { - Ctx context.Context ControlConnection net.Conn TURNUsername string TURNPassword string @@ -26,10 +26,12 @@ type SocksTurnTCPHandler struct { UseTLS bool DropNonPrivateRequests bool Log *logrus.Logger + realm string + nonce string } // PreHandler connects to the STUN server, sets the connection up and returns the data connections -func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, *socks.Error) { +func (s *SocksTurnTCPHandler) Init(ctx context.Context, request socks.Request) (io.ReadWriteCloser, *socks.Error) { var target netip.Addr var err error switch request.AddressType { @@ -45,7 +47,7 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, * target = ip } else { // input is a hostname - names, err := helper.ResolveName(s.Ctx, string(request.DestinationAddress)) + names, err := helper.ResolveName(ctx, string(request.DestinationAddress)) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } @@ -63,10 +65,12 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, * return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)) } - controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(s.Log, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) + realm, nonce, controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(ctx, s.Log, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } + s.realm = realm + s.nonce = nonce // we need to keep this connection open s.ControlConnection = controlConnection @@ -75,60 +79,102 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, * // Refresh is used to refresh an active connection every 2 minutes func (s *SocksTurnTCPHandler) Refresh(ctx context.Context) { - nonce := "" - realm := "" - tick := time.NewTicker(2 * time.Minute) - select { - case <-ctx.Done(): - return - case <-tick.C: - s.Log.Debug("[socks] refreshing connection") - refresh := internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) - response, err := refresh.SendAndReceive(s.Log, s.ControlConnection, s.Timeout) - if err != nil { - s.Log.Error(err) + nonce := s.nonce + realm := s.realm + tick := time.NewTicker(5 * time.Minute) // default timeout on coturn is 600 seconds (10 minutes) + for { + select { + case <-ctx.Done(): return - } - // should happen on a stale nonce - if response.Header.MessageType.Class == internal.MsgTypeClassError { - realm := string(response.GetAttribute(internal.AttrRealm).Value) - nonce := string(response.GetAttribute(internal.AttrNonce).Value) - refresh = internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) - response, err = refresh.SendAndReceive(s.Log, s.ControlConnection, s.Timeout) + case <-tick.C: + s.Log.Debug("[socks] refreshing connection") + refresh := internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) + response, err := refresh.SendAndReceive(ctx, s.Log, s.ControlConnection, s.Timeout) if err != nil { s.Log.Error(err) return } + // should happen on a stale nonce if response.Header.MessageType.Class == internal.MsgTypeClassError { - s.Log.Error(response.GetErrorString()) - return + realm := string(response.GetAttribute(internal.AttrRealm).Value) + nonce := string(response.GetAttribute(internal.AttrNonce).Value) + s.nonce = nonce + s.realm = realm + refresh = internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) + response, err = refresh.SendAndReceive(ctx, s.Log, s.ControlConnection, s.Timeout) + if err != nil { + s.Log.Error(err) + return + } + if response.Header.MessageType.Class == internal.MsgTypeClassError { + s.Log.Error(response.GetErrorString()) + return + } } } } } +const bufferLength = 1024 * 100 + // ReadFromClient is used to copy data func (s *SocksTurnTCPHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { - i, err := io.Copy(remote, client) - if err != nil { - return fmt.Errorf("CopyFromRemoteToClient: %w", err) + for { + // anonymous func for defer + // this might not be the fastest, but it does the trick + err := func() error { + ctx, cancel := context.WithTimeout(ctx, s.Timeout) + defer cancel() + select { + case <-ctx.Done(): + return ctx.Err() + default: + i, err := io.CopyN(remote, client, bufferLength) + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return fmt.Errorf("ReadFromClient: %w", err) + } + s.Log.Debugf("[socks] wrote %d bytes to client", i) + } + return nil + }() + if err != nil { + return err + } } - s.Log.Debugf("[socks] wrote %d bytes to client", i) - return nil } // ReadFromRemote is used to copy data func (s *SocksTurnTCPHandler) ReadFromRemote(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { - i, err := io.Copy(client, remote) - if err != nil { - return fmt.Errorf("CopyFromClientToRemote: %w", err) + for { + // anonymous func for defer + // this might not be the fastest, but it does the trick + err := func() error { + ctx, cancel := context.WithTimeout(ctx, s.Timeout) + defer cancel() + select { + case <-ctx.Done(): + return ctx.Err() + default: + i, err := io.CopyN(client, remote, bufferLength) + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return fmt.Errorf("ReadFromRemote: %w", err) + } + s.Log.Debugf("[socks] wrote %d bytes to remote", i) + } + return nil + }() + if err != nil { + return err + } } - s.Log.Debugf("[socks] wrote %d bytes to remote", i) - return nil } // Cleanup closes the stored control connection -func (s *SocksTurnTCPHandler) Close() error { +func (s *SocksTurnTCPHandler) Close(ctx context.Context) error { if s.ControlConnection != nil { return s.ControlConnection.Close() } diff --git a/internal/socksimplementations/socksturnudphandler.go b/internal/socksimplementations/socksturnudphandler.go index 5cc039cf..c411471d 100644 --- a/internal/socksimplementations/socksturnudphandler.go +++ b/internal/socksimplementations/socksturnudphandler.go @@ -17,7 +17,6 @@ import ( // SocksTurnUDPHandler is the implementation of a UDP TURN server type SocksTurnUDPHandler struct { - Ctx context.Context TURNUsername string TURNPassword string Server string @@ -30,7 +29,7 @@ type SocksTurnUDPHandler struct { } // PreHandler creates a connection to the target server and returns a connection to send data -func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteCloser, *socks.Error) { +func (s *SocksTurnUDPHandler) Init(ctx context.Context, request socks.Request) (io.ReadWriteCloser, *socks.Error) { var target netip.Addr var err error switch request.AddressType { @@ -41,7 +40,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo } target = tmp case socks.RequestAddressTypeDomainname: - names, err := helper.ResolveName(s.Ctx, string(request.DestinationAddress)) + names, err := helper.ResolveName(ctx, string(request.DestinationAddress)) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } @@ -58,7 +57,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)) } - remote, realm, nonce, err := internal.SetupTurnConnection(s.Log, s.ConnectProtocol, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, s.Log, s.ConnectProtocol, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } @@ -73,7 +72,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("error on generating ChannelBindRequest: %w", err)) } s.Log.Debugf("ChannelBind Request:\n%s", channelBindRequest.String()) - channelBindResponse, err := channelBindRequest.SendAndReceive(s.Log, remote, s.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, s.Log, remote, s.Timeout) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("error on sending ChannelBindRequest: %w", err)) } @@ -85,7 +84,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo } // CopyFromRemoteToClient is used to send data and remove the extra channel data header -func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { +func (s *SocksTurnUDPHandler) ReadFromRemote(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { clientConn, ok := client.(net.Conn) if !ok { return fmt.Errorf("could not cast client to net.Conn") @@ -95,7 +94,7 @@ func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote return fmt.Errorf("could not cast remote to net.Conn") } - recv, err := helper.ConnectionRead(remoteConn, s.Timeout) + recv, err := helper.ConnectionRead(ctx, remoteConn, s.Timeout) if err != nil { return err } @@ -106,7 +105,7 @@ func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote } s.Log.Debugf("received %d bytes on channel %02x", len(data), channel) - err = helper.ConnectionWrite(clientConn, data, s.Timeout) + err = helper.ConnectionWrite(ctx, clientConn, data, s.Timeout) if err != nil { return err } @@ -114,7 +113,7 @@ func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote } // CopyFromClientToRemote is used to send data and add the extra channel data header -func (s *SocksTurnUDPHandler) CopyFromClientToRemote(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { +func (s *SocksTurnUDPHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { clientConn, ok := client.(net.Conn) if !ok { return fmt.Errorf("could not cast client to net.Conn") @@ -124,7 +123,7 @@ func (s *SocksTurnUDPHandler) CopyFromClientToRemote(ctx context.Context, client return fmt.Errorf("could not cast remote to net.Conn") } - toSend, err := helper.ConnectionRead(clientConn, s.Timeout) + toSend, err := helper.ConnectionRead(ctx, clientConn, s.Timeout) if err != nil { return err } @@ -136,7 +135,7 @@ func (s *SocksTurnUDPHandler) CopyFromClientToRemote(ctx context.Context, client buf = append(buf, helper.PutUint16(uint16(toSendLen))...) buf = append(buf, toSend...) - err = helper.ConnectionWrite(remoteConn, buf, s.Timeout) + err = helper.ConnectionWrite(ctx, remoteConn, buf, s.Timeout) if err != nil { return err } @@ -148,6 +147,6 @@ func (s *SocksTurnUDPHandler) Refresh(_ context.Context) { } // Cleanup is not used in this implementation -func (s *SocksTurnUDPHandler) Cleanup() error { +func (s *SocksTurnUDPHandler) Close(ctx context.Context) error { return nil } diff --git a/main.go b/main.go index 75f9bb55..8655dffd 100644 --- a/main.go +++ b/main.go @@ -52,7 +52,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, }, Before: func(ctx *cli.Context) error { if ctx.Bool("debug") { @@ -65,7 +65,7 @@ func main() { useTLS := c.Bool("tls") protocol := c.String("protocol") timeout := c.Duration("timeout") - return cmd.Info(cmd.InfoOpts{ + return cmd.Info(c.Context, cmd.InfoOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -86,7 +86,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, }, @@ -103,7 +103,7 @@ func main() { timeout := c.Duration("timeout") username := c.String("username") password := c.String("password") - return cmd.BruteTransports(cmd.BruteTransportOpts{ + return cmd.BruteTransports(c.Context, cmd.BruteTransportOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -125,7 +125,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "passfile", Aliases: []string{"p"}, Required: true, Usage: "passwordfile to use for bruteforce"}, }, @@ -142,7 +142,7 @@ func main() { timeout := c.Duration("timeout") username := c.String("username") passwordFile := c.String("passfile") - return cmd.BruteForce(cmd.BruteforceOpts{ + return cmd.BruteForce(c.Context, cmd.BruteforceOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -168,7 +168,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "target", Aliases: []string{"t"}, Required: true, Usage: "Target to leak memory to in the form host:port. Should be a public server under your control"}, @@ -206,7 +206,7 @@ func main() { } size := c.Uint("size") - return cmd.MemoryLeak(cmd.MemoryleakOpts{ + return cmd.MemoryLeak(c.Context, cmd.MemoryleakOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -231,7 +231,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, }, @@ -248,7 +248,7 @@ func main() { timeout := c.Duration("timeout") username := c.String("username") password := c.String("password") - return cmd.RangeScan(cmd.RangeScanOpts{ + return cmd.RangeScan(c.Context, cmd.RangeScanOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -269,7 +269,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "listen", Aliases: []string{"l"}, Value: "127.0.0.1:1080", Usage: "Address and port to listen on"}, @@ -290,7 +290,7 @@ func main() { password := c.String("password") listen := c.String("listen") dropPublic := c.Bool("drop-public") - return cmd.Socks(cmd.SocksOpts{ + return cmd.Socks(c.Context, cmd.SocksOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -312,7 +312,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "ports", Value: "80,443,8080,8081", Usage: "Ports to check"}, @@ -337,7 +337,7 @@ func main() { ips := c.StringSlice("ip") - return cmd.TCPScanner(cmd.TCPScannerOpts{ + return cmd.TCPScanner(c.Context, cmd.TCPScannerOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -360,7 +360,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "community-string", Value: "public", Usage: "SNMP community string to use for scanning"}, @@ -383,7 +383,7 @@ func main() { communityString := c.String("community-string") domain := c.String("domain") ips := c.StringSlice("ip") - return cmd.UDPScanner(cmd.UDPScannerOpts{ + return cmd.UDPScanner(c.Context, cmd.UDPScannerOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol,