From fd62d693eea7cc201f420f52c38501fd67982ad8 Mon Sep 17 00:00:00 2001 From: ivan-stankov-salt-security <124668375+ivan-stankov-salt-security@users.noreply.github.com> Date: Mon, 7 Aug 2023 16:47:15 +0300 Subject: [PATCH] add GetInitMessage and WriteBeforeMessage to output_tcp.go (#1193) * add GetInitMessage and WriteBeforeMessage to output_tcp.go * try to fix code duplication --- output_tcp.go | 31 ++++++++++++--- output_tcp_test.go | 94 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 102 insertions(+), 23 deletions(-) diff --git a/output_tcp.go b/output_tcp.go index 88527288..0d5f0eb1 100644 --- a/output_tcp.go +++ b/output_tcp.go @@ -29,6 +29,9 @@ type TCPOutputConfig struct { Sticky bool `json:"output-tcp-sticky"` SkipVerify bool `json:"output-tcp-skip-verify"` Workers int `json:"output-tcp-workers"` + + GetInitMessage func() *Message `json:"-"` + WriteBeforeMessage func(conn net.Conn, msg *Message) error `json:"-"` } // NewTCPOutput constructor for TCPOutput @@ -78,14 +81,14 @@ func (o *TCPOutput) worker(bufferIndex int) { defer conn.Close() + if o.config.GetInitMessage != nil { + msg := o.config.GetInitMessage() + _ = o.writeToConnection(conn, msg) + } + for { msg := <-o.buf[bufferIndex] - if _, err = conn.Write(msg.Meta); err == nil { - if _, err = conn.Write(msg.Data); err == nil { - _, err = conn.Write(payloadSeparatorAsBytes) - } - } - + err = o.writeToConnection(conn, msg) if err != nil { Debug(2, "INFO: TCP output connection closed, reconnecting") go o.worker(bufferIndex) @@ -95,6 +98,22 @@ func (o *TCPOutput) worker(bufferIndex int) { } } +func (o *TCPOutput) writeToConnection(conn net.Conn, msg *Message) (err error) { + if o.config.WriteBeforeMessage != nil { + err = o.config.WriteBeforeMessage(conn, msg) + } + + if err == nil { + if _, err = conn.Write(msg.Meta); err == nil { + if _, err = conn.Write(msg.Data); err == nil { + _, err = conn.Write(payloadSeparatorAsBytes) + } + } + } + + return err +} + func (o *TCPOutput) getBufferIndex(msg *Message) int { if !o.config.Sticky { o.workerIndex++ diff --git a/output_tcp_test.go b/output_tcp_test.go index 8a860b16..5163c951 100644 --- a/output_tcp_test.go +++ b/output_tcp_test.go @@ -4,9 +4,12 @@ import ( "bufio" "log" "net" + "strings" "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestTCPOutput(t *testing.T) { @@ -15,24 +18,8 @@ func TestTCPOutput(t *testing.T) { listener := startTCP(func(data []byte) { wg.Done() }) - input := NewTestInput() output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{Workers: 10}) - - plugins := &InOutPlugins{ - Inputs: []PluginReader{input}, - Outputs: []PluginWriter{output}, - } - - emitter := NewEmitter() - go emitter.Start(plugins, Settings.Middleware) - - for i := 0; i < 10; i++ { - wg.Add(1) - input.EmitGET() - } - - wg.Wait() - emitter.Close() + runTCPOutput(wg, output, 10, false) } func startTCP(cb func([]byte)) net.Listener { @@ -131,3 +118,76 @@ func getTestBytes() *Message { Data: []byte("GET / HTTP/1.1\r\nHost: www.w3.org\r\nUser-Agent: Go 1.1 package http\r\nAccept-Encoding: gzip\r\n\r\n"), } } + +func TestTCPOutputGetInitMessage(t *testing.T) { + wg := new(sync.WaitGroup) + + var dataList [][]byte + listener := startTCP(func(data []byte) { + dataList = append(dataList, data) + wg.Done() + }) + getInitMessage := func() *Message { + return &Message{ + Meta: []byte{}, + Data: []byte("test1"), + } + } + output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{Workers: 1, GetInitMessage: getInitMessage}) + + runTCPOutput(wg, output, 1, true) + + if assert.Equal(t, 2, len(dataList)) { + assert.Equal(t, "test1", string(dataList[0])) + } +} + +func TestTCPOutputGetInitMessageAndWriteBeforeMessage(t *testing.T) { + wg := new(sync.WaitGroup) + + var dataList [][]byte + listener := startTCP(func(data []byte) { + dataList = append(dataList, data) + wg.Done() + }) + getInitMessage := func() *Message { + return &Message{ + Meta: []byte{}, + Data: []byte("test2"), + } + } + writeBeforeMessage := func(conn net.Conn, _ *Message) error { + _, err := conn.Write([]byte("before")) + return err + } + output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{Workers: 1, GetInitMessage: getInitMessage, WriteBeforeMessage: writeBeforeMessage}) + + runTCPOutput(wg, output, 1, true) + + if assert.Equal(t, 2, len(dataList)) { + assert.Equal(t, "beforetest2", string(dataList[0])) + assert.True(t, strings.HasPrefix(string(dataList[1]), "before")) + } +} + +func runTCPOutput(wg *sync.WaitGroup, output PluginWriter, repeat int, initMessage bool) { + input := NewTestInput() + plugins := &InOutPlugins{ + Inputs: []PluginReader{input}, + Outputs: []PluginWriter{output}, + } + + emitter := NewEmitter() + go emitter.Start(plugins, Settings.Middleware) + + if initMessage { + wg.Add(1) + } + for i := 0; i < repeat; i++ { + wg.Add(1) + input.EmitGET() + } + + wg.Wait() + emitter.Close() +}