Skip to content

Commit

Permalink
add Headers to output_ws.go (#1192)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-stankov-salt-security authored Jun 3, 2023
1 parent e546fdb commit 3062446
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
10 changes: 9 additions & 1 deletion output_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import (
"crypto/tls"
"encoding/base64"
"fmt"
"github.com/gorilla/websocket"
"hash/fnv"
"log"
"net/http"
"net/url"
"strings"
"time"

"github.com/gorilla/websocket"
)

// WebSocketOutput used for sending raw tcp payloads
Expand All @@ -33,6 +34,8 @@ type WebSocketOutputConfig struct {
Sticky bool `json:"output-ws-sticky"`
SkipVerify bool `json:"output-ws-skip-verify"`
Workers int `json:"output-ws-workers"`

Headers map[string][]string `json:"output-ws-headers"`
}

// NewWebSocketOutput constructor for WebSocketOutput
Expand All @@ -49,6 +52,11 @@ func NewWebSocketOutput(address string, config *WebSocketOutputConfig) PluginWri
o.headers = http.Header{
"Authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte(u.User.String()))},
}
for k, values := range config.Headers {
for _, v := range values {
o.headers.Add(k, v)
}
}

u.User = nil // must be after creating the headers
o.address = u.String()
Expand Down
26 changes: 22 additions & 4 deletions output_ws_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
package goreplay

import (
"github.com/gorilla/websocket"
"log"
"net/http"
"sync"
"testing"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)

func TestWebSocketOutput(t *testing.T) {
wg := new(sync.WaitGroup)

var gotHeader http.Header
wsAddr := startWebsocket(func(data []byte) {
wg.Done()
}, func(header http.Header) {
gotHeader = header
})
input := NewTestInput()
output := NewWebSocketOutput(wsAddr, &WebSocketOutputConfig{Workers: 1})
headers := map[string][]string{
"key1": {"value1"},
"key2": {"value2"},
}
output := NewWebSocketOutput(wsAddr, &WebSocketOutputConfig{Workers: 1, Headers: headers})

plugins := &InOutPlugins{
Inputs: []PluginReader{input},
Expand All @@ -32,12 +41,21 @@ func TestWebSocketOutput(t *testing.T) {

wg.Wait()
emitter.Close()

if assert.NotNil(t, gotHeader) {
assert.Equal(t, "Basic dXNlcjE=", gotHeader.Get("Authorization"))
for k, values := range headers {
assert.Equal(t, 1, len(values))
assert.Equal(t, values[0], gotHeader.Get(k))
}
}
}

func startWebsocket(cb func([]byte)) string {
func startWebsocket(cb func([]byte), headercb func(http.Header)) string {
upgrader := websocket.Upgrader{}

http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
headercb(r.Header)
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
Expand All @@ -60,5 +78,5 @@ func startWebsocket(cb func([]byte)) string {
}
}()

return "ws://localhost:8081/test"
return "ws://user1@localhost:8081/test"
}

0 comments on commit 3062446

Please sign in to comment.