Skip to content

Commit

Permalink
Merge pull request #78 from alanbraz/detector-mtls
Browse files Browse the repository at this point in the history
tests added and enable self signed certs #69
  • Loading branch information
gkumbhat authored Jun 18, 2024
2 parents 4ca76b6 + 262574f commit ab1bc23
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 5 deletions.
10 changes: 9 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@ detectors:
service:
hostname: https://localhost/api/v1/text/contents # full url / endpoint currently expected
port: 8080
tls: caikit
tls: detector
chunker_id: en_regex
default_threshold: 0.5
tls:
caikit:
cert_path: /path/to/tls.crt
key_path: /path/to/tls.key
client_ca_cert_path: /path/to/ca.crt
detector:
cert_path: /path/to/tls.crt
key_path: /path/to/tls.key
client_ca_cert_path: /path/to/ca.crt
insecure: false
detector_bundle_no_ca:
cert_path: /path/to/client-bundle.pem
insecure: true
46 changes: 42 additions & 4 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ use std::{collections::HashMap, time::Duration};
use futures::future::join_all;
use ginepro::LoadBalancedChannel;
use reqwest::StatusCode;
use tracing::debug;
use url::Url;

use tokio::fs::File;
use tokio::io::AsyncReadExt;

use crate::config::{ServiceConfig, Tls};

pub mod chunker;
Expand Down Expand Up @@ -111,13 +115,47 @@ pub async fn create_http_clients(
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.timeout(DEFAULT_REQUEST_TIMEOUT);
if let Some(Tls::Config(tls_config)) = &service_config.tls {
let mut cert_buf = Vec::new();
let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| {
panic!("error reading cert from {cert_path:?}: {error}")
File::open(cert_path)
.await
.unwrap_or_else(|error| {
panic!("error reading cert from {cert_path:?}: {error}")
})
.read_to_end(&mut cert_buf)
.await
.unwrap();

if let Some(key_path) = &tls_config.key_path {
File::open(key_path)
.await
.unwrap_or_else(|error| {
panic!("error reading key from {key_path:?}: {error}")
})
.read_to_end(&mut cert_buf)
.await
.unwrap();
}
let identity = reqwest::Identity::from_pem(&cert_buf).unwrap_or_else(|error| {
panic!("error parsing bundled client certificate: {error}")
});
let identity = reqwest::Identity::from_pem(&cert_pem)
.unwrap_or_else(|error| panic!("error parsing cert: {error}"));

builder = builder.use_rustls_tls().identity(identity);

debug!(?tls_config.insecure);
builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false));

if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path {
let ca_cert =
tokio::fs::read(client_ca_cert_path)
.await
.unwrap_or_else(|error| {
panic!("error reading cert from {client_ca_cert_path:?}: {error}")
});
let cacert = reqwest::Certificate::from_pem(&ca_cert)
.unwrap_or_else(|error| panic!("error parsing ca cert: {error}"));
builder = builder.add_root_certificate(cacert)
}
}
let client = builder
.build()
Expand Down
82 changes: 82 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub struct TlsConfig {
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
pub client_ca_cert_path: Option<PathBuf>,
pub insecure: Option<bool>,
}

/// Generation service provider
Expand Down Expand Up @@ -202,4 +203,85 @@ tls: {}
assert!(config.chunkers.len() == 2 && config.detectors.len() == 1);
Ok(())
}

#[test]
fn test_deserialize_config_detector_tls_signed() -> Result<(), Error> {
let s = r#"
generation:
provider: tgis
service:
hostname: localhost
port: 8000
chunkers:
sentence-en:
type: sentence
service:
hostname: localhost
port: 9000
sentence-ja:
type: sentence
service:
hostname: localhost
port: 9000
detectors:
hap:
service:
hostname: localhost
port: 9000
tls: detector
chunker_id: sentence-en
default_threshold: 0.5
tls:
detector:
cert_path: /certs/client.pem
"#;
let config: OrchestratorConfig = serde_yml::from_str(s)?;
assert!(config.chunkers.len() == 2 && config.detectors.len() == 1);
assert!(config.tls.len() == 1 && config.tls.contains_key("detector"));
Ok(())
}

#[test]
fn test_deserialize_config_detector_tls_insecure() -> Result<(), Error> {
let s = r#"
generation:
provider: tgis
service:
hostname: localhost
port: 8000
chunkers:
sentence-en:
type: sentence
service:
hostname: localhost
port: 9000
sentence-ja:
type: sentence
service:
hostname: localhost
port: 9000
detectors:
hap:
service:
hostname: localhost
port: 9000
tls: detector
chunker_id: sentence-en
default_threshold: 0.5
tls:
detector:
client_ca_cert_path: /certs/ca.pem
cert_path: /certs/client.pem
key_path: /certs/client-key.pem
insecure: true
"#;
let config: OrchestratorConfig = serde_yml::from_str(s)?;
assert!(config.chunkers.len() == 2 && config.detectors.len() == 1);
assert!(
config.tls.len() == 1 && config.tls.get("detector").unwrap().insecure == Some(true)
);
Ok(())
}
}

//

0 comments on commit ab1bc23

Please sign in to comment.