Skip to content

Commit

Permalink
feat[ai]: Add conversation endpoints.
Browse files Browse the repository at this point in the history
LLM internal state messages is simplified to be a single field either on ChatState or in the Conversation table.
We now keep timestamps for individual ChatMessages.

Also drop the db connection arg from the completions method as the DB is
Not needed for those calls.

Signed-off-by: Hiram Chirino <[email protected]>
  • Loading branch information
chirino committed Jan 10, 2025
1 parent 44ade24 commit 773e6d5
Show file tree
Hide file tree
Showing 7 changed files with 894 additions and 174 deletions.
251 changes: 247 additions & 4 deletions modules/fundamental/src/ai/endpoints/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
#[cfg(test)]
mod test;

use crate::ai::model::{Conversation, ConversationSummary};
use crate::{
ai::model::{AiFlags, AiTool, ChatState},
ai::service::AiService,
Error,
};
use actix_http::header;
use actix_web::{get, post, web, HttpResponse, Responder};
use actix_web::{delete, get, post, put, web, HttpResponse, Responder};
use itertools::Itertools;
use time::OffsetDateTime;
use trustify_auth::authenticator::user::UserDetails;
use trustify_auth::{authorizer::Require, Ai};
use trustify_common::db::query::Query;
use trustify_common::db::Database;
use trustify_common::model::{Paginated, PaginatedResults};
use uuid::Uuid;

pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) {
let service = AiService::new(db.clone());
Expand All @@ -19,7 +25,12 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d
.service(completions)
.service(flags)
.service(tools)
.service(tool_call);
.service(tool_call)
.service(create_conversation)
.service(update_conversation)
.service(list_conversations)
.service(get_conversation)
.service(delete_conversation);
}

#[utoipa::path(
Expand All @@ -35,11 +46,10 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d
#[post("/v1/ai/completions")]
pub async fn completions(
service: web::Data<AiService>,
db: web::Data<Database>,
request: web::Json<ChatState>,
_: Require<Ai>,
) -> actix_web::Result<impl Responder> {
let response = service.completions(&request, db.as_ref()).await?;
let response = service.completions(&request).await?;
Ok(HttpResponse::Ok().json(response))
}

Expand Down Expand Up @@ -123,3 +133,236 @@ pub async fn tool_call(
.insert_header((header::CONTENT_TYPE, "text/plain"))
.body(result))
}

#[utoipa::path(
tag = "ai",
operation_id = "createConversation",
responses(
(status = 200, description = "The resulting conversation", body = Conversation),
(status = 400, description = "The request was invalid"),
(status = 404, description = "The AI service is not enabled")
)
)]
#[post("/v1/ai/conversations")]
pub async fn create_conversation(_: Require<Ai>) -> actix_web::Result<impl Responder, Error> {
// generate an assistant response
let uuid = Uuid::now_v7();
let response = Conversation {
id: uuid,
messages: Default::default(),
updated_at: to_offset_date_time(uuid)?,
seq: 0,
};

Ok(HttpResponse::Ok().json(response))
}

fn to_offset_date_time(uuid: Uuid) -> Result<OffsetDateTime, Error> {
match uuid.get_timestamp() {
Some(ts) => match OffsetDateTime::from_unix_timestamp(ts.to_unix().0 as i64) {
Ok(ts) => Ok(ts),
Err(e) => Err(Error::Internal(e.to_string())),
},
None => Err(Error::Internal("uuid generation failure".into())),
}
}

#[utoipa::path(
tag = "ai",
operation_id = "updateConversation",
params(
("id", Path, description = "Opaque ID of the conversation")
),
request_body = Conversation,
responses(
(status = 200, description = "The resulting conversation", body = Conversation),
(status = 400, description = "The request was invalid"),
(status = 404, description = "The AI service is not enabled or the conversation was not found")
)
)]
#[put("/v1/ai/conversations/{id}")]
pub async fn update_conversation(
service: web::Data<AiService>,
db: web::Data<Database>,
id: web::Path<Uuid>,
user: UserDetails,
request: web::Json<Conversation>,
_: Require<Ai>,
) -> actix_web::Result<impl Responder> {
let user_id = user.id;

let conversation_id = id.into_inner();

let (conversation, messages) = service
.upsert_conversation(
conversation_id,
user_id,
&request.messages,
request.seq,
db.as_ref(),
)
.await?;

let conversation = Conversation {
id: conversation.id,
updated_at: conversation.updated_at,
messages,
seq: conversation.seq,
};

Ok(HttpResponse::Ok().json(conversation))
}

