Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add thresholding for detector results #52

Merged
merged 14 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ detectors:
port: 8080
tls: caikit
chunker_id: en_regex
config: {}
default_threshold: 0.5
tls:
caikit:
cert_path: /path/to/tls.crt
Expand Down
28 changes: 26 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,70 +6,93 @@ use std::{
use serde::Deserialize;
use tracing::debug;

/// Configuration for service needed for
/// orchestrator to communicate with it
#[derive(Debug, Clone, Deserialize)]
pub struct ServiceConfig {
pub hostname: String,
pub port: Option<u16>,
pub tls: Option<Tls>,
}

/// TLS provider
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum Tls {
Name(String),
Config(TlsConfig),
}

/// Client TLS configuration
#[derive(Debug, Clone, Deserialize)]
pub struct TlsConfig {
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
pub client_ca_cert_path: Option<PathBuf>,
}

/// Generation service provider
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum GenerationProvider {
Tgis,
Nlp,
}

/// Generate service configuration
#[derive(Debug, Clone, Deserialize)]
pub struct GenerationConfig {
/// Generation service provider
pub provider: GenerationProvider,
/// Generation service connection information
pub service: ServiceConfig,
}

/// Chunker parser type
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChunkerType {
Sentence,
All,
}

/// Configuration for each chunker
#[allow(dead_code)]
#[derive(Debug, Clone, Deserialize)]
pub struct ChunkerConfig {
/// Chunker type
pub r#type: ChunkerType,
/// Chunker service connection information
pub service: ServiceConfig,
}

/// Configuration for each detector
#[derive(Debug, Clone, Deserialize)]
pub struct DetectorConfig {
/// Detector service connection information
pub service: ServiceConfig,
/// ID of chunker that this detector will use
pub chunker_id: String,
//pub config: HashMap<String, String>,
/// Default threshold with which to filter detector results by score
pub default_threshold: f32,
}

/// Overall orchestrator server configuration
#[derive(Debug, Clone, Deserialize)]
pub struct OrchestratorConfig {
/// Generation service and associated configuration
pub generation: GenerationConfig,
/// Chunker services and associated configurations
pub chunkers: HashMap<String, ChunkerConfig>,
/// Detector services and associated configurations
pub detectors: HashMap<String, DetectorConfig>,
/// Map of TLS connections, allowing reuse across services
/// that may require the same TLS information
pub tls: HashMap<String, TlsConfig>,
}

impl OrchestratorConfig {
/// Load overall orchestrator server configuration
pub async fn load(path: impl AsRef<Path>) -> Self {
let path = path.as_ref();
let s = tokio::fs::read_to_string(path)
Expand Down Expand Up @@ -108,6 +131,7 @@ impl OrchestratorConfig {
todo!()
}

/// Get ID of chunker associated with a particular detector
pub fn get_chunker_id(&self, detector_id: &str) -> Option<String> {
self.detectors
.get(detector_id)
Expand Down Expand Up @@ -157,7 +181,7 @@ detectors:
hostname: localhost
port: 9000
chunker_id: sentence-en
config: {}
default_threshold: 0.5
tls: {}
"#;
let config: OrchestratorConfig = serde_yml::from_str(s)?;
Expand Down
22 changes: 17 additions & 5 deletions src/models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#![allow(unused_qualifications)]

use crate::{pb, server};
use crate::pb;
use std::collections::HashMap;

// TODO: When detector API is updated, consider if fields
// like 'threshold' can be named options instead of the
// use a generic HashMap with Values here
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
pub type DetectorParams = HashMap<String, serde_json::Value>;

/// User request to orchestrator
Expand All @@ -28,15 +32,23 @@ pub struct GuardrailsHttpRequest {
pub text_gen_parameters: Option<GuardrailsTextGenerationParameters>,
}

#[derive(Debug, thiserror::Error)]
pub enum ValidationError {
#[error("`{0}` is required")]
Required(String),
#[error("{0}")]
Invalid(String),
}

impl GuardrailsHttpRequest {
/// Upfront validation of user request
pub fn validate(&self) -> Result<(), server::Error> {
pub fn validate(&self) -> Result<(), ValidationError> {
// Validate required parameters
if self.model_id.is_empty() {
return Err(server::Error::Validation("`model_id` is required".into()));
return Err(ValidationError::Required("model_id".into()));
}
if self.inputs.is_empty() {
return Err(server::Error::Validation("`inputs` is required".into()));
return Err(ValidationError::Required("inputs".into()));
}
// Validate masks
let input_range = 0..self.inputs.len();
Expand All @@ -48,7 +60,7 @@ impl GuardrailsHttpRequest {
if !input_masks.iter().all(|(start, end)| {
input_range.contains(start) && input_range.contains(end) && start < end
}) {
return Err(server::Error::Validation("invalid masks".into()));
return Err(ValidationError::Invalid("invalid masks".into()));
}
}
Ok(())
Expand Down
30 changes: 23 additions & 7 deletions src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,22 @@ async fn detect(
let ctx = ctx.clone();
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
let chunker_id =
// Get the detector config
let detector_config =
ctx.config
.get_chunker_id(&detector_id)
.detectors
.get(&detector_id)
.ok_or_else(|| Error::DetectorNotFound {
detector_id: detector_id.clone(),
})?;
let chunks = chunks.get(&chunker_id).unwrap().clone();
// Get the default threshold to use if threshold is not provided by the user
let default_threshold = detector_config.default_threshold;
// Get chunker for detector
let chunker_id = detector_config.chunker_id.as_str();
let chunks = chunks.get(chunker_id).unwrap().clone();
Ok(tokio::spawn(async move {
handle_detection_task(ctx, detector_id, detector_params, chunks).await
handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks)
.await
}))
})
.collect::<Result<Vec<_>, Error>>()?;
Expand Down Expand Up @@ -316,6 +323,7 @@ async fn handle_chunk_task(
async fn handle_detection_task(
ctx: Arc<Context>,
detector_id: String,
default_threshold: f32,
detector_params: DetectorParams,
chunks: Vec<Chunk>,
) -> Result<Vec<TokenClassificationResult>, Error> {
Expand All @@ -325,7 +333,10 @@ async fn handle_detection_task(
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
async move {
let request = DetectorRequest::new(chunk.text.clone(), detector_params);
// NOTE: The detector request is expected to change and not actually
// take parameters. Any parameters will be ignored for now
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone());
debug!(
%detector_id,
?request,
Expand All @@ -344,14 +355,19 @@ async fn handle_detection_task(
?response,
"received detector response"
);
// Filter results based on threshold (if applicable) here
let results = response
.detections
.into_iter()
.map(|detection| {
.filter_map(|detection| {
let mut result: TokenClassificationResult = detection.into();
result.start += chunk.offset as u32;
result.end += chunk.offset as u32;
result
let threshold = detector_params
.get("threshold")
.and_then(|v| v.as_f64())
.unwrap_or(default_threshold as f64);
(result.score >= threshold).then_some(result)
})
.collect::<Vec<_>>();
Ok::<Vec<TokenClassificationResult>, Error>(results)
Expand Down
6 changes: 6 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,9 @@ impl IntoResponse for Error {
(code, Json(error)).into_response()
}
}

impl From<models::ValidationError> for Error {
fn from(value: models::ValidationError) -> Self {
Self::Validation(value.to_string())
}
}