Skip to content

Commit

Permalink
Merge pull request #308 from grafana/stream-test-fix-event-done
Browse files Browse the repository at this point in the history
stream_test: more realistic data + test happy path
  • Loading branch information
yoziru authored Apr 29, 2024
2 parents 0014fad + 36d1279 commit 55637b7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
7 changes: 6 additions & 1 deletion packages/grafana-llm-app/pkg/plugin/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,14 @@ func (a *App) runOpenAIChatCompletionsStream(ctx context.Context, req *backend.R
return nil
case event := <-eventStream.Events:
var body map[string]interface{}
if event == nil {
// make sure we have an event, otherwise, event.Data() will panic
log.DefaultLogger.Warn(fmt.Sprintf("proxy: stream: event is nil, ending (in sad branch): %s", req.Path))
return nil
}
eventData := event.Data()
// If the event data is "[DONE]", then we're done.
if eventData == "[DONE]" {
if eventData == "[DONE]" || event.Event() == "done" {
err = sender.SendJSON([]byte(`{"choices": [{"delta": {"done": true}}]}`))
if err != nil {
err = fmt.Errorf("proxy: stream: error sending done: %w", err)
Expand Down
68 changes: 52 additions & 16 deletions packages/grafana-llm-app/pkg/plugin/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type mockStreamServer struct {
request *http.Request
}

func newMockOpenAIStreamServer(t *testing.T, statusCode int, finish chan (struct{})) *mockStreamServer {
func newMockOpenAIStreamServer(t *testing.T, statusCode int, includeDone bool) *mockStreamServer {
server := &mockStreamServer{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("mock server got request: %s", r.URL.String())
Expand All @@ -33,17 +33,26 @@ func newMockOpenAIStreamServer(t *testing.T, statusCode int, finish chan (struct
}

w.Header().Set("Content-Type", "text/event-stream")
streamMessages := []byte{}
for i := 0; i < 10; i++ {
// Actual body isn't really important here.
body := fmt.Sprintf(`data: {"choices": [{"text": "%d"}]}\n`, i)
_, err := w.Write([]byte(body))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
w.(http.Flusher).Flush()
data := fmt.Sprintf(`{"choices":[{"delta":{"content":"response%d"},"finish_reason":null,"index":0,"logprobs":null}],"id":"mock-chat-id","model":"gpt-4-turbo","object":"chat.completion.chunk","p":"p","system_fingerprint":"abc"}`, i)
dataBytes := []byte("data: " + data + "\n\n")
streamMessages = append(streamMessages, dataBytes...)
}

_, _ = w.Write([]byte(`data: [DONE]`))
// final message has finish reason
streamMessages = append(streamMessages, []byte(`{"choices":[{"delta":{},"finish_reason":"stop","index":0,"logprobs":null}],"created":1714142715,"id":"mock-chat-id","model":"gpt-4-turbo","object":"chat.completion.chunk","p":"ppppppppppp","system_fingerprint":"abc"}}}`)...)

// done messages
if includeDone {
streamMessages = append(streamMessages, []byte("event: done\n")...)
streamMessages = append(streamMessages, []byte("data: [DONE]\n\n")...)
}
_, err := w.Write(streamMessages)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
w.(http.Flusher).Flush()
})
server.server = httptest.NewServer(handler)
Expand All @@ -64,10 +73,12 @@ func TestRunStream(t *testing.T) {
"model": "gpt-3.5-turbo",
"messages": []
}`)
for _, tc := range []struct {
name string
settings Settings
statusCode int

testCases := []struct {
name string
settings Settings
statusCode int
includeDone bool

expErr string
expMessageCount int
Expand All @@ -90,12 +101,30 @@ func TestRunStream(t *testing.T) {
expErr: "401",
expMessageCount: 0,
},
} {
{
name: "happy path",
settings: Settings{OpenAI: OpenAISettings{Provider: openAIProviderOpenAI}},
statusCode: http.StatusOK,
includeDone: true,

expErr: "",
expMessageCount: 11, // 9 messages + 1 finish reason + 1 done
},
{
name: "happy path without EOF",
settings: Settings{OpenAI: OpenAISettings{Provider: openAIProviderOpenAI}},
statusCode: http.StatusOK,

expErr: "",
expMessageCount: 11, // 9 messages + 1 finish reason + 1 done
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
finish := make(chan struct{})
// Start up a mock server that just captures the request and sends a 200 OK response.
server := newMockOpenAIStreamServer(t, tc.statusCode, finish)
server := newMockOpenAIStreamServer(t, tc.statusCode, tc.includeDone)

// Initialize app (need to set OpenAISettings:URL in here)
settings := tc.settings
Expand Down Expand Up @@ -142,8 +171,8 @@ func TestRunStream(t *testing.T) {
t.Fatalf("RunStream error: %s", err)
}

n := len(r.messages)
if tc.expErr != "" {
n := len(r.messages)
var got EventError
if err = json.Unmarshal(r.messages[n-1], &got); err != nil {
t.Fatalf("got non-JSON error message %s", r.messages[n-1])
Expand All @@ -154,7 +183,14 @@ func TestRunStream(t *testing.T) {
if tc.expMessageCount != n-1 {
t.Fatalf("expected %d non-error messages, got %d", tc.expMessageCount, n-1)
}
return
}

// expect the right number of messages
if tc.expMessageCount != n {
t.Fatalf("expected %d messages, got %d", tc.expMessageCount, n)
}

})
}
}

0 comments on commit 55637b7

Please sign in to comment.