Skip to content

Commit

Permalink
added error messages to various unwraps (foundation-model-stack#243)
Browse files Browse the repository at this point in the history
* added error messages to various unwraps

Signed-off-by: resoluteCoder <[email protected]>

* added detector id vars to expect

Signed-off-by: resoluteCoder <[email protected]>

* changed expects to unwrap or else due to lint

Signed-off-by: resoluteCoder <[email protected]>

---------

Signed-off-by: resoluteCoder <[email protected]>
  • Loading branch information
resoluteCoder authored Nov 4, 2024
1 parent 521c80f commit f2010e1
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 13 deletions.
7 changes: 5 additions & 2 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,11 @@ pub async fn create_http_client(default_port: u16, service_config: &ServiceConfi
Some(_) => "https",
None => "http",
};
let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap();
base_url.set_port(Some(port)).unwrap();
let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname))
.unwrap_or_else(|e| panic!("error parsing base url: {}", e));
base_url
.set_port(Some(port))
.unwrap_or_else(|_| panic!("error setting port: {}", port));
debug!(%base_url, "creating HTTP client");
let request_timeout = Duration::from_secs(
service_config
Expand Down
13 changes: 10 additions & 3 deletions src/orchestrator/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,18 @@ async fn streaming_output_detection_task(
// Create a mutable copy of the parameters, so that we can modify it based on processing
let mut detector_params = detector_params.clone();
let detector_id = detector_id.to_string();
let chunker_id = ctx.config.get_chunker_id(&detector_id).unwrap();
let chunker_id = ctx
.config
.get_chunker_id(&detector_id)
.expect("chunker id is not found");

// Get the detector config
// TODO: Add error handling
let detector_config = ctx.config.detectors.get(&detector_id).unwrap();
let detector_config = ctx
.config
.detectors
.get(&detector_id)
.expect("detector config not found");

// Get the default threshold to use if threshold is not provided by the user
let default_threshold = detector_config.default_threshold;
Expand Down Expand Up @@ -394,7 +401,7 @@ async fn detection_task(
let client = ctx
.clients
.get_as::<TextContentsDetectorClient>(&detector_id)
.unwrap();
.unwrap_or_else(|| panic!("text contents detector client not found for {}", detector_id));
match client.text_contents(&detector_id, request, headers)
.await
.map_err(|error| Error::DetectorRequestFailed { id: detector_id.clone(), error }) {
Expand Down
4 changes: 3 additions & 1 deletion src/orchestrator/streaming/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ impl ResultActor {
result.token_classification_results.output = Some(detections);
if input_start_index == 0 {
// Get input_token_count and seed from first generation message
let first = generations.first().unwrap();
let first = generations
.first()
.expect("first element in classified generated text stream result not found");
result.input_token_count = first.input_token_count;
result.seed = first.seed;
// Get input_tokens from second generation message (if specified)
Expand Down
24 changes: 20 additions & 4 deletions src/orchestrator/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,19 @@ impl Orchestrator {
let ctx = ctx.clone();
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
let detector_config = ctx.config.detectors.get(&detector_id).unwrap();
let detector_config =
ctx.config.detectors.get(&detector_id).unwrap_or_else(|| {
panic!("detector config not found for {}", detector_id)
});

let chunker_id = detector_config.chunker_id.as_str();

let default_threshold = detector_config.default_threshold;

let chunk = chunks.get(chunker_id).unwrap().clone();
let chunk = chunks
.get(chunker_id)
.unwrap_or_else(|| panic!("chunk not found for {}", chunker_id))
.clone();

let headers = headers.clone();

Expand Down Expand Up @@ -754,7 +760,12 @@ pub async fn detect_for_generation(
let client = ctx
.clients
.get_as::<TextGenerationDetectorClient>(&detector_id)
.unwrap();
.unwrap_or_else(|| {
panic!(
"text generation detector client not found for {}",
detector_id
)
});
let response = client
.text_generation(&detector_id, request, headers)
.await
Expand Down Expand Up @@ -845,7 +856,12 @@ pub async fn detect_for_context(
let client = ctx
.clients
.get_as::<TextContextDocDetectorClient>(&detector_id)
.unwrap();
.unwrap_or_else(|| {
panic!(
"text context doc detector client not found for {}",
detector_id
)
});
let response = client
.text_context_doc(&detector_id, request, headers)
.await
Expand Down
14 changes: 11 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,23 @@ pub async fn run(
// Configure mTLS if client CA is provided
let client_auth = if tls_client_ca_cert_path.is_some() {
info!("Configuring TLS trust certificate (mTLS) for incoming connections");
let client_certs = load_certs(tls_client_ca_cert_path.as_ref().unwrap());
let client_certs = load_certs(
tls_client_ca_cert_path
.as_ref()
.expect("error loading certs for mTLS"),
);
let mut client_auth_certs = RootCertStore::empty();
for client_cert in client_certs {
// Should be only one
client_auth_certs.add(client_cert).unwrap();
client_auth_certs
.add(client_cert.clone())
.unwrap_or_else(|e| {
panic!("error adding client cert {:?}: {}", client_cert, e)
});
}
WebPkiClientVerifier::builder(client_auth_certs.into())
.build()
.unwrap()
.unwrap_or_else(|e| panic!("error building client verifier: {}", e))
} else {
WebPkiClientVerifier::no_client_auth()
};
Expand Down

0 comments on commit f2010e1

Please sign in to comment.