diff --git a/src/models.rs b/src/models.rs index b5b8482f..da1e3e5d 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,8 +1,9 @@ #![allow(unused_qualifications)] -use crate::pb; use std::collections::HashMap; +use crate::pb; + /// Parameters relevant to each detector #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] pub struct DetectorParams { @@ -39,6 +40,8 @@ pub enum ValidationError { Required(String), #[error("{0}")] Invalid(String), + #[error("{0} field not present in {1}")] + Missing(String, String), } impl GuardrailsHttpRequest { @@ -64,6 +67,28 @@ impl GuardrailsHttpRequest { return Err(ValidationError::Invalid("invalid masks".into())); } } + // Validate Guardrail config input/output models + let config = self.guardrail_config.as_ref(); + let (config_input, config_output) = ( + config.and_then(|config| config.input.as_ref()), + config.and_then(|config| config.output.as_ref()), + ); + if let Some(config_input) = config_input { + if config_input.models.is_none() { + return Err(ValidationError::Missing( + "models".into(), + "guardrail config input".into(), + )); + } + } + if let Some(config_output) = config_output { + if config_output.models.is_none() { + return Err(ValidationError::Missing( + "models".into(), + "guardrail config output".into(), + )); + } + } Ok(()) } } @@ -93,7 +118,9 @@ impl GuardrailsConfig { } pub fn output_detectors(&self) -> Option<&HashMap> { - self.output.as_ref().and_then(|output| output.models.as_ref()) + self.output + .as_ref() + .and_then(|output| output.models.as_ref()) } } @@ -102,7 +129,6 @@ impl GuardrailsConfig { pub struct GuardrailsConfigInput { /// Map of model name to model specific parameters #[serde(rename = "models")] - #[serde(skip_serializing_if = "Option::is_none")] pub models: Option>, /// Vector of spans are in the form of (span_start, span_end) corresponding /// to spans of input text on which to run input detection @@ -116,7 +142,6 @@ pub struct GuardrailsConfigInput { pub struct GuardrailsConfigOutput { /// Map of model name to model specific parameters #[serde(rename = "models")] - #[serde(skip_serializing_if = "Option::is_none")] pub models: Option>, } @@ -792,21 +817,42 @@ mod tests { let error = result.unwrap_err().to_string(); assert!(error.contains("invalid masks")); - // Missing models expected OK + // No config input models let request = GuardrailsHttpRequest { model_id: "model".to_string(), - inputs: "The cow jumped over the moon!".to_string(), + inputs: "This is ignored anyway!".to_string(), guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { - masks: Some(vec![(5, 8)]), + masks: None, models: None, }), output: Some(GuardrailsConfigOutput { - models: None, + models: Some(HashMap::new()), }), }), text_gen_parameters: None, }; - assert!(request.validate().is_ok()); + let result = request.validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("models field not present in guardrail config input")); + + // No config output models + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "This is ignored anyway!".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: None, + models: Some(HashMap::new()), + }), + output: Some(GuardrailsConfigOutput { models: None }), + }), + text_gen_parameters: None, + }; + let result = request.validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("models field not present in guardrail config output")); } -} +} \ No newline at end of file