#[utoipa::path(
tag = "ai",
operation_id = "listConversations",
params(
Query,
Paginated,
),
responses(
(status = 200, description = "The resulting list of conversation summaries", body = PaginatedResults<ConversationSummary>),
(status = 404, description = "The AI service is not enabled")
)
)]
#[get("/v1/ai/conversations")]
// Gets the list of the user's previous conversations
pub async fn list_conversations(
service: web::Data<AiService>,
web::Query(search): web::Query<Query>,
web::Query(paginated): web::Query<Paginated>,
db: web::Data<Database>,
user: UserDetails,
_: Require<Ai>,
) -> actix_web::Result<impl Responder> {
let user_id = user.id;

let result = service
.fetch_conversations(user_id, search, paginated, db.as_ref())
.await?;

let result = PaginatedResults {
items: result
.items
.into_iter()
.map(|c| ConversationSummary {
id: c.id,
summary: c.summary,
updated_at: c.updated_at,
})
.collect(),
total: result.total,
};

Ok(HttpResponse::Ok().json(result))
}

#[utoipa::path(
tag = "ai",
operation_id = "getConversation",
params(
("id", Path, description = "Opaque ID of the conversation")
),
responses(
(status = 200, description = "The resulting conversation", body = Conversation),
(status = 400, description = "The request was invalid"),
(status = 404, description = "The AI service is not enabled or the conversation was not found")
)
)]
#[get("/v1/ai/conversations/{id}")]
pub async fn get_conversation(
service: web::Data<AiService>,
db: web::Data<Database>,
id: web::Path<Uuid>,
user: UserDetails,
_: Require<Ai>,
) -> actix_web::Result<impl Responder> {
let user_id = user.id;

let uuid = id.into_inner();
let conversation = service.fetch_conversation(uuid, db.as_ref()).await?;

match conversation {
// return an empty conversation i
None => Ok(HttpResponse::Ok().json(Conversation {
id: uuid,
messages: Default::default(),
updated_at: to_offset_date_time(uuid)?,
seq: 0,
})),

// Found the conversation
Some((conversation, internal_state)) => {
// verify that the conversation belongs to the user
if conversation.user_id != user_id {
// make this error look like a not found error to avoid leaking
// existence of the conversation
Err(Error::NotFound("conversation not found".to_string()))?;
}

Ok(HttpResponse::Ok().json(Conversation {
id: conversation.id,
updated_at: conversation.updated_at,
messages: internal_state.chat_messages(),
seq: conversation.seq,
}))
}
}
}

#[utoipa::path(
tag = "ai",
operation_id = "deleteConversation",
params(
("id", Path, description = "Opaque ID of the conversation")
),
responses(
(status = 200, description = "The resulting conversation", body = Conversation),
(status = 400, description = "The request was invalid"),
(status = 404, description = "The AI service is not enabled or the conversation was not found")
)
)]
#[delete("/v1/ai/conversations/{id}")]
pub async fn delete_conversation(
service: web::Data<AiService>,
db: web::Data<Database>,
id: web::Path<Uuid>,
user: UserDetails,
_: Require<Ai>,
) -> actix_web::Result<impl Responder> {
let user_id = user.id;
let conversation_id = id.into_inner();

let conversation = service
.fetch_conversation(conversation_id, db.as_ref())
.await?;

match conversation {
// the conversation_id might be invalid
None => Err(Error::NotFound("conversation not found".to_string()))?,

// Found the conversation
Some((conversation, internal_state)) => {
// verify that the conversation belongs to the user
if conversation.user_id != user_id {
// make this error look like a not found error to avoid leaking
// existence of the conversation
Err(Error::NotFound("conversation not found".to_string()))?;
}

let rows_affected = service
.delete_conversation(conversation_id, db.as_ref())
.await?;
match rows_affected {
0 => Ok(HttpResponse::NotFound().finish()),
1 => Ok(HttpResponse::Ok().json(Conversation {
id: conversation.id,
updated_at: conversation.updated_at,
messages: internal_state.chat_messages(),
seq: conversation.seq,
})),
_ => Err(Error::Internal("Unexpected number of rows affected".into()))?,
}
}
}
}
Loading

0 comments on commit 773e6d5

Please sign in to comment.