Skip to content

Commit

Permalink
Merge pull request foundation-model-stack#233 from mdevino/chat-stand…
Browse files Browse the repository at this point in the history
…alone-endpoint

Chat standalone endpoint
  • Loading branch information
gkumbhat authored Oct 18, 2024
2 parents a9e62c2 + 47a8dbd commit 68101a8
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 20 deletions.
63 changes: 58 additions & 5 deletions src/clients/detector/text_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@
*/

use async_trait::async_trait;
use hyper::{HeaderMap, StatusCode};
use serde::Serialize;

use super::DEFAULT_PORT;
use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME};
use crate::{
clients::{create_http_client, Client, HttpClient},
clients::{create_http_client, openai::Message, Client, Error, HttpClient},
config::ServiceConfig,
health::HealthCheckResult,
models::{DetectionResult, DetectorParams},
};

const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat";

#[cfg_attr(test, faux::create)]
#[derive(Clone)]
pub struct TextChatDetectorClient {
Expand All @@ -46,9 +51,37 @@ impl TextChatDetectorClient {
}
}

pub async fn text_chat(&self) {
let _url = self.client.base_url().join("/api/v1/text/chat").unwrap();
todo!()
pub async fn text_chat(
&self,
model_id: &str,
request: ChatDetectionRequest,
headers: HeaderMap,
) -> Result<Vec<DetectionResult>, Error> {
let url = self.client.base_url().join(CHAT_DETECTOR_ENDPOINT).unwrap();
let request = self
.client
.post(url)
.headers(headers)
.header(DETECTOR_ID_HEADER_NAME, model_id)
.json(&request);

tracing::debug!("Request being sent to chat detector: {:?}", request);
let response = request.send().await?;
tracing::debug!("Response received from chat detector: {:?}", response);

if response.status() == StatusCode::OK {
Ok(response.json().await?)
} else {
let code = response.status().as_u16();
let error = response
.json::<DetectorError>()
.await
.unwrap_or(DetectorError {
code,
message: "".into(),
});
Err(error.into())
}
}
}

Expand All @@ -67,3 +100,23 @@ impl Client for TextChatDetectorClient {
}
}
}

/// A struct representing a request to a detector compatible with the
/// /api/v1/text/chat endpoint.
// #[cfg_attr(test, derive(PartialEq))]
#[derive(Debug, Serialize)]
pub struct ChatDetectionRequest {
/// Chat messages to run detection on
pub messages: Vec<Message>,

pub detector_params: DetectorParams,
}

impl ChatDetectionRequest {
pub fn new(messages: Vec<Message>, detector_params: DetectorParams) -> Self {
Self {
messages,
detector_params,
}
}
}
79 changes: 78 additions & 1 deletion src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};

use crate::{
clients::detector::{ContentAnalysisResponse, ContextType},
clients::{
self,
detector::{ContentAnalysisResponse, ContextType},
openai::Content,
},
health::HealthCheckCache,
pb,
};
Expand Down Expand Up @@ -939,6 +943,79 @@ pub struct ContextDocsResult {
pub detections: Vec<DetectionResult>,
}

/// The request format expected in the /api/v2/text/detect/generated endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChatDetectionHttpRequest {
/// The map of detectors to be used, along with their respective parameters, e.g. thresholds.
pub detectors: HashMap<String, DetectorParams>,

// The list of messages to run detections on.
pub messages: Vec<clients::openai::Message>,
}

impl ChatDetectionHttpRequest {
/// Upfront validation of user request
pub fn validate(&self) -> Result<(), ValidationError> {
// Validate required parameters
if self.detectors.is_empty() {
return Err(ValidationError::Required("detectors".into()));
}
if self.messages.is_empty() {
return Err(ValidationError::Required("messages".into()));
}

Ok(())
}

/// Validates for the "/api/v1/text/chat" endpoint.
pub fn validate_for_text(&self) -> Result<(), ValidationError> {
self.validate()?;
self.validate_messages()?;
validate_detector_params(&self.detectors)?;

Ok(())
}

/// Validates if message contents are either a string or a content type of type "text"
fn validate_messages(&self) -> Result<(), ValidationError> {
for message in &self.messages {
match &message.content {
Some(content) => self.validate_content_type(content)?,
None => {
return Err(ValidationError::Invalid(
"Message content cannot be empty".into(),
))
}
}
}
Ok(())
}

/// Validates if content type array contains only text messages
fn validate_content_type(&self, content: &Content) -> Result<(), ValidationError> {
match content {
Content::Array(content) => {
for content_part in content {
if content_part.r#type != "text" {
return Err(ValidationError::Invalid(
"Only content of type text is allowed".into(),
));
}
}
Ok(())
}
Content::String(_) => Ok(()), // if message.content is a string, it is a valid message
}
}
}

/// The response format of the /api/v2/text/detection/chat endpoint
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatDetectionResult {
/// Detection results
pub detections: Vec<DetectionResult>,
}

/// The request format expected in the /api/v2/text/detect/generated endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DetectionOnGeneratedHttpRequest {
Expand Down
34 changes: 31 additions & 3 deletions src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use uuid::Uuid;

