Skip to content

Commit

Permalink
Merge pull request #68 from gkumbhat/add_mock_lib
Browse files Browse the repository at this point in the history
Add mock lib
  • Loading branch information
gkumbhat authored Jun 7, 2024
2 parents 2354acd + 627cfab commit e82dfbf
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 9 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
[build-dependencies]
tonic-build = "0.11.0"

[dev-dependencies]
faux = "0.1.10"

[profile.release]
debug = false
incremental = true
Expand Down
1 change: 1 addition & 0 deletions src/clients/chunker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{

const MODEL_ID_HEADER_NAME: &str = "mm-model-id";

#[cfg_attr(test, derive(Default))]
#[derive(Clone)]
pub struct ChunkerClient {
clients: HashMap<String, ChunkersServiceClient<LoadBalancedChannel>>,
Expand Down
1 change: 1 addition & 0 deletions src/clients/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::config::ServiceConfig;

const DETECTOR_ID_HEADER_NAME: &str = "detector-id";

#[cfg_attr(test, derive(Default))]
#[derive(Clone)]
pub struct DetectorClient {
clients: HashMap<String, HttpClient>,
Expand Down
35 changes: 35 additions & 0 deletions src/clients/tgis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ use crate::{
},
};

#[cfg_attr(test, faux::create)]
#[derive(Clone)]
pub struct TgisClient {
clients: HashMap<String, GenerationServiceClient<LoadBalancedChannel>>,
}

#[cfg_attr(test, faux::methods)]
impl TgisClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await;
Expand Down Expand Up @@ -85,3 +87,36 @@ impl TgisClient {
.into_inner())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::pb::fmaas::model_info_response;

#[tokio::test]
async fn test_model_info() {
// Initialize a mock object from `TgisClient`
let mut mock_client = TgisClient::faux();

let request = ModelInfoRequest {
model_id: "test-model-1".to_string(),
};

let expected_response = ModelInfoResponse {
max_sequence_length: 2,
max_new_tokens: 20,
max_beam_width: 3,
model_kind: model_info_response::ModelKind::DecoderOnly.into(),
max_beam_sequence_lengths: [].to_vec(),
};
// Construct a behavior for the mock object
faux::when!(mock_client.model_info(request.clone()))
.once()
.then_return(Ok(expected_response.clone()));
// Test the mock object's behaviour
assert_eq!(
mock_client.model_info(request).await.unwrap(),
expected_response
);
}
}
32 changes: 23 additions & 9 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,42 @@ use tracing::debug;

/// Configuration for service needed for
/// orchestrator to communicate with it
#[derive(Debug, Clone, Deserialize)]
#[derive(Clone, Debug, Default, Deserialize)]
pub struct ServiceConfig {
pub hostname: String,
pub port: Option<u16>,
pub tls: Option<Tls>,
}

/// TLS provider
#[derive(Debug, Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
#[serde(untagged)]
pub enum Tls {
Name(String),
Config(TlsConfig),
}

/// Client TLS configuration
#[derive(Debug, Clone, Deserialize)]
#[derive(Clone, Debug, Default, Deserialize)]
pub struct TlsConfig {
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
pub client_ca_cert_path: Option<PathBuf>,
}

/// Generation service provider
#[derive(Debug, Clone, Copy, Deserialize)]
#[cfg_attr(test, derive(Default))]
#[derive(Clone, Copy, Debug, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum GenerationProvider {
#[cfg_attr(test, default)]
Tgis,
Nlp,
}

/// Generate service configuration
#[derive(Debug, Clone, Deserialize)]
#[cfg_attr(test, derive(Default))]
#[derive(Clone, Debug, Deserialize)]
pub struct GenerationConfig {
/// Generation service provider
pub provider: GenerationProvider,
Expand All @@ -49,16 +52,19 @@ pub struct GenerationConfig {
}

/// Chunker parser type
#[derive(Debug, Clone, Copy, Deserialize)]
#[cfg_attr(test, derive(Default))]
#[derive(Clone, Copy, Debug, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChunkerType {
#[cfg_attr(test, default)]
Sentence,
All,
}

