Skip to content

Commit

Permalink
Migrate clients implementation from reqwest to hyper (foundation-…
Browse files Browse the repository at this point in the history
…model-stack#252)

* migrate from reqwest to hyper

Signed-off-by: Paul Scoropan <[email protected]>

* updated test + some nits

Signed-off-by: Paul Scoropan <[email protected]>

* reorder dependencies

Signed-off-by: Paul Scoropan <[email protected]>

* more clean up. TODOs left: docstrings, debug! and info! lines where missing

Signed-off-by: Paul Scoropan <[email protected]>

* one more formatting nit

Signed-off-by: Paul Scoropan <[email protected]>

* minor http client refactor and update detector endpoint handling

Signed-off-by: Paul Scoropan <[email protected]>

* update openai client

Signed-off-by: Paul Scoropan <[email protected]>

* removed eventsource implementation, connected http_body_utils stream impl to eventsource_stream eventsource impl

Signed-off-by: Paul Scoropan <[email protected]>

* simplified RequestLike, ResponseLike

Signed-off-by: Paul Scoropan <[email protected]>

* use &self instead of self for HttpClient get/post

Signed-off-by: Paul Scoropan <[email protected]>

* make sure error responses are handled properly and error handling improvements

Signed-off-by: Paul Scoropan <[email protected]>

* doc strings

Signed-off-by: Paul Scoropan <[email protected]>

* removed RequestLike and ResponseLike

Signed-off-by: Paul Scoropan <[email protected]>

* [image tested successfully on cluster] marginal improvement to debug log statement

Signed-off-by: Paul Scoropan <[email protected]>

* fix tokio spawn tracing (needs explicit instrumentation to avoid creating new trace) and add missing traceparent propagation from responses

Signed-off-by: Paul Scoropan <[email protected]>

* more async tracing instrumentation fixes

Signed-off-by: Paul Scoropan <[email protected]>

* more tracing logic fixes - traces are being split due to async instrumentation logic

Signed-off-by: Paul Scoropan <[email protected]>

* added incoming response debug print - in process of debugging traceparent injection issue

Signed-off-by: Paul Scoropan <[email protected]>

* attempt fix tranceparent injection issue - testing on cluster

Signed-off-by: Paul Scoropan <[email protected]>

* more debug testing

Signed-off-by: Paul Scoropan <[email protected]>

* nits & add RequestBody, ResponseBody

Signed-off-by: Paul Scoropan <[email protected]>

* simplified client error handling and added missing handling for chat generation clients

Signed-off-by: Paul Scoropan <[email protected]>

* simplify error handling

Signed-off-by: Paul Scoropan <[email protected]>

* small nits

Signed-off-by: Paul Scoropan <[email protected]>

* rebase formatting

Signed-off-by: Paul Scoropan <[email protected]>

* fix client connect and request timeouts

Signed-off-by: Paul Scoropan <[email protected]>

* fix variable misuse and revert error message changes

Signed-off-by: Paul Scoropan <[email protected]>

---------

Signed-off-by: Paul Scoropan <[email protected]>
  • Loading branch information
pscoro authored Nov 25, 2024
1 parent a4964b1 commit 79434f2
Show file tree
Hide file tree
Showing 26 changed files with 1,654 additions and 774 deletions.
674 changes: 520 additions & 154 deletions Cargo.lock

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,36 @@ async-stream = "0.3.5"
axum = { version = "0.7.5", features = ["json"] }
axum-extra = "0.9.3"
clap = { version = "4.5.15", features = ["derive", "env"] }
eventsource = "0.5.0"
eventsource-stream = "0.2.3"
futures = "0.3.30"
futures-core = "0.3.30"
futures-timer = "3.0.3"
ginepro = "0.8.1"
http-body-util = "0.1.2"
http-serde = "2.1.1"
hyper = { version = "1.4.1", features = ["http1", "http2", "server"] }
hyper-rustls = { version = "0.27.3", features = ["ring"]}
hyper-util = { version = "0.1.7", features = ["server-auto", "server-graceful", "tokio"] }
mime = "0.3.17"
mio = "1.0.2"
opentelemetry = { version = "0.24.0", features = ["trace", "metrics"] }
opentelemetry-http = { version = "0.13.0", features = ["reqwest"] }
opentelemetry-otlp = { version = "0.17.0", features = ["http-proto"] }
opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio", "metrics"] }
pin-project-lite = "0.2.15"
prost = "0.13.1"
reqwest = { version = "0.12.5", features = ["blocking", "rustls-tls", "json"] }
reqwest-eventsource = "0.6.0"
rustls = {version = "0.23.12", default-features = false, features = ["std"]}
rustls = {version = "0.23.12", default-features = false, features = ["std", "ring"]}
rustls-pemfile = "2.1.3"
rustls-webpki = "0.102.6"
serde = { version = "1.0.206", features = ["derive"] }
serde_json = "1.0.124"
serde_yml = "0.0.11"
thiserror = "1.0.63"
tokio = { version = "1.39.2", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
tokio-rustls = { version = "0.26.0" }
tokio-rustls = { version = "0.26.0", features = ["ring"]}
tokio-stream = { version = "0.1.15", features = ["sync"] }
tonic = { version = "0.12.1", features = ["tls", "tls-roots", "tls-webpki-roots"] }
tower-http = { version = "0.5.2", features = ["trace"] }
Expand All @@ -51,6 +59,7 @@ tracing-opentelemetry = "0.25.0"
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
url = "2.5.2"
uuid = { version = "1.10.0", features = ["v4", "fast-rng"] }
hyper-timeout = "0.5.2"

[build-dependencies]
tonic-build = "0.12.1"
Expand Down
105 changes: 51 additions & 54 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ use async_trait::async_trait;
use axum::http::{Extensions, HeaderMap};
use futures::Stream;
use ginepro::LoadBalancedChannel;
use tokio::{fs::File, io::AsyncReadExt};
use hyper_timeout::TimeoutConnector;
use hyper_util::rt::TokioExecutor;
use tonic::{metadata::MetadataMap, Request};
use tracing::{debug, instrument};
use tracing::{debug, instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use url::Url;

use crate::{
config::{ServiceConfig, Tls},
health::HealthCheckResult,
tracing_utils::with_traceparent_header,
utils::{tls, trace::with_traceparent_header},
};

pub mod errors;
Expand All @@ -60,7 +62,7 @@ pub use generation::GenerationClient;

pub mod openai;

const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
const DEFAULT_CONNECT_TIMEOUT_SEC: u64 = 60;
const DEFAULT_REQUEST_TIMEOUT_SEC: u64 = 600;

pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
Expand Down Expand Up @@ -198,7 +200,10 @@ impl ClientMap {
}

#[instrument(skip_all, fields(hostname = service_config.hostname))]
pub async fn create_http_client(default_port: u16, service_config: &ServiceConfig) -> HttpClient {
pub async fn create_http_client(
default_port: u16,
service_config: &ServiceConfig,
) -> Result<HttpClient, Error> {
let port = service_config.port.unwrap_or(default_port);
let protocol = match service_config.tls {
Some(_) => "https",
Expand All @@ -210,53 +215,36 @@ pub async fn create_http_client(default_port: u16, service_config: &ServiceConfi
.set_port(Some(port))
.unwrap_or_else(|_| panic!("error setting port: {}", port));
debug!(%base_url, "creating HTTP client");

let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC);
let request_timeout = Duration::from_secs(
service_config
.request_timeout
.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC),
);
let mut builder = reqwest::ClientBuilder::new()
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.timeout(request_timeout);
if let Some(Tls::Config(tls_config)) = &service_config.tls {
let mut cert_buf = Vec::new();
let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
File::open(cert_path)
.await
.unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}"))
.read_to_end(&mut cert_buf)
.await
.unwrap();

if let Some(key_path) = &tls_config.key_path {
File::open(key_path)
.await
.unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}"))
.read_to_end(&mut cert_buf)
.await
.unwrap();
}
let identity = reqwest::Identity::from_pem(&cert_buf)
.unwrap_or_else(|error| panic!("error parsing bundled client certificate: {error}"));

builder = builder.use_rustls_tls().identity(identity);
builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false));

if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path {
let ca_cert = tokio::fs::read(client_ca_cert_path)
let https_conn_builder = match &service_config.tls {
Some(Tls::Config(tls)) => hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(
tls::build_client_config(tls)
.await
.unwrap_or_else(|error| {
panic!("error reading cert from {client_ca_cert_path:?}: {error}")
});
let cacert = reqwest::Certificate::from_pem(&ca_cert)
.unwrap_or_else(|error| panic!("error parsing ca cert: {error}"));
builder = builder.add_root_certificate(cacert)
}
}
let client = builder
.build()
.unwrap_or_else(|error| panic!("error creating http client: {error}"));
HttpClient::new(base_url, client)
.map_err(|e| e.into_client_error())?,
),
Some(_) => panic!("unexpected unresolved TLS in client builder"),
None => hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls::build_insecure_client_config()),
};
let https_conn = https_conn_builder
.https_or_http()
.enable_http1()
.enable_http2()
.build();

