Skip to content

Commit

Permalink
refactor: memcached::binary::Connection
Browse files Browse the repository at this point in the history
  • Loading branch information
wangrui committed Jan 3, 2025
1 parent c215fe0 commit ee5ce02
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 159 deletions.
2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ rust-version = "1.75"
version = "0.50.1"

[features]
default = ["reqwest/rustls-tls", "executors-tokio", "services-memory" ,"services-memcached"]
default = ["reqwest/rustls-tls", "executors-tokio", "services-memory"]

# Build test utils or not.
#
Expand Down
57 changes: 21 additions & 36 deletions core/src/services/memcached/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::sync::Arc;
use std::time::Duration;

use super::binary::{self, Conn};
use super::binary;
use crate::raw::adapters::kv;
use crate::raw::*;
use crate::services::MemcachedConfig;
Expand Down Expand Up @@ -149,7 +149,7 @@ impl Builder for MemcachedBuilder {
);
};
if self.config.enable_tls {
rustls::crypto::aws_lc_rs::default_provider()
rustls::crypto::ring::default_provider()
.install_default()
.map_err(|_err| {
Error::new(
Expand Down Expand Up @@ -325,40 +325,37 @@ impl bb8::ManageConnection for MemcacheConnectionManager {

/// TODO: Implement unix stream support.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let conn = if self.enable_tls {
let mut conn = if self.enable_tls {
let mut root_cert_store = rustls::RootCertStore::empty();

let native_certs = rustls_native_certs::load_native_certs();
if native_certs.errors.is_empty() {
for cert in rustls_native_certs::load_native_certs().expect("unreachable!") {
root_cert_store.add(cert).map_err(|err| {
Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err)
})?;
}
let native_certs = rustls_native_certs::load_native_certs()
.expect("errors occurred while loading certificates");
for cert in native_certs {
root_cert_store.add(cert).map_err(|err| {
Error::new(ErrorKind::Unexpected, "load cafile failed").set_source(err)
})?;
}

let tls_config =
rustls::ClientConfig::builder().with_root_certificates(root_cert_store);
let config = if let (Some(cert_path), Some(key_path)) = (&self.tls_cert, &self.tls_key)
{
let cert_chain = CertificateDer::pem_file_iter(cert_path)
.map_err(|err| {
Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err)
Error::new(ErrorKind::Unexpected, "load tls cert failed").set_source(err)
})?
.filter_map(Result::ok)
.collect();
let key_der = PrivateKeyDer::from_pem_file(key_path).map_err(|err| {
Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err)
Error::new(ErrorKind::Unexpected, "load tls key failed").set_source(err)
})?;

rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
tls_config
.with_client_auth_cert(cert_chain, key_der)
.map_err(|err| {
Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err)
Error::new(ErrorKind::Unexpected, "build tls client failed").set_source(err)
})?
} else {
rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth()
tls_config.with_no_client_auth()
};

let connector = TlsConnector::from(Arc::new(config));
Expand All @@ -381,29 +378,17 @@ impl bb8::ManageConnection for MemcacheConnectionManager {
Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err)
})?;

let mut conn = binary::TlsConnection::new(conn);

if let (Some(username), Some(password)) =
(self.username.as_ref(), self.password.as_ref())
{
conn.auth(username, password).await?;
}
binary::Connection::Tls(conn)
binary::Connection::new(Box::new(conn))
} else {
let conn = TcpStream::connect(&self.address)
.await
.map_err(new_std_io_error)?;

let mut conn = binary::TcpConnection::new(conn);

if let (Some(username), Some(password)) =
(self.username.as_ref(), self.password.as_ref())
{
conn.auth(username, password).await?;
}

binary::Connection::Tcp(conn)
binary::Connection::new(Box::new(conn))
};
if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) {
conn.auth(username, password).await?;
}
Ok(conn)
}

Expand Down
192 changes: 70 additions & 122 deletions core/src/services/memcached/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,135 +95,28 @@ impl PacketHeader {
}
}

pub enum Connection {
Tls(TlsConnection),
Tcp(TcpConnection),
}

impl Connection {
pub async fn version(&mut self) -> Result<String> {
match self {
Self::Tls(conn) => conn.version().await,
Self::Tcp(conn) => conn.version().await,
}
}
pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
match self {
Self::Tls(conn) => conn.get(key).await,
Self::Tcp(conn) => conn.get(key).await,
}
}

pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> {
match self {
Self::Tls(conn) => conn.set(key, val, expiration).await,
Self::Tcp(conn) => conn.set(key, val, expiration).await,
}
}
pub async fn delete(&mut self, key: &str) -> Result<()> {
match self {
Self::Tls(conn) => conn.delete(key).await,
Self::Tcp(conn) => conn.delete(key).await,
}
}
}

pub struct Response {
header: PacketHeader,
_key: Vec<u8>,
_extras: Vec<u8>,
value: Vec<u8>,
}