use crate::{
clients::{
self,
chunker::ChunkerClient,
detector::{
text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient,
Expand All @@ -40,9 +41,9 @@ use crate::{
config::{DetectorType, GenerationProvider, OrchestratorConfig},
health::HealthCheckCache,
models::{
ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams,
GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest,
GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest,
ChatDetectionHttpRequest, ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest,
DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig,
GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest,
},
};

Expand Down Expand Up @@ -382,6 +383,33 @@ impl ContextDocsDetectionTask {
}
}

/// Task for the /api/v2/text/detection/chat endpoint
#[derive(Debug)]
pub struct ChatDetectionTask {
/// Request unique identifier
pub request_id: Uuid,

/// Detectors configuration
pub detectors: HashMap<String, DetectorParams>,

// Messages to run detection on
pub messages: Vec<clients::openai::Message>,

// Headermap
pub headers: HeaderMap,
}

impl ChatDetectionTask {
pub fn new(request_id: Uuid, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self {
Self {
request_id,
detectors: request.detectors,
messages: request.messages,
headers,
}
}
}

/// Task for the /api/v2/text/detection/generated endpoint
#[derive(Debug)]
pub struct DetectionOnGenerationTask {
Expand Down
115 changes: 107 additions & 8 deletions src/orchestrator/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,28 @@ use futures::{
use tracing::{debug, error, info, instrument};

use super::{
apply_masks, get_chunker_ids, Chunk, ClassificationWithGenTask, Context,
apply_masks, get_chunker_ids, ChatDetectionTask, Chunk, ClassificationWithGenTask, Context,
ContextDocsDetectionTask, DetectionOnGenerationTask, Error, GenerationWithDetectionTask,
Orchestrator, TextContentDetectionTask,
};
use crate::{
clients::{
chunker::{tokenize_whole_doc, ChunkerClient, DEFAULT_CHUNKER_ID},
detector::{
ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest,
ContextType, GenerationDetectionRequest, TextContentsDetectorClient,
TextContextDocDetectorClient, TextGenerationDetectorClient,
ChatDetectionRequest, ContentAnalysisRequest, ContentAnalysisResponse,
ContextDocsDetectionRequest, ContextType, GenerationDetectionRequest,
TextChatDetectorClient, TextContentsDetectorClient, TextContextDocDetectorClient,
TextGenerationDetectorClient,
},
openai::Message,
GenerationClient,
},
models::{
ClassifiedGeneratedTextResult, ContextDocsResult, DetectionOnGenerationResult,
DetectionResult, DetectorParams, GenerationWithDetectionResult,
GuardrailsTextGenerationParameters, InputWarning, InputWarningReason,
TextContentDetectionResult, TextGenTokenClassificationResults, TokenClassificationResult,
ChatDetectionResult, ClassifiedGeneratedTextResult, ContextDocsResult,
DetectionOnGenerationResult, DetectionResult, DetectorParams,
GenerationWithDetectionResult, GuardrailsTextGenerationParameters, InputWarning,
InputWarningReason, TextContentDetectionResult, TextGenTokenClassificationResults,
TokenClassificationResult,
},
orchestrator::UNSUITABLE_INPUT_MESSAGE,
pb::caikit::runtime::chunkers,
Expand Down Expand Up @@ -447,6 +450,61 @@ impl Orchestrator {
}
}
}

/// Handles detections on chat messages (without performing generation)
pub async fn handle_chat_detection(
&self,
task: ChatDetectionTask,
) -> Result<ChatDetectionResult, Error> {
info!(
request_id = ?task.request_id,
detectors = ?task.detectors,
"handling detection on chat content task"
);
let ctx = self.ctx.clone();
let headers = task.headers;

let task_handle = tokio::spawn(async move {
// call detection
let detections = try_join_all(
task.detectors
.iter()
.map(|(detector_id, detector_params)| {
let ctx = ctx.clone();
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
let messages = task.messages.clone();
let headers = headers.clone();
async {
detect_for_chat(ctx, detector_id, detector_params, messages, headers)
.await
}
})
.collect::<Vec<_>>(),
)
.await?
.into_iter()
.flatten()
.collect::<Vec<_>>();

Ok(ChatDetectionResult { detections })
});
match task_handle.await {
// Task completed successfully
Ok(Ok(result)) => Ok(result),
// Task failed, return error propagated from child task that failed
Ok(Err(error)) => {
error!(request_id = ?task.request_id, %error, "detection task on chat failed");
Err(error)
}
// Task cancelled or panicked
Err(error) => {
let error = error.into();
error!(request_id = ?task.request_id, %error, "detection task on chat failed");
Err(error)
}
}
}
}

/// Handles input detection task.
Expand Down Expand Up @@ -711,6 +769,47 @@ pub async fn detect_for_generation(
Ok::<Vec<DetectionResult>, Error>(response)
}

/// Calls a detector that implements the /api/v1/text/chat endpoint
pub async fn detect_for_chat(
ctx: Arc<Context>,
detector_id: String,
detector_params: DetectorParams,
messages: Vec<Message>,
headers: HeaderMap,
) -> Result<Vec<DetectionResult>, Error> {
let detector_id = detector_id.clone();
let threshold = detector_params.threshold().unwrap_or(
detector_params.threshold().unwrap_or(
ctx.config
.detectors
.get(&detector_id)
.ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))?
.default_threshold,
),
);
let request = ChatDetectionRequest::new(messages.clone(), detector_params.clone());
debug!(%detector_id, ?request, "sending chat detector request");
let client = ctx
.clients
.get_as::<TextChatDetectorClient>(&detector_id)
.unwrap();
let response = client
.text_chat(&detector_id, request, headers)
.await
.map(|results| {
results
.into_iter()
.filter(|detection| detection.score > threshold)
.collect()
})
.map_err(|error| Error::DetectorRequestFailed {
id: detector_id.clone(),
error,
})?;
debug!(%detector_id, ?response, "received chat detector response");
Ok::<Vec<DetectionResult>, Error>(response)
}

/// Calls a detector that implements the /api/v1/text/doc endpoint
pub async fn detect_for_context(
ctx: Arc<Context>,
Expand Down
Loading

0 comments on commit 68101a8

Please sign in to comment.