Skip to content

Commit

Permalink
[#260] Support HTTPS local with rustls
Browse files Browse the repository at this point in the history
  • Loading branch information
zonyitoo committed May 31, 2020
1 parent 5a1d9ef commit 07307fc
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 4 deletions.
19 changes: 19 additions & 0 deletions src/bin/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ fn main() {
);
}

#[cfg(feature = "local-http-rustls")]
{
app = clap_app!(@app (app)
(@arg TLS_IDENTITY_CERT_PATH: --("tls-identity-certificate") +takes_value required_if("PROTOCOL", "https") requires[TLS_IDENTITY_PRIVATE_KEY_PATH] "TLS identity certificate (PEM) path for HTTPS server")
(@arg TLS_IDENTITY_PRIVATE_KEY_PATH: --("tls-identity-private-key") +takes_value required_if("PROTOCOL", "https") requires[TLS_IDENTITY_CERT_PATH] "TLS identity private key (PEM), PKCS #8 or RSA syntax, for HTTPS server")
);
}

let matches = app.get_matches();
drop(available_ciphers);

Expand Down Expand Up @@ -298,6 +306,17 @@ fn main() {
}
}

#[cfg(feature = "local-http-rustls")]
{
if let Some(cpath) = matches.value_of("TLS_IDENTITY_CERT_PATH") {
config.tls_identity_certificate_path = Some(cpath.into());
}

if let Some(kpath) = matches.value_of("TLS_IDENTITY_PRIVATE_KEY_PATH") {
config.tls_identity_private_key_path = Some(kpath.into());
}
}

// DONE READING options

