diff --git a/server/processing/processor.go b/server/processing/processor.go index a3f1849..fe351bb 100644 --- a/server/processing/processor.go +++ b/server/processing/processor.go @@ -10,7 +10,6 @@ import ( "github.com/teilomillet/gollm" "github.com/teilomillet/hapax/config" - "github.com/teilomillet/hapax/server/middleware" ) // Processor handles request processing and response formatting for LLM interactions. @@ -91,51 +90,70 @@ func (p *Processor) ProcessRequest(ctx context.Context, req *Request) (*Response return nil, fmt.Errorf("request cannot be nil") } - // Select the appropriate template, falling back to default - tmpl := p.templates["default"] - if t, ok := p.templates[req.Type]; ok { - tmpl = t - } - if tmpl == nil { - return nil, fmt.Errorf("no template found for type: %s", req.Type) - } + var promptMessages []gollm.PromptMessage - // Execute the template with the request data - var buf bytes.Buffer - err := tmpl.Execute(&buf, req) - if err != nil { - return nil, fmt.Errorf("template execution failed: %w", err) + // Always start with system prompt if we have one + if p.defaultPrompt != "" { + promptMessages = append(promptMessages, gollm.PromptMessage{ + Role: "system", + Content: p.defaultPrompt, + }) } - // Create an LLM prompt with system context - prompt := &gollm.Prompt{ - Messages: []gollm.PromptMessage{ - { - Role: "system", - Content: p.defaultPrompt, - }, - { - Role: "user", - Content: buf.String(), - }, - }, - } + // Now we have two clear paths - either conversation or single input + if len(req.Messages) > 0 { + // For conversations, we just need to convert the messages directly + for _, msg := range req.Messages { + promptMessages = append(promptMessages, gollm.PromptMessage{ + Role: msg.Role, + Content: msg.Content, + }) + } + } else if req.Input != "" { + // For single inputs, we still use the template system + tmpl := p.templates["default"] + if t, ok := p.templates[req.Type]; ok { + tmpl = t + } + if tmpl == nil { + return nil, fmt.Errorf("no template found for type: %s", req.Type) + } - // Pass timeout header to LLM context if present - if timeoutHeader := ctx.Value("X-Test-Timeout"); timeoutHeader != nil { - ctx = context.WithValue(ctx, middleware.XTestTimeoutKey, timeoutHeader) + var buf bytes.Buffer + if err := tmpl.Execute(&buf, req); err != nil { + return nil, fmt.Errorf("template execution failed: %w", err) + } + + promptMessages = append(promptMessages, gollm.PromptMessage{ + Role: "user", + Content: buf.String(), + }) + } else { + return nil, fmt.Errorf("request must contain either messages or input") } - // Send request to LLM + prompt := &gollm.Prompt{Messages: promptMessages} + response, err := p.llm.Generate(ctx, prompt) if err != nil { return nil, fmt.Errorf("LLM processing failed: %w", err) } - // Apply response formatting return p.formatResponse(response), nil } +// Helper function to convert our Message type to gollm.PromptMessage +func convertMessages(messages []Message) []gollm.PromptMessage { + promptMessages := make([]gollm.PromptMessage, len(messages)) + for i, msg := range messages { + promptMessages[i] = gollm.PromptMessage{ + Role: msg.Role, + Content: msg.Content, + } + } + return promptMessages +} + // formatResponse applies configured formatting options to the LLM response: // 1. Cleans JSON if enabled (removes markdown blocks, formats JSON) // 2. Trims whitespace if enabled diff --git a/server/processing/processor_test.go b/server/processing/processor_test.go index 650eb20..6338eb5 100644 --- a/server/processing/processor_test.go +++ b/server/processing/processor_test.go @@ -2,6 +2,7 @@ package processing import ( "context" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -78,26 +79,38 @@ func TestNewProcessor(t *testing.T) { // For example, chat messages are formatted as "role: content\n" to match // the chat template: "{{range .Messages}}{{.Role}}: {{.Content}}\n{{end}}" func TestProcessRequest(t *testing.T) { - // Map of input prompts to expected LLM responses - // Note: The keys must match the exact output of our templates + // Updated mock responses to match new message structure mockResponses := map[string]string{ - "Hello": "World", // Simple completion - "user: Hi\n": "Hello!", // Chat completion (matches template format) - "Test": "Very long response that should be truncated", - "undefined": "", // Default response for unmatched inputs + "Hello": "World", + "Hi": "Hello!", + "Test": "Very long response that should be truncated", + "undefined": "", } - // Create a mock LLM that returns predefined responses based on the input prompt - // For chat messages, we check prompt.Messages[1] because index 0 is the system prompt + // Updated mock LLM to handle the new message structure mockLLM := mocks.NewMockLLM(func(ctx context.Context, prompt *gollm.Prompt) (string, error) { - if len(prompt.Messages) < 2 { + // Always check if we have messages + if len(prompt.Messages) == 0 { return mockResponses["undefined"], nil } - return mockResponses[prompt.Messages[1].Content], nil + + // Find the last non-system message + var lastContent string + for i := len(prompt.Messages) - 1; i >= 0; i-- { + if prompt.Messages[i].Role != "system" { + lastContent = prompt.Messages[i].Content + break + } + } + + // Return corresponding response + if response, ok := mockResponses[lastContent]; ok { + return response, nil + } + return mockResponses["undefined"], nil }) - // Configure the processor with both simple and chat templates - // Also set up response formatting to test truncation and cleaning + // Rest of the configuration remains the same cfg := &config.ProcessingConfig{ RequestTemplates: map[string]string{ "default": "{{.Input}}", @@ -106,7 +119,7 @@ func TestProcessRequest(t *testing.T) { ResponseFormatting: config.ResponseFormattingConfig{ CleanJSON: true, TrimWhitespace: true, - MaxLength: 10, // Short length to test truncation + MaxLength: 10, }, } @@ -157,6 +170,7 @@ func TestProcessRequest(t *testing.T) { }, } + // Test execution remains the same for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { resp, err := proc.ProcessRequest(ctx, tt.req) @@ -222,3 +236,271 @@ func TestFormatResponse(t *testing.T) { }) } } + +// TestProcessMultiTurnConversation verifies that the processor correctly handles +// multi-turn conversations with different message types and maintains conversation context. +func TestProcessMultiTurnConversation(t *testing.T) { + // Create a mock LLM that verifies message handling and returns appropriate responses + mockLLM := mocks.NewMockLLM(func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + // Verify that messages are being passed correctly + assert.NotNil(t, prompt.Messages) + + // Get the last message to determine the response + lastMsg := prompt.Messages[len(prompt.Messages)-1] + + // Return different responses based on the message content + switch lastMsg.Content { + case "Hello, Claude": + return "Hello! How can I assist you today?", nil + case "Can you explain LLMs?": + // Verify that previous messages are preserved + assert.True(t, len(prompt.Messages) >= 3, "Expected previous messages to be included") + return "Language Learning Models (LLMs) are AI systems that process and generate text...", nil + default: + return "I understand. What else would you like to know?", nil + } + }) + + // Create processor with test configuration + cfg := &config.ProcessingConfig{ + ResponseFormatting: config.ResponseFormattingConfig{ + TrimWhitespace: true, + }, + } + + proc, err := NewProcessor(cfg, mockLLM) + assert.NoError(t, err) + + // Set a system prompt to verify it's included + proc.SetDefaultPrompt("You are a helpful AI assistant.") + + ctx := context.Background() + + // Test cases for different conversation patterns + tests := []struct { + name string + messages []Message + wantContent string + wantErr bool + }{ + { + name: "single message conversation", + messages: []Message{ + {Role: "user", Content: "Hello, Claude"}, + }, + wantContent: "Hello! How can I assist you today?", + wantErr: false, + }, + { + name: "multi-turn conversation", + messages: []Message{ + {Role: "user", Content: "Hello, Claude"}, + {Role: "assistant", Content: "Hello! How can I assist you today?"}, + {Role: "user", Content: "Can you explain LLMs?"}, + }, + wantContent: "Language Learning Models (LLMs) are AI systems that process and generate text...", + wantErr: false, + }, + { + name: "conversation with empty messages", + messages: []Message{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &Request{ + Type: "chat", + Messages: tt.messages, + } + + resp, err := proc.ProcessRequest(ctx, req) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantContent, resp.Content) + } + }) + } +} + +// TestMessageOrderPreservation ensures that messages are processed in the correct order +// and that the conversation context is maintained properly. +func TestMessageOrderPreservation(t *testing.T) { + var capturedMessages []gollm.PromptMessage + + mockLLM := mocks.NewMockLLM(func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + // Capture messages for verification + capturedMessages = prompt.Messages + return "Response", nil + }) + + cfg := &config.ProcessingConfig{ + ResponseFormatting: config.ResponseFormattingConfig{ + TrimWhitespace: true, + }, + } + + proc, err := NewProcessor(cfg, mockLLM) + assert.NoError(t, err) + proc.SetDefaultPrompt("System instruction") + + ctx := context.Background() + + // Test a complex conversation sequence + req := &Request{ + Type: "chat", + Messages: []Message{ + {Role: "user", Content: "First message"}, + {Role: "assistant", Content: "First response"}, + {Role: "user", Content: "Second message"}, + }, + } + + _, err = proc.ProcessRequest(ctx, req) + assert.NoError(t, err) + + // Verify message order and content + assert.Equal(t, "system", capturedMessages[0].Role) + assert.Equal(t, "System instruction", capturedMessages[0].Content) + assert.Equal(t, "user", capturedMessages[1].Role) + assert.Equal(t, "First message", capturedMessages[1].Content) + assert.Equal(t, "assistant", capturedMessages[2].Role) + assert.Equal(t, "First response", capturedMessages[2].Content) + assert.Equal(t, "user", capturedMessages[3].Role) + assert.Equal(t, "Second message", capturedMessages[3].Content) +} + +func TestAnthropicStyleConversations(t *testing.T) { + // First, let's create a mock LLM that can track conversation state and verify message handling. + // We'll make it return responses that depend on the conversation context. + var capturedMessages []gollm.PromptMessage + + mockLLM := mocks.NewMockLLM(func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + // Store the messages for verification + capturedMessages = prompt.Messages + + // Simulate different responses based on conversation context + lastMsg := prompt.Messages[len(prompt.Messages)-1] + switch lastMsg.Content { + case "Hello, Claude": + return "Hello! I'm here to help.", nil + case "Tell me about language models": + // We should see the previous exchange in context + if len(prompt.Messages) < 3 { + return "", fmt.Errorf("expected previous conversation context") + } + return "Language models are AI systems that process and generate text based on patterns learned from training data.", nil + case "Can you elaborate on that?": + // This should have the full conversation history + if len(prompt.Messages) < 5 { + return "", fmt.Errorf("missing conversation history") + } + return "Let me build on my previous explanation...", nil + default: + return "I didn't understand that specific query.", nil + } + }) + + // Create a processor with a simple configuration + cfg := &config.ProcessingConfig{ + ResponseFormatting: config.ResponseFormattingConfig{ + TrimWhitespace: true, + }, + } + + processor, err := NewProcessor(cfg, mockLLM) + assert.NoError(t, err, "Processor creation should succeed") + + // Set a system prompt to verify it's maintained throughout the conversation + processor.SetDefaultPrompt("You are a helpful AI assistant.") + + ctx := context.Background() + + // Now let's simulate a multi-turn conversation + conversationSteps := []struct { + name string + messages []Message + wantResponse string + wantMsgCount int // Expected number of messages including system prompt + shouldSucceed bool + }{ + { + name: "initial greeting", + messages: []Message{ + {Role: "user", Content: "Hello, Claude"}, + }, + wantResponse: "Hello! I'm here to help.", + wantMsgCount: 2, // System prompt + user message + shouldSucceed: true, + }, + { + name: "second turn with context", + messages: []Message{ + {Role: "user", Content: "Hello, Claude"}, + {Role: "assistant", Content: "Hello! I'm here to help."}, + {Role: "user", Content: "Tell me about language models"}, + }, + wantResponse: "Language models are AI systems that process and generate text based on patterns learned from training data.", + wantMsgCount: 4, // System + 3 conversation messages + shouldSucceed: true, + }, + { + name: "third turn with full history", + messages: []Message{ + {Role: "user", Content: "Hello, Claude"}, + {Role: "assistant", Content: "Hello! I'm here to help."}, + {Role: "user", Content: "Tell me about language models"}, + {Role: "assistant", Content: "Language models are AI systems that process and generate text based on patterns learned from training data."}, + {Role: "user", Content: "Can you elaborate on that?"}, + }, + wantResponse: "Let me build on my previous explanation...", + wantMsgCount: 6, // System + 5 conversation messages + shouldSucceed: true, + }, + } + + for _, step := range conversationSteps { + t.Run(step.name, func(t *testing.T) { + // Reset captured messages for this test + capturedMessages = nil + + // Create and send the request + req := &Request{ + Type: "chat", + Messages: step.messages, + } + + resp, err := processor.ProcessRequest(ctx, req) + + // Verify the results + if step.shouldSucceed { + assert.NoError(t, err, "Request should succeed") + assert.NotNil(t, resp, "Response should not be nil") + assert.Equal(t, step.wantResponse, resp.Content, "Response content should match expected") + + // Verify message count and system prompt + assert.Equal(t, step.wantMsgCount, len(capturedMessages), + "Incorrect number of messages in conversation") + assert.Equal(t, "system", capturedMessages[0].Role, + "First message should be system prompt") + assert.Equal(t, "You are a helpful AI assistant.", + capturedMessages[0].Content, "System prompt should be preserved") + + // Verify conversation order is maintained + if len(step.messages) > 0 { + lastMsg := capturedMessages[len(capturedMessages)-1] + assert.Equal(t, step.messages[len(step.messages)-1].Content, + lastMsg.Content, "Last message content should match") + assert.Equal(t, step.messages[len(step.messages)-1].Role, + lastMsg.Role, "Last message role should match") + } + } else { + assert.Error(t, err, "Request should fail") + } + }) + } +}