let mut timeout_conn = TimeoutConnector::new(https_conn);
timeout_conn.set_connect_timeout(Some(connect_timeout));

let client =
hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(timeout_conn);
Ok(HttpClient::new(base_url, request_timeout, client))
}

#[instrument(skip_all, fields(hostname = service_config.hostname))]
Expand All @@ -273,13 +261,14 @@ pub async fn create_grpc_client<C>(
let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap();
base_url.set_port(Some(port)).unwrap();
debug!(%base_url, "creating gRPC client");
let connect_timeout = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SEC);
let request_timeout = Duration::from_secs(
service_config
.request_timeout
.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC),
);
let mut builder = LoadBalancedChannel::builder((service_config.hostname.clone(), port))
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.connect_timeout(connect_timeout)
.timeout(request_timeout);

let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls {
Expand Down Expand Up @@ -349,31 +338,39 @@ pub fn is_valid_hostname(hostname: &str) -> bool {
/// Turns a gRPC client request body of type `T` and header map into a `tonic::Request<T>`.
/// Will also inject the current `traceparent` header into the request based on the current span.
fn grpc_request_with_headers<T>(request: T, headers: HeaderMap) -> Request<T> {
let headers = with_traceparent_header(headers);
let ctx = Span::current().context();
let headers = with_traceparent_header(&ctx, headers);
let metadata = MetadataMap::from_headers(headers);
Request::from_parts(metadata, Extensions::new(), request)
}

#[cfg(test)]
mod tests {
use errors::grpc_to_http_code;
use http_body_util::BodyExt;
use hyper::{http, StatusCode};
use reqwest::Response;

use super::*;
use crate::{
clients::http::Response,
health::{HealthCheckResult, HealthStatus},
pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse},
};

async fn mock_http_response(
status: StatusCode,
body: &str,
) -> Result<Response, reqwest::Error> {
Ok(reqwest::Response::from(
async fn mock_http_response(status: StatusCode, body: &str) -> Result<Response, Error> {
Ok(Response(
http::Response::builder()
.status(status)
.body(body.to_string())
.body(
body.to_string()
.map_err(|e| {
panic!(
"infallible error parsing string body in test response: {}",
e
)
})
.boxed(),
)
.unwrap(),
))
}
Expand Down
2 changes: 1 addition & 1 deletion src/clients/chunker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{
caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults},
grpc::health::v1::{health_client::HealthClient, HealthCheckRequest},
},
tracing_utils::trace_context_from_grpc_response,
utils::trace::trace_context_from_grpc_response,
};

const DEFAULT_PORT: u16 = 8085;
Expand Down
84 changes: 58 additions & 26 deletions src/clients/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ use std::fmt::Debug;

use axum::http::HeaderMap;
use hyper::StatusCode;
use reqwest::Response;
use serde::{Deserialize, Serialize};
use tracing::info;
use serde::Deserialize;
use tracing::instrument;
use url::Url;

use super::{
http::{HttpClientExt, RequestBody, ResponseBody},
Error,
};

pub mod text_contents;
pub use text_contents::*;
pub mod text_chat;
Expand All @@ -33,9 +37,6 @@ pub use text_context_doc::*;
pub mod text_generation;
pub use text_generation::*;

use super::{Error, HttpClient};
use crate::tracing_utils::{trace_context_from_http_response, with_traceparent_header};

const DEFAULT_PORT: u16 = 8080;
const DETECTOR_ID_HEADER_NAME: &str = "detector-id";

Expand All @@ -54,24 +55,55 @@ impl From<DetectorError> for Error {
}
}

/// Make a POST request for an HTTP detector client and return the response.
/// Also injects the `traceparent` header from the current span and traces the response.
pub async fn post_with_headers<T: Debug + Serialize>(
client: HttpClient,
url: Url,
request: T,
headers: HeaderMap,
model_id: &str,
) -> Result<Response, Error> {
let headers = with_traceparent_header(headers);
info!(?url, ?headers, ?request, "sending client request");
let response = client
.post(url)
.headers(headers)
.header(DETECTOR_ID_HEADER_NAME, model_id)
.json(&request)
.send()
.await?;
trace_context_from_http_response(&response);
Ok(response)
/// This trait should be implemented by all detectors.
/// If the detector has an HTTP client (currently all detector clients are HTTP) this trait will
/// implicitly extend the client with an HTTP detector specific post function.
pub trait DetectorClient {}

/// Provides a helper extension for HTTP detector clients.
pub trait DetectorClientExt: HttpClientExt {
/// Wraps the post function with extra detector functionality
/// (detector id header injection & error handling)
async fn post_to_detector<U: ResponseBody>(
&self,
model_id: &str,
url: Url,
headers: HeaderMap,
request: impl RequestBody,
) -> Result<U, Error>;

/// Wraps call to inner HTTP client endpoint function.
fn endpoint(&self, path: &str) -> Url;
}

impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
#[instrument(skip_all, fields(model_id, url))]
async fn post_to_detector<U: ResponseBody>(
&self,
model_id: &str,
url: Url,
headers: HeaderMap,
request: impl RequestBody,
) -> Result<U, Error> {
let mut headers = headers;
headers.append(DETECTOR_ID_HEADER_NAME, model_id.parse().unwrap());
let response = self.inner().post(url, headers, request).await?;

let status = response.status();
match status {
StatusCode::OK => Ok(response.json().await?),
_ => Err(response
.json::<DetectorError>()
.await
.unwrap_or(DetectorError {
code: status.as_u16(),
message: "".into(),
})
.into()),
}
}

fn endpoint(&self, path: &str) -> Url {
self.inner().endpoint(path)
}
}
Loading

0 comments on commit 79434f2

Please sign in to comment.