diff --git a/config/config.yaml b/config/config.yaml index 7e1a4b63..db040233 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -17,7 +17,7 @@ detectors: port: 8080 tls: caikit chunker_id: en_regex - config: {} + default_threshold: 0.5 tls: caikit: cert_path: /path/to/tls.crt diff --git a/src/config.rs b/src/config.rs index 95c42baf..ef64b43e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,6 +6,8 @@ use std::{ use serde::Deserialize; use tracing::debug; +/// Configuration for service needed for +/// orchestrator to communicate with it #[derive(Debug, Clone, Deserialize)] pub struct ServiceConfig { pub hostname: String, @@ -13,6 +15,7 @@ pub struct ServiceConfig { pub tls: Option, } +/// TLS provider #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] pub enum Tls { @@ -20,6 +23,7 @@ pub enum Tls { Config(TlsConfig), } +/// Client TLS configuration #[derive(Debug, Clone, Deserialize)] pub struct TlsConfig { pub cert_path: Option, @@ -27,6 +31,7 @@ pub struct TlsConfig { pub client_ca_cert_path: Option, } +/// Generation service provider #[derive(Debug, Clone, Copy, Deserialize)] #[serde(rename_all = "lowercase")] pub enum GenerationProvider { @@ -34,12 +39,16 @@ pub enum GenerationProvider { Nlp, } +/// Generate service configuration #[derive(Debug, Clone, Deserialize)] pub struct GenerationConfig { + /// Generation service provider pub provider: GenerationProvider, + /// Generation service connection information pub service: ServiceConfig, } +/// Chunker parser type #[derive(Debug, Clone, Copy, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ChunkerType { @@ -47,29 +56,43 @@ pub enum ChunkerType { All, } +/// Configuration for each chunker #[allow(dead_code)] #[derive(Debug, Clone, Deserialize)] pub struct ChunkerConfig { + /// Chunker type pub r#type: ChunkerType, + /// Chunker service connection information pub service: ServiceConfig, } +/// Configuration for each detector #[derive(Debug, Clone, Deserialize)] pub struct DetectorConfig { + /// Detector service connection information pub service: ServiceConfig, + /// ID of chunker that this detector will use pub chunker_id: String, - //pub config: HashMap, + /// Default threshold with which to filter detector results by score + pub default_threshold: f32, } +/// Overall orchestrator server configuration #[derive(Debug, Clone, Deserialize)] pub struct OrchestratorConfig { + /// Generation service and associated configuration pub generation: GenerationConfig, + /// Chunker services and associated configurations pub chunkers: HashMap, + /// Detector services and associated configurations pub detectors: HashMap, + /// Map of TLS connections, allowing reuse across services + /// that may require the same TLS information pub tls: HashMap, } impl OrchestratorConfig { + /// Load overall orchestrator server configuration pub async fn load(path: impl AsRef) -> Self { let path = path.as_ref(); let s = tokio::fs::read_to_string(path) @@ -108,6 +131,7 @@ impl OrchestratorConfig { todo!() } + /// Get ID of chunker associated with a particular detector pub fn get_chunker_id(&self, detector_id: &str) -> Option { self.detectors .get(detector_id) @@ -157,7 +181,7 @@ detectors: hostname: localhost port: 9000 chunker_id: sentence-en - config: {} + default_threshold: 0.5 tls: {} "#; let config: OrchestratorConfig = serde_yml::from_str(s)?; diff --git a/src/models.rs b/src/models.rs index 276cb143..5b48522c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,8 +1,12 @@ #![allow(unused_qualifications)] -use crate::{pb, server}; +use crate::pb; use std::collections::HashMap; +// TODO: When detector API is updated, consider if fields +// like 'threshold' can be named options instead of the +// use a generic HashMap with Values here +// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 pub type DetectorParams = HashMap; /// User request to orchestrator @@ -28,15 +32,23 @@ pub struct GuardrailsHttpRequest { pub text_gen_parameters: Option, } +#[derive(Debug, thiserror::Error)] +pub enum ValidationError { + #[error("`{0}` is required")] + Required(String), + #[error("{0}")] + Invalid(String), +} + impl GuardrailsHttpRequest { /// Upfront validation of user request - pub fn validate(&self) -> Result<(), server::Error> { + pub fn validate(&self) -> Result<(), ValidationError> { // Validate required parameters if self.model_id.is_empty() { - return Err(server::Error::Validation("`model_id` is required".into())); + return Err(ValidationError::Required("model_id".into())); } if self.inputs.is_empty() { - return Err(server::Error::Validation("`inputs` is required".into())); + return Err(ValidationError::Required("inputs".into())); } // Validate masks let input_range = 0..self.inputs.len(); @@ -48,7 +60,7 @@ impl GuardrailsHttpRequest { if !input_masks.iter().all(|(start, end)| { input_range.contains(start) && input_range.contains(end) && start < end }) { - return Err(server::Error::Validation("invalid masks".into())); + return Err(ValidationError::Invalid("invalid masks".into())); } } Ok(()) diff --git a/src/orchestrator.rs b/src/orchestrator.rs index e6bb1dff..6328cd5a 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -238,15 +238,22 @@ async fn detect( let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - let chunker_id = + // Get the detector config + let detector_config = ctx.config - .get_chunker_id(&detector_id) + .detectors + .get(&detector_id) .ok_or_else(|| Error::DetectorNotFound { detector_id: detector_id.clone(), })?; - let chunks = chunks.get(&chunker_id).unwrap().clone(); + // Get the default threshold to use if threshold is not provided by the user + let default_threshold = detector_config.default_threshold; + // Get chunker for detector + let chunker_id = detector_config.chunker_id.as_str(); + let chunks = chunks.get(chunker_id).unwrap().clone(); Ok(tokio::spawn(async move { - handle_detection_task(ctx, detector_id, detector_params, chunks).await + handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks) + .await })) }) .collect::, Error>>()?; @@ -316,6 +323,7 @@ async fn handle_chunk_task( async fn handle_detection_task( ctx: Arc, detector_id: String, + default_threshold: f32, detector_params: DetectorParams, chunks: Vec, ) -> Result, Error> { @@ -325,7 +333,10 @@ async fn handle_detection_task( let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); async move { - let request = DetectorRequest::new(chunk.text.clone(), detector_params); + // NOTE: The detector request is expected to change and not actually + // take parameters. Any parameters will be ignored for now + // ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 + let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone()); debug!( %detector_id, ?request, @@ -344,14 +355,19 @@ async fn handle_detection_task( ?response, "received detector response" ); + // Filter results based on threshold (if applicable) here let results = response .detections .into_iter() - .map(|detection| { + .filter_map(|detection| { let mut result: TokenClassificationResult = detection.into(); result.start += chunk.offset as u32; result.end += chunk.offset as u32; - result + let threshold = detector_params + .get("threshold") + .and_then(|v| v.as_f64()) + .unwrap_or(default_threshold as f64); + (result.score >= threshold).then_some(result) }) .collect::>(); Ok::, Error>(results) diff --git a/src/server.rs b/src/server.rs index 61468815..028294f2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -183,3 +183,9 @@ impl IntoResponse for Error { (code, Json(error)).into_response() } } + +impl From for Error { + fn from(value: models::ValidationError) -> Self { + Self::Validation(value.to_string()) + } +}