diff --git a/Cargo.toml b/Cargo.toml index 56321242..81ad8a0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,9 @@ uuid = { version = "1.8.0", features = ["v4", "fast-rng"] } [build-dependencies] tonic-build = "0.11.0" +[dev-dependencies] +faux = "0.1.10" + [profile.release] debug = false incremental = true diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 493ad8e6..548e6269 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -20,6 +20,7 @@ use crate::{ const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; +#[cfg_attr(test, derive(Default))] #[derive(Clone)] pub struct ChunkerClient { clients: HashMap>, diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 86e5e54a..f6b3a31f 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -7,6 +7,7 @@ use crate::config::ServiceConfig; const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; +#[cfg_attr(test, derive(Default))] #[derive(Clone)] pub struct DetectorClient { clients: HashMap, diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index b1dd3942..33e6a67d 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -15,11 +15,13 @@ use crate::{ }, }; +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TgisClient { clients: HashMap>, } +#[cfg_attr(test, faux::methods)] impl TgisClient { pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await; @@ -85,3 +87,36 @@ impl TgisClient { .into_inner()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::pb::fmaas::model_info_response; + + #[tokio::test] + async fn test_model_info() { + // Initialize a mock object from `TgisClient` + let mut mock_client = TgisClient::faux(); + + let request = ModelInfoRequest { + model_id: "test-model-1".to_string(), + }; + + let expected_response = ModelInfoResponse { + max_sequence_length: 2, + max_new_tokens: 20, + max_beam_width: 3, + model_kind: model_info_response::ModelKind::DecoderOnly.into(), + max_beam_sequence_lengths: [].to_vec(), + }; + // Construct a behavior for the mock object + faux::when!(mock_client.model_info(request.clone())) + .once() + .then_return(Ok(expected_response.clone())); + // Test the mock object's behaviour + assert_eq!( + mock_client.model_info(request).await.unwrap(), + expected_response + ); + } +} diff --git a/src/config.rs b/src/config.rs index ef64b43e..440a59a8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,7 +8,7 @@ use tracing::debug; /// Configuration for service needed for /// orchestrator to communicate with it -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Debug, Default, Deserialize)] pub struct ServiceConfig { pub hostname: String, pub port: Option, @@ -16,7 +16,7 @@ pub struct ServiceConfig { } /// TLS provider -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Debug, Deserialize)] #[serde(untagged)] pub enum Tls { Name(String), @@ -24,7 +24,7 @@ pub enum Tls { } /// Client TLS configuration -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Debug, Default, Deserialize)] pub struct TlsConfig { pub cert_path: Option, pub key_path: Option, @@ -32,15 +32,18 @@ pub struct TlsConfig { } /// Generation service provider -#[derive(Debug, Clone, Copy, Deserialize)] +#[cfg_attr(test, derive(Default))] +#[derive(Clone, Copy, Debug, Deserialize)] #[serde(rename_all = "lowercase")] pub enum GenerationProvider { + #[cfg_attr(test, default)] Tgis, Nlp, } /// Generate service configuration -#[derive(Debug, Clone, Deserialize)] +#[cfg_attr(test, derive(Default))] +#[derive(Clone, Debug, Deserialize)] pub struct GenerationConfig { /// Generation service provider pub provider: GenerationProvider, @@ -49,16 +52,19 @@ pub struct GenerationConfig { } /// Chunker parser type -#[derive(Debug, Clone, Copy, Deserialize)] +#[cfg_attr(test, derive(Default))] +#[derive(Clone, Copy, Debug, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ChunkerType { + #[cfg_attr(test, default)] Sentence, All, } /// Configuration for each chunker +#[cfg_attr(test, derive(Default))] #[allow(dead_code)] -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct ChunkerConfig { /// Chunker type pub r#type: ChunkerType, @@ -67,7 +73,7 @@ pub struct ChunkerConfig { } /// Configuration for each detector -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Debug, Default, Deserialize)] pub struct DetectorConfig { /// Detector service connection information pub service: ServiceConfig, @@ -78,7 +84,8 @@ pub struct DetectorConfig { } /// Overall orchestrator server configuration -#[derive(Debug, Clone, Deserialize)] +#[cfg_attr(test, derive(Default))] +#[derive(Clone, Debug, Deserialize)] pub struct OrchestratorConfig { /// Generation service and associated configuration pub generation: GenerationConfig, @@ -150,6 +157,13 @@ fn service_tls_name_to_config( service } +#[cfg(test)] +impl Default for Tls { + fn default() -> Self { + Tls::Name("dummy_tls".to_string()) + } +} + #[cfg(test)] mod tests { use anyhow::Error; diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 926d5d6b..bc42c577 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -689,6 +689,28 @@ impl StreamingClassificationWithGenTask { #[cfg(test)] mod tests { use super::*; + use crate::{ + models::FinishReason, + pb::fmaas::{ + BatchedGenerationRequest, BatchedGenerationResponse, GenerationResponse, StopReason, + }, + }; + + async fn get_test_context( + gen_client: GenerationClient, + chunker_client: Option, + detector_client: Option, + ) -> Context { + let chunker_client = chunker_client.unwrap_or_default(); + let detector_client = detector_client.unwrap_or_default(); + + Context { + generation_client: gen_client, + chunker_client, + detector_client, + config: OrchestratorConfig::default(), + } + } #[test] fn test_apply_masks() { @@ -709,4 +731,62 @@ mod tests { let s = "哈囉世界"; assert_eq!(slice_codepoints(s, 3, 4), "界"); } + + // Test for TGIS generation with default parameter + #[tokio::test] + async fn test_tgis_generate_with_default_params() { + // Initialize a mock object from `TgisClient` + let mut mock_client = TgisClient::faux(); + + let sample_text = String::from("sample text"); + let text_gen_model_id = String::from("test-llm-id-1"); + + let generation_response = GenerationResponse { + text: String::from("sample response worked"), + stop_reason: StopReason::EosToken.into(), + stop_sequence: String::from("\n"), + generated_token_count: 3, + seed: 7, + ..Default::default() + }; + + let client_generation_response = BatchedGenerationResponse { + responses: [generation_response].to_vec(), + }; + + let expected_generate_req_args = BatchedGenerationRequest { + model_id: text_gen_model_id.clone(), + prefix_id: None, + requests: [GenerationRequest { + text: sample_text.clone(), + }] + .to_vec(), + params: None, + }; + + let expected_generate_response = ClassifiedGeneratedTextResult { + generated_text: Some(client_generation_response.responses[0].text.clone()), + finish_reason: Some(FinishReason::EosToken), + generated_token_count: Some(3), + seed: Some(7), + ..Default::default() + }; + + // Construct a behavior for the mock object + faux::when!(mock_client.generate(expected_generate_req_args)) + .once() // TODO: Add with_args + .then_return(Ok(client_generation_response)); + + let mock_generation_client = GenerationClient::Tgis(mock_client.clone()); + + let ctx: Context = get_test_context(mock_generation_client, None, None).await; + + // Test request formulation and response processing is as expected + assert_eq!( + generate(ctx.into(), text_gen_model_id, sample_text, None) + .await + .unwrap(), + expected_generate_response + ); + } }