From ee5ce02c9512da8e4a4d401ab856dcb35892df6c Mon Sep 17 00:00:00 2001 From: wangrui Date: Fri, 3 Jan 2025 15:56:54 +0800 Subject: [PATCH] refactor: memcached::binary::Connection --- core/Cargo.toml | 2 +- core/src/services/memcached/backend.rs | 57 +++----- core/src/services/memcached/binary.rs | 192 +++++++++---------------- 3 files changed, 92 insertions(+), 159 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 5151607b489..92b4115b08a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -47,7 +47,7 @@ rust-version = "1.75" version = "0.50.1" [features] -default = ["reqwest/rustls-tls", "executors-tokio", "services-memory" ,"services-memcached"] +default = ["reqwest/rustls-tls", "executors-tokio", "services-memory"] # Build test utils or not. # diff --git a/core/src/services/memcached/backend.rs b/core/src/services/memcached/backend.rs index 8e41ba0c10e..c821582cf4a 100644 --- a/core/src/services/memcached/backend.rs +++ b/core/src/services/memcached/backend.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use std::time::Duration; -use super::binary::{self, Conn}; +use super::binary; use crate::raw::adapters::kv; use crate::raw::*; use crate::services::MemcachedConfig; @@ -149,7 +149,7 @@ impl Builder for MemcachedBuilder { ); }; if self.config.enable_tls { - rustls::crypto::aws_lc_rs::default_provider() + rustls::crypto::ring::default_provider() .install_default() .map_err(|_err| { Error::new( @@ -325,40 +325,37 @@ impl bb8::ManageConnection for MemcacheConnectionManager { /// TODO: Implement unix stream support. async fn connect(&self) -> Result { - let conn = if self.enable_tls { + let mut conn = if self.enable_tls { let mut root_cert_store = rustls::RootCertStore::empty(); - let native_certs = rustls_native_certs::load_native_certs(); - if native_certs.errors.is_empty() { - for cert in rustls_native_certs::load_native_certs().expect("unreachable!") { - root_cert_store.add(cert).map_err(|err| { - Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err) - })?; - } + let native_certs = rustls_native_certs::load_native_certs() + .expect("errors occurred while loading certificates"); + for cert in native_certs { + root_cert_store.add(cert).map_err(|err| { + Error::new(ErrorKind::Unexpected, "load cafile failed").set_source(err) + })?; } - + let tls_config = + rustls::ClientConfig::builder().with_root_certificates(root_cert_store); let config = if let (Some(cert_path), Some(key_path)) = (&self.tls_cert, &self.tls_key) { let cert_chain = CertificateDer::pem_file_iter(cert_path) .map_err(|err| { - Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err) + Error::new(ErrorKind::Unexpected, "load tls cert failed").set_source(err) })? .filter_map(Result::ok) .collect(); let key_der = PrivateKeyDer::from_pem_file(key_path).map_err(|err| { - Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err) + Error::new(ErrorKind::Unexpected, "load tls key failed").set_source(err) })?; - rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store) + tls_config .with_client_auth_cert(cert_chain, key_der) .map_err(|err| { - Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err) + Error::new(ErrorKind::Unexpected, "build tls client failed").set_source(err) })? } else { - rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store) - .with_no_client_auth() + tls_config.with_no_client_auth() }; let connector = TlsConnector::from(Arc::new(config)); @@ -381,29 +378,17 @@ impl bb8::ManageConnection for MemcacheConnectionManager { Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err) })?; - let mut conn = binary::TlsConnection::new(conn); - - if let (Some(username), Some(password)) = - (self.username.as_ref(), self.password.as_ref()) - { - conn.auth(username, password).await?; - } - binary::Connection::Tls(conn) + binary::Connection::new(Box::new(conn)) } else { let conn = TcpStream::connect(&self.address) .await .map_err(new_std_io_error)?; - let mut conn = binary::TcpConnection::new(conn); - - if let (Some(username), Some(password)) = - (self.username.as_ref(), self.password.as_ref()) - { - conn.auth(username, password).await?; - } - - binary::Connection::Tcp(conn) + binary::Connection::new(Box::new(conn)) }; + if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) { + conn.auth(username, password).await?; + } Ok(conn) } diff --git a/core/src/services/memcached/binary.rs b/core/src/services/memcached/binary.rs index 23e29814f4b..db9250853d2 100644 --- a/core/src/services/memcached/binary.rs +++ b/core/src/services/memcached/binary.rs @@ -95,39 +95,6 @@ impl PacketHeader { } } -pub enum Connection { - Tls(TlsConnection), - Tcp(TcpConnection), -} - -impl Connection { - pub async fn version(&mut self) -> Result { - match self { - Self::Tls(conn) => conn.version().await, - Self::Tcp(conn) => conn.version().await, - } - } - pub async fn get(&mut self, key: &str) -> Result>> { - match self { - Self::Tls(conn) => conn.get(key).await, - Self::Tcp(conn) => conn.get(key).await, - } - } - - pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { - match self { - Self::Tls(conn) => conn.set(key, val, expiration).await, - Self::Tcp(conn) => conn.set(key, val, expiration).await, - } - } - pub async fn delete(&mut self, key: &str) -> Result<()> { - match self { - Self::Tls(conn) => conn.delete(key).await, - Self::Tcp(conn) => conn.delete(key).await, - } - } -} - pub struct Response { header: PacketHeader, _key: Vec, @@ -135,95 +102,21 @@ pub struct Response { value: Vec, } -pub struct TlsConnection { - io: BufReader>, +pub struct Connection { + io: BufReader>, } -impl TlsConnection { - pub fn new(io: TlsStream) -> Self { - Self { - io: BufReader::new(io), - } - } -} - -impl Conn for TlsConnection { - type T = TlsStream; - fn get_conn(&mut self) -> &mut Self::T { - self.io.get_mut() - } -} -pub struct TcpConnection { - io: BufReader, -} - -impl TcpConnection { - pub fn new(io: TcpStream) -> Self { +impl Connection { + pub fn new(io: Box) -> Self { Self { io: BufReader::new(io), } } -} - -impl Conn for TcpConnection { - type T = TcpStream; - fn get_conn(&mut self) -> &mut Self::T { + pub fn get_mut(&mut self) -> &mut Box { self.io.get_mut() } -} - -pub async fn parse_response( - reader: &mut T, -) -> Result { - let header = PacketHeader::read::(reader) - .await - .map_err(new_std_io_error)?; - - if header.vbucket_id_or_status != constants::OK_STATUS - && header.vbucket_id_or_status != constants::KEY_NOT_FOUND - { - return Err( - Error::new(ErrorKind::Unexpected, "unexpected status received") - .with_context("message", format!("{}", header.vbucket_id_or_status)), - ); - } - - let mut extras = vec![0x0; header.extras_length as usize]; - reader - .read_exact(extras.as_mut_slice()) - .await - .map_err(new_std_io_error)?; - - let mut key = vec![0x0; header.key_length as usize]; - reader - .read_exact(key.as_mut_slice()) - .await - .map_err(new_std_io_error)?; - - let mut value = vec![ - 0x0; - (header.total_body_length - u32::from(header.key_length) - u32::from(header.extras_length)) - as usize - ]; - reader - .read_exact(value.as_mut_slice()) - .await - .map_err(new_std_io_error)?; - - Ok(Response { - header, - _key: key, - _extras: extras, - value, - }) -} -#[async_trait::async_trait] -pub trait Conn { - type T: AsyncWrite + std::marker::Unpin + tokio::io::AsyncRead + std::marker::Send; - - fn get_conn(&mut self) -> &mut Self::T; - async fn auth(&mut self, username: &str, password: &str) -> Result<()> { - let writer = self.get_conn(); + pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> { + let writer = self.get_mut(); let key = "PLAIN"; let request_header = PacketHeader { magic: Magic::Request as u8, @@ -249,8 +142,8 @@ pub trait Conn { Ok(()) } - async fn version(&mut self) -> Result { - let writer = self.get_conn(); + pub async fn version(&mut self) -> Result { + let writer = self.get_mut(); let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Version as u8, @@ -271,8 +164,8 @@ pub trait Conn { } } - async fn get(&mut self, key: &str) -> Result>> { - let writer = self.get_conn(); + pub async fn get(&mut self, key: &str) -> Result>> { + let writer = self.get_mut(); let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Get as u8, @@ -300,8 +193,8 @@ pub trait Conn { } } - async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { - let writer = self.get_conn(); + pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { + let writer = self.get_mut(); let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Set as u8, @@ -337,8 +230,8 @@ pub trait Conn { Ok(()) } - async fn delete(&mut self, key: &str) -> Result<()> { - let writer = self.get_conn(); + pub async fn delete(&mut self, key: &str) -> Result<()> { + let writer = self.get_mut(); let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Delete as u8, @@ -359,3 +252,58 @@ pub trait Conn { Ok(()) } } + +pub async fn parse_response( + reader: &mut T, +) -> Result { + let header = PacketHeader::read::(reader) + .await + .map_err(new_std_io_error)?; + + if header.vbucket_id_or_status != constants::OK_STATUS + && header.vbucket_id_or_status != constants::KEY_NOT_FOUND + { + return Err( + Error::new(ErrorKind::Unexpected, "unexpected status received") + .with_context("message", format!("{}", header.vbucket_id_or_status)), + ); + } + + let mut extras = vec![0x0; header.extras_length as usize]; + reader + .read_exact(extras.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + let mut key = vec![0x0; header.key_length as usize]; + reader + .read_exact(key.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + let mut value = vec![ + 0x0; + (header.total_body_length - u32::from(header.key_length) - u32::from(header.extras_length)) + as usize + ]; + reader + .read_exact(value.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + Ok(Response { + header, + _key: key, + _extras: extras, + value, + }) +} + +#[async_trait::async_trait] +pub trait Connect: + AsyncWrite + std::marker::Unpin + tokio::io::AsyncRead + std::marker::Send +{ +} + +impl Connect for TcpStream {} +impl Connect for TlsStream {}