diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index c7f303db51b2..86f95fdb18d9 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -169,3 +169,72 @@ pub async fn create_chat_service(logger: Arc, chat: &ModelConfi ChatService::new(engine, logger) } + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use anyhow::Result; + use async_trait::async_trait; + use futures::StreamExt; + use tabby_inference::ChatCompletionOptions; + + use super::*; + + struct MockChatCompletionStream; + + #[async_trait] + impl ChatCompletionStream for MockChatCompletionStream { + async fn chat_completion( + &self, + _messages: &[Message], + _options: ChatCompletionOptions, + ) -> Result> { + let s = stream! { + yield "Hello, world!".into(); + }; + Ok(Box::pin(s)) + } + } + + struct MockEventLogger(Mutex>); + + impl EventLogger for MockEventLogger { + fn write(&self, x: tabby_common::api::event::LogEntry) { + self.0.lock().unwrap().push(x.event); + } + } + + #[tokio::test] + async fn test_chat_service() { + let engine = Arc::new(MockChatCompletionStream); + let logger = Arc::new(MockEventLogger(Default::default())); + let service = Arc::new(ChatService::new(engine, logger.clone())); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "user".into(), + content: "Hello, computer!".into(), + }], + temperature: None, + seed: None, + presence_penalty: None, + }; + let mut output = service.generate(request).await; + let response = output.next().await.unwrap(); + assert_eq!(response.choices[0].delta.content, "Hello, world!"); + + let finish = output.next().await.unwrap(); + assert_eq!(finish.choices[0].delta.content, ""); + assert_eq!(finish.choices[0].finish_reason.as_ref().unwrap(), "stop"); + + assert!(output.next().await.is_none()); + + let event = &logger.0.lock().unwrap()[0]; + let Event::ChatCompletion { output, .. } = event else { + panic!("Expected ChatCompletion event"); + }; + assert_eq!(output.role, "assistant"); + assert_eq!(output.content, "Hello, world!"); + } +}