if config.local_addr.is_none() {
Expand Down
13 changes: 12 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -977,11 +977,18 @@ pub struct Config {
///
/// Set to `true` if you want to query IPv6 addresses before IPv4
pub ipv6_first: bool,
/// TLS cryptographic identify (X509)
/// TLS cryptographic identity (X509), PKCS #12 format
#[cfg(feature = "local-http-native-tls")]
pub tls_identity_path: Option<PathBuf>,
/// TLS cryptographic identity's password
#[cfg(feature = "local-http-native-tls")]
pub tls_identity_password: Option<String>,
/// TLS cryptographic identity, certificate file path (PEM)
#[cfg(feature = "local-http-rustls")]
pub tls_identity_certificate_path: Option<PathBuf>,
/// TLS cryptographic identity, private keys (PEM), RSA or PKCS #8
#[cfg(feature = "local-http-rustls")]
pub tls_identity_private_key_path: Option<PathBuf>,
}

/// Configuration parsing error kind
Expand Down Expand Up @@ -1076,6 +1083,10 @@ impl Config {
tls_identity_path: None,
#[cfg(feature = "local-http-native-tls")]
tls_identity_password: None,
#[cfg(feature = "local-http-rustls")]
tls_identity_certificate_path: None,
#[cfg(feature = "local-http-rustls")]
tls_identity_private_key_path: None,
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/relay/tcprelay/http_tls/native_tls.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
//! TLS support implementation by native-tls
//! TLS support by [native-tls](https://crates.io/crates/native-tls)
use std::{
fs::File,
future::Future,
io,
io::Read,
io::{self, Read},
net::SocketAddr,
pin::Pin,
sync::Arc,
Expand All @@ -16,6 +15,7 @@ use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use log::trace;
use native_tls::Identity;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
Expand All @@ -34,6 +34,8 @@ impl TlsAcceptor {
let id_path = config.tls_identity_path.as_ref().expect("identity path");
let id_pwd = config.tls_identity_password.as_ref().expect("identify password");

trace!("creating TLS acceptor with identity: {}", id_path.display());

let mut id_file = File::open(id_path)?;
let mut id_buf = Vec::new();
id_file.read_to_end(&mut id_buf)?;
Expand Down
202 changes: 202 additions & 0 deletions src/relay/tcprelay/http_tls/rustls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//! TLS support by [rustls](https://crates.io/crates/rustls)
use std::{
fs::File,
future::Future,
io::{self, BufReader, Read},
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{self, Poll},
};

use futures::ready;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use log::trace;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::rustls::{self, NoClientAuth, PrivateKey, ServerConfig};

use crate::config::Config;

#[pin_project]
pub struct TlsAcceptor {
acceptor: tokio_rustls::TlsAcceptor,
#[pin]
incoming: AddrIncoming,
}

impl TlsAcceptor {
pub fn bind(config: &Config, addr: &SocketAddr) -> io::Result<TlsAcceptor> {
let id_cert_path = config.tls_identity_certificate_path.as_ref().expect("certificate path");
let id_key_path = config.tls_identity_private_key_path.as_ref().expect("private key path");

trace!(
"creating TLS acceptor with cert: {}, private key: {}",
id_cert_path.display(),
id_key_path.display()
);

let id_cert_file = File::open(id_cert_path)?;
let id_cert = match rustls::internal::pemfile::certs(&mut BufReader::new(id_cert_file)) {
Ok(certs) => certs,
Err(..) => {
let err = io::Error::new(io::ErrorKind::InvalidData, "error while loading certificates");
return Err(err);
}
};

let mut id_key_file = File::open(id_key_path)?;
let mut id_key_buf = Vec::new();
id_key_file.read_to_end(&mut id_key_buf)?;

let mut id_key = TlsAcceptor::load_pkcs8_private_key(&id_key_buf)?;
if id_key.is_empty() {
id_key = TlsAcceptor::load_rsa_private_key(&id_key_buf)?;
}

if id_key.is_empty() {
let err = io::Error::new(
io::ErrorKind::InvalidInput,
"cannot find any PKCS #8 or RSA private keys",
);
return Err(err);
}

let mut config = ServerConfig::new(NoClientAuth::new());
if let Err(err) = config.set_single_cert(id_cert, id_key.remove(0)) {
let err = io::Error::new(io::ErrorKind::Other, format!("setting certificate: {}", err));
return Err(err);
}
config.set_protocols(&["h2".into(), "http/1.1".into()]);

let server_config = Arc::new(config);

Ok(TlsAcceptor {
acceptor: From::from(server_config),
incoming: match AddrIncoming::bind(addr) {
Ok(incoming) => incoming,
Err(err) => {
let err = io::Error::new(io::ErrorKind::Other, format!("hyper bind: {}", err));
return Err(err);
}
},
})
}

fn load_pkcs8_private_key(key: &[u8]) -> io::Result<Vec<PrivateKey>> {
match rustls::internal::pemfile::pkcs8_private_keys(&mut BufReader::new(key)) {
Ok(pk) => Ok(pk),
Err(..) => {
let err = io::Error::new(io::ErrorKind::InvalidData, "error while loading PKCS #8 private keys");
Err(err)
}
}
}

fn load_rsa_private_key(key: &[u8]) -> io::Result<Vec<PrivateKey>> {
match rustls::internal::pemfile::rsa_private_keys(&mut BufReader::new(key)) {
Ok(pk) => Ok(pk),
Err(..) => {
let err = io::Error::new(io::ErrorKind::InvalidData, "error while loading PKCS #8 private keys");
Err(err)
}
}
}

pub fn local_addr(&self) -> SocketAddr {
self.incoming.local_addr()
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let this = self.project();
match ready!(this.incoming.poll_accept(cx)) {
Some(Ok(stream)) => {
let remote_addr = stream.remote_addr();
Poll::Ready(Some(Ok(TlsStream {
state: TlsStreamState::Handshaking(this.acceptor.accept(stream)),
remote_addr,
})))
}
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}

enum TlsStreamState {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

pub struct TlsStream {
state: TlsStreamState,
remote_addr: SocketAddr,
}

impl TlsStream {
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}

macro_rules! forward_stream_method {
($self:expr, $cx:expr, $method:ident $(, $param:expr)*) => {{
let this = $self.get_mut();

loop {
match this.state {
TlsStreamState::Handshaking(ref mut accept_fut) => {
match ready!(Pin::new(accept_fut).poll($cx)) {
Ok(stream) => {
this.state = TlsStreamState::Streaming(stream);
}
Err(err) => {
let err = io::Error::new(io::ErrorKind::Other, format!("tls handshake: {}", err));
return Poll::Ready(Err(err));
}
}
}
TlsStreamState::Streaming(ref mut stream) => {
return Pin::new(stream).$method($cx, $($param),*);
}
}
}
}};
}

impl AsyncRead for TlsStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
forward_stream_method!(self, cx, poll_read, buf)
}
}

impl AsyncWrite for TlsStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
forward_stream_method!(self, cx, poll_write, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
match this.state {
TlsStreamState::Handshaking(..) => Poll::Ready(Ok(())),
TlsStreamState::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
match this.state {
TlsStreamState::Handshaking(..) => Poll::Ready(Ok(())),
TlsStreamState::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

0 comments on commit 07307fc

Please sign in to comment.