/// Configuration for each chunker
#[cfg_attr(test, derive(Default))]
#[allow(dead_code)]
#[derive(Debug, Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct ChunkerConfig {
/// Chunker type
pub r#type: ChunkerType,
Expand All @@ -67,7 +73,7 @@ pub struct ChunkerConfig {
}

/// Configuration for each detector
#[derive(Debug, Clone, Deserialize)]
#[derive(Clone, Debug, Default, Deserialize)]
pub struct DetectorConfig {
/// Detector service connection information
pub service: ServiceConfig,
Expand All @@ -78,7 +84,8 @@ pub struct DetectorConfig {
}

/// Overall orchestrator server configuration
#[derive(Debug, Clone, Deserialize)]
#[cfg_attr(test, derive(Default))]
#[derive(Clone, Debug, Deserialize)]
pub struct OrchestratorConfig {
/// Generation service and associated configuration
pub generation: GenerationConfig,
Expand Down Expand Up @@ -150,6 +157,13 @@ fn service_tls_name_to_config(
service
}

#[cfg(test)]
impl Default for Tls {
fn default() -> Self {
Tls::Name("dummy_tls".to_string())
}
}

#[cfg(test)]
mod tests {
use anyhow::Error;
Expand Down
80 changes: 80 additions & 0 deletions src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,28 @@ impl StreamingClassificationWithGenTask {
#[cfg(test)]
mod tests {
use super::*;
use crate::{
models::FinishReason,
pb::fmaas::{
BatchedGenerationRequest, BatchedGenerationResponse, GenerationResponse, StopReason,
},
};

async fn get_test_context(
gen_client: GenerationClient,
chunker_client: Option<ChunkerClient>,
detector_client: Option<DetectorClient>,
) -> Context {
let chunker_client = chunker_client.unwrap_or_default();
let detector_client = detector_client.unwrap_or_default();

Context {
generation_client: gen_client,
chunker_client,
detector_client,
config: OrchestratorConfig::default(),
}
}

#[test]
fn test_apply_masks() {
Expand All @@ -709,4 +731,62 @@ mod tests {
let s = "哈囉世界";
assert_eq!(slice_codepoints(s, 3, 4), "界");
}

// Test for TGIS generation with default parameter
#[tokio::test]
async fn test_tgis_generate_with_default_params() {
// Initialize a mock object from `TgisClient`
let mut mock_client = TgisClient::faux();

let sample_text = String::from("sample text");
let text_gen_model_id = String::from("test-llm-id-1");

let generation_response = GenerationResponse {
text: String::from("sample response worked"),
stop_reason: StopReason::EosToken.into(),
stop_sequence: String::from("\n"),
generated_token_count: 3,
seed: 7,
..Default::default()
};

let client_generation_response = BatchedGenerationResponse {
responses: [generation_response].to_vec(),
};

let expected_generate_req_args = BatchedGenerationRequest {
model_id: text_gen_model_id.clone(),
prefix_id: None,
requests: [GenerationRequest {
text: sample_text.clone(),
}]
.to_vec(),
params: None,
};

let expected_generate_response = ClassifiedGeneratedTextResult {
generated_text: Some(client_generation_response.responses[0].text.clone()),
finish_reason: Some(FinishReason::EosToken),
generated_token_count: Some(3),
seed: Some(7),
..Default::default()
};

// Construct a behavior for the mock object
faux::when!(mock_client.generate(expected_generate_req_args))
.once() // TODO: Add with_args
.then_return(Ok(client_generation_response));

let mock_generation_client = GenerationClient::Tgis(mock_client.clone());

let ctx: Context = get_test_context(mock_generation_client, None, None).await;

// Test request formulation and response processing is as expected
assert_eq!(
generate(ctx.into(), text_gen_model_id, sample_text, None)
.await
.unwrap(),
expected_generate_response
);
}
}

0 comments on commit e82dfbf

Please sign in to comment.