Skip to content

Commit

Permalink
🐛 Fixed missing models field breaking deserialization by adding valid…
Browse files Browse the repository at this point in the history
…ation (#91)

Signed-off-by: Paulo Caldeira <[email protected]>
  • Loading branch information
pmcjr committed Jul 1, 2024
1 parent 4b2d30f commit 2636700
Showing 1 changed file with 56 additions and 10 deletions.
66 changes: 56 additions & 10 deletions src/models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![allow(unused_qualifications)]

use crate::pb;
use std::collections::HashMap;

use crate::pb;

/// Parameters relevant to each detector
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct DetectorParams {
Expand Down Expand Up @@ -39,6 +40,8 @@ pub enum ValidationError {
Required(String),
#[error("{0}")]
Invalid(String),
#[error("{0} field not present in {1}")]
Missing(String, String),
}

impl GuardrailsHttpRequest {
Expand All @@ -64,6 +67,28 @@ impl GuardrailsHttpRequest {
return Err(ValidationError::Invalid("invalid masks".into()));
}
}
// Validate Guardrail config input/output models
let config = self.guardrail_config.as_ref();
let (config_input, config_output) = (
config.and_then(|config| config.input.as_ref()),
config.and_then(|config| config.output.as_ref()),
);
if let Some(config_input) = config_input {
if config_input.models.is_none() {
return Err(ValidationError::Missing(
"models".into(),
"guardrail config input".into(),
));
}
}
if let Some(config_output) = config_output {
if config_output.models.is_none() {
return Err(ValidationError::Missing(
"models".into(),
"guardrail config output".into(),
));
}
}
Ok(())
}
}
Expand Down Expand Up @@ -93,7 +118,9 @@ impl GuardrailsConfig {
}

pub fn output_detectors(&self) -> Option<&HashMap<String, DetectorParams>> {
self.output.as_ref().and_then(|output| output.models.as_ref())
self.output
.as_ref()
.and_then(|output| output.models.as_ref())
}
}

Expand All @@ -102,7 +129,6 @@ impl GuardrailsConfig {
pub struct GuardrailsConfigInput {
/// Map of model name to model specific parameters
#[serde(rename = "models")]
#[serde(skip_serializing_if = "Option::is_none")]
pub models: Option<HashMap<String, DetectorParams>>,
/// Vector of spans are in the form of (span_start, span_end) corresponding
/// to spans of input text on which to run input detection
Expand All @@ -116,7 +142,6 @@ pub struct GuardrailsConfigInput {
pub struct GuardrailsConfigOutput {
/// Map of model name to model specific parameters
#[serde(rename = "models")]
#[serde(skip_serializing_if = "Option::is_none")]
pub models: Option<HashMap<String, DetectorParams>>,
}

Expand Down Expand Up @@ -792,21 +817,42 @@ mod tests {
let error = result.unwrap_err().to_string();
assert!(error.contains("invalid masks"));

// Missing models expected OK
// No config input models
let request = GuardrailsHttpRequest {
model_id: "model".to_string(),
inputs: "The cow jumped over the moon!".to_string(),
inputs: "This is ignored anyway!".to_string(),
guardrail_config: Some(GuardrailsConfig {
input: Some(GuardrailsConfigInput {
masks: Some(vec![(5, 8)]),
masks: None,
models: None,
}),
output: Some(GuardrailsConfigOutput {
models: None,
models: Some(HashMap::new()),
}),
}),
text_gen_parameters: None,
};
assert!(request.validate().is_ok());
let result = request.validate();
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.contains("models field not present in guardrail config input"));

// No config output models
let request = GuardrailsHttpRequest {
model_id: "model".to_string(),
inputs: "This is ignored anyway!".to_string(),
guardrail_config: Some(GuardrailsConfig {
input: Some(GuardrailsConfigInput {
masks: None,
models: Some(HashMap::new()),
}),
output: Some(GuardrailsConfigOutput { models: None }),
}),
text_gen_parameters: None,
};
let result = request.validate();
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.contains("models field not present in guardrail config output"));
}
}
}

0 comments on commit 2636700

Please sign in to comment.