pub struct TlsConnection {
io: BufReader<TlsStream<TcpStream>>,
pub struct Connection {
io: BufReader<Box<dyn Connect>>,
}

impl TlsConnection {
pub fn new(io: TlsStream<TcpStream>) -> Self {
Self {
io: BufReader::new(io),
}
}
}

impl Conn for TlsConnection {
type T = TlsStream<TcpStream>;
fn get_conn(&mut self) -> &mut Self::T {
self.io.get_mut()
}
}
pub struct TcpConnection {
io: BufReader<TcpStream>,
}

impl TcpConnection {
pub fn new(io: TcpStream) -> Self {
impl Connection {
pub fn new(io: Box<dyn Connect>) -> Self {
Self {
io: BufReader::new(io),
}
}
}

impl Conn for TcpConnection {
type T = TcpStream;
fn get_conn(&mut self) -> &mut Self::T {
pub fn get_mut(&mut self) -> &mut Box<dyn Connect> {
self.io.get_mut()
}
}

pub async fn parse_response<T: AsyncWriteExt + std::marker::Unpin + tokio::io::AsyncRead>(
reader: &mut T,
) -> Result<Response> {
let header = PacketHeader::read::<T>(reader)
.await
.map_err(new_std_io_error)?;

if header.vbucket_id_or_status != constants::OK_STATUS
&& header.vbucket_id_or_status != constants::KEY_NOT_FOUND
{
return Err(
Error::new(ErrorKind::Unexpected, "unexpected status received")
.with_context("message", format!("{}", header.vbucket_id_or_status)),
);
}

let mut extras = vec![0x0; header.extras_length as usize];
reader
.read_exact(extras.as_mut_slice())
.await
.map_err(new_std_io_error)?;

let mut key = vec![0x0; header.key_length as usize];
reader
.read_exact(key.as_mut_slice())
.await
.map_err(new_std_io_error)?;

let mut value = vec![
0x0;
(header.total_body_length - u32::from(header.key_length) - u32::from(header.extras_length))
as usize
];
reader
.read_exact(value.as_mut_slice())
.await
.map_err(new_std_io_error)?;

Ok(Response {
header,
_key: key,
_extras: extras,
value,
})
}
#[async_trait::async_trait]
pub trait Conn {
type T: AsyncWrite + std::marker::Unpin + tokio::io::AsyncRead + std::marker::Send;

fn get_conn(&mut self) -> &mut Self::T;
async fn auth(&mut self, username: &str, password: &str) -> Result<()> {
let writer = self.get_conn();
pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> {
let writer = self.get_mut();
let key = "PLAIN";
let request_header = PacketHeader {
magic: Magic::Request as u8,
Expand All @@ -249,8 +142,8 @@ pub trait Conn {
Ok(())
}

async fn version(&mut self) -> Result<String> {
let writer = self.get_conn();
pub async fn version(&mut self) -> Result<String> {
let writer = self.get_mut();
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Version as u8,
Expand All @@ -271,8 +164,8 @@ pub trait Conn {
}
}

async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
let writer = self.get_conn();
pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
let writer = self.get_mut();
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Get as u8,
Expand Down Expand Up @@ -300,8 +193,8 @@ pub trait Conn {
}
}

async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> {
let writer = self.get_conn();
pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> {
let writer = self.get_mut();
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Set as u8,
Expand Down Expand Up @@ -337,8 +230,8 @@ pub trait Conn {
Ok(())
}

async fn delete(&mut self, key: &str) -> Result<()> {
let writer = self.get_conn();
pub async fn delete(&mut self, key: &str) -> Result<()> {
let writer = self.get_mut();
let request_header = PacketHeader {
magic: Magic::Request as u8,
opcode: Opcode::Delete as u8,
Expand All @@ -359,3 +252,58 @@ pub trait Conn {
Ok(())
}
}

pub async fn parse_response<T: AsyncWriteExt + std::marker::Unpin + tokio::io::AsyncRead>(
reader: &mut T,
) -> Result<Response> {
let header = PacketHeader::read::<T>(reader)
.await
.map_err(new_std_io_error)?;

if header.vbucket_id_or_status != constants::OK_STATUS
&& header.vbucket_id_or_status != constants::KEY_NOT_FOUND
{
return Err(
Error::new(ErrorKind::Unexpected, "unexpected status received")
.with_context("message", format!("{}", header.vbucket_id_or_status)),
);
}

let mut extras = vec![0x0; header.extras_length as usize];
reader
.read_exact(extras.as_mut_slice())
.await
.map_err(new_std_io_error)?;

let mut key = vec![0x0; header.key_length as usize];
reader
.read_exact(key.as_mut_slice())
.await
.map_err(new_std_io_error)?;

let mut value = vec![
0x0;
(header.total_body_length - u32::from(header.key_length) - u32::from(header.extras_length))
as usize
];
reader
.read_exact(value.as_mut_slice())
.await
.map_err(new_std_io_error)?;

Ok(Response {
header,
_key: key,
_extras: extras,
value,
})
}

#[async_trait::async_trait]
pub trait Connect:
AsyncWrite + std::marker::Unpin + tokio::io::AsyncRead + std::marker::Send
{
}

impl Connect for TcpStream {}
impl Connect for TlsStream<TcpStream> {}

0 comments on commit ee5ce02

Please sign in to comment.