From a0a975abd9f8ff1743885ed99c631688c34f1f05 Mon Sep 17 00:00:00 2001 From: miamia0 Date: Wed, 10 Jan 2024 22:09:30 +0800 Subject: [PATCH] use flush to send --- Cargo.lock | 47 ++-- volo-thrift/Cargo.toml | 1 + volo-thrift/src/client/mod.rs | 15 +- volo-thrift/src/codec/default/mod.rs | 51 +++- volo-thrift/src/codec/mod.rs | 8 + volo-thrift/src/transport/mod.rs | 1 - volo-thrift/src/transport/multiplex/client.rs | 38 +-- volo-thrift/src/transport/multiplex/mod.rs | 1 + volo-thrift/src/transport/multiplex/server.rs | 4 +- .../transport/multiplex/thrift_transport.rs | 252 +++++++++++++----- volo-thrift/src/transport/multiplex/utils.rs | 54 ++++ volo-thrift/src/transport/pingpong/server.rs | 4 +- .../transport/pingpong/thrift_transport.rs | 4 +- 13 files changed, 356 insertions(+), 124 deletions(-) create mode 100644 volo-thrift/src/transport/multiplex/utils.rs diff --git a/Cargo.lock b/Cargo.lock index dfdff1e2..090961c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1134,9 +1134,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.151" +version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "libgit2-sys" @@ -1163,9 +1163,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.12" +version = "1.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d97137b25e321a73eef1418d1d5d2eda4d77e12813f8e6dead84bc52c5870a7b" +checksum = "5f526fdd09d99e19742883e43de41e1aa9e36db0c7ab7f935165d611c5cccc66" dependencies = [ "cc", "libc", @@ -1397,18 +1397,18 @@ dependencies = [ [[package]] name = "num_enum" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683751d591e6d81200c39fb0d1032608b77724f34114db54f571ff1317b337c0" +checksum = "02339744ee7253741199f897151b38e72257d13802d4ee837285cc2990a90845" dependencies = [ "num_enum_derive", ] [[package]] name = "num_enum_derive" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c11e44798ad209ccdd91fc192f0526a369a01234f7373e1b141c96d7cee4f0e" +checksum = "681030a937600a36906c185595136d26abfebb4aa9c65701cefcaf8578bb982b" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1763,11 +1763,10 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "2.0.1" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97dc5fea232fc28d2f597b37c4876b348a40e33f3b02cc975c8d006d78d94b1a" +checksum = "6b2685dd208a3771337d8d386a89840f0f43cd68be8dae90a5f8c2384effc9cd" dependencies = [ - "toml_datetime", "toml_edit", ] @@ -2552,6 +2551,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "tokio-condvar" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7233b09174540ef9bf9fc8326bcad6ccebc631e7c9a1e2e48d956a133056f9d" +dependencies = [ + "tokio", +] + [[package]] name = "tokio-macros" version = "2.2.0" @@ -2622,9 +2630,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.2" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "185d8ab0dfbb35cf1399a6344d8484209c088f75f8f68230da55d48d95d43e3d" +checksum = "a1a195ec8c9da26928f773888e0742ca3ca1040c6cd859c919c9f59c1954ab35" dependencies = [ "serde", "serde_spanned", @@ -2634,18 +2642,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.3" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.20.2" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" +checksum = "d34d383cd00a163b4a5b85053df514d45bc330f6de7737edfe0a93311d1eaa03" dependencies = [ "indexmap 2.1.0", "serde", @@ -3050,6 +3058,7 @@ dependencies = [ "pin-project", "thiserror", "tokio", + "tokio-condvar", "tracing", "volo", ] @@ -3347,9 +3356,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.32" +version = "0.5.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8434aeec7b290e8da5c3f0d628cb0eac6cabcb31d14bb74f779a08109a5914d6" +checksum = "b7520bbdec7211caa7c4e682eb1fbe07abe20cee6756b6e00f537c82c11816aa" dependencies = [ "memchr", ] diff --git a/volo-thrift/Cargo.toml b/volo-thrift/Cargo.toml index 0b1412ac..a596319a 100644 --- a/volo-thrift/Cargo.toml +++ b/volo-thrift/Cargo.toml @@ -44,6 +44,7 @@ tokio = { workspace = true, features = [ "parking_lot", ] } tracing.workspace = true +tokio-condvar = "0.1.0" [features] default = [] diff --git a/volo-thrift/src/client/mod.rs b/volo-thrift/src/client/mod.rs index c8d80a1d..8641f35e 100644 --- a/volo-thrift/src/client/mod.rs +++ b/volo-thrift/src/client/mod.rs @@ -444,9 +444,10 @@ impl ClientBuilder +pub struct MessageService where - Resp: EntryMessage + Send + 'static, + Req: EntryMessage + Send + 'static + Sync, + Resp: EntryMessage + Send + 'static + Sync, MkT: MakeTransport, MkC: MakeCodec + Sync, { @@ -455,13 +456,13 @@ where #[cfg(feature = "multiplex")] inner: motore::utils::Either< pingpong::Client, - crate::transport::multiplex::Client, + crate::transport::multiplex::Client, >, } -impl Service for MessageService +impl Service for MessageService where - Req: EntryMessage + 'static + Send, + Req: Send + 'static + EntryMessage + Sync, Resp: Send + 'static + EntryMessage + Sync, MkT: MakeTransport, MkC: MakeCodec + Sync, @@ -506,8 +507,8 @@ where + Clone + Sync, Req: EntryMessage + Send + 'static + Sync + Clone, - Resp: EntryMessage + Send + 'static, - IL: Layer>, + Resp: EntryMessage + Send + 'static + Sync, + IL: Layer>, IL::Service: Service> + Sync + Clone + Send + 'static, >::Error: Send + Into, diff --git a/volo-thrift/src/codec/default/mod.rs b/volo-thrift/src/codec/default/mod.rs index 9987d026..7854a794 100644 --- a/volo-thrift/src/codec/default/mod.rs +++ b/volo-thrift/src/codec/default/mod.rs @@ -116,8 +116,7 @@ pub struct DefaultEncoder { impl Encoder for DefaultEncoder { - #[inline] - async fn encode( + async fn send( &mut self, cx: &mut Cx, msg: ThriftMessage, @@ -177,6 +176,54 @@ impl Encoder } // write_result } + + #[inline] + async fn encode( + &mut self, + cx: &mut Cx, + msg: ThriftMessage, + ) -> Result<(), crate::Error> { + cx.stats_mut().record_encode_start_at(); + + // first, we need to get the size of the message + let (real_size, malloc_size) = self.encoder.size(cx, &msg)?; + trace!( + "[VOLO] codec encode message real size: {}, malloc size: {}", + real_size, + malloc_size + ); + cx.stats_mut().set_write_size(real_size); + + // then we reserve the size of the message in the linked bytes + self.linked_bytes.reserve(malloc_size); + // after that, we encode the message into the linked bytes + self.encoder + .encode(cx, &mut self.linked_bytes, msg) + .map_err(|e| { + // record the error time + cx.stats_mut().record_encode_end_at(); + e + })?; + + cx.stats_mut().record_encode_end_at(); + Ok(()) + } + + async fn flush(&mut self) -> Result<(), crate::Error> { + self.linked_bytes + .write_all_vectored(&mut self.writer) + .await + .map_err(TransportError::from)?; + + match self.writer.flush().await.map_err(TransportError::from) { + Ok(()) => Ok(()), + Err(e) => Err(e.into()), + } + } + + async fn reset(&mut self) { + self.linked_bytes.reset(); + } } pub struct DefaultDecoder { diff --git a/volo-thrift/src/codec/mod.rs b/volo-thrift/src/codec/mod.rs index 93118cb4..c78608bd 100644 --- a/volo-thrift/src/codec/mod.rs +++ b/volo-thrift/src/codec/mod.rs @@ -25,11 +25,19 @@ pub trait Decoder: Send + 'static { /// /// Note: [`Encoder`] should be designed to be ready for reuse. pub trait Encoder: Send + 'static { + fn reset(&mut self) -> impl Future + Send; + fn send( + &mut self, + cx: &mut Cx, + msg: ThriftMessage, + ) -> impl Future> + Send; fn encode( &mut self, cx: &mut Cx, msg: ThriftMessage, ) -> impl Future> + Send; + + fn flush(&mut self) -> impl Future> + Send; } /// [`MakeCodec`] receives an [`AsyncRead`] and an [`AsyncWrite`] and returns a diff --git a/volo-thrift/src/transport/mod.rs b/volo-thrift/src/transport/mod.rs index f48bcafb..b881cb4f 100644 --- a/volo-thrift/src/transport/mod.rs +++ b/volo-thrift/src/transport/mod.rs @@ -1,5 +1,4 @@ pub(crate) mod incoming; -#[cfg(feature = "multiplex")] pub mod multiplex; pub mod pingpong; pub mod pool; diff --git a/volo-thrift/src/transport/multiplex/client.rs b/volo-thrift/src/transport/multiplex/client.rs index 1f02e978..eeddefbf 100644 --- a/volo-thrift/src/transport/multiplex/client.rs +++ b/volo-thrift/src/transport/multiplex/client.rs @@ -15,18 +15,18 @@ use crate::{ EntryMessage, Error, ThriftMessage, }; -pub struct MakeClientTransport +pub struct MakeClientTransport where MkT: MakeTransport, MkC: MakeCodec, { make_transport: MkT, make_codec: MkC, - _phantom: PhantomData Resp>, + _phantom: PhantomData<(fn() -> Resp, fn() -> Req)>, } -impl, Resp> Clone - for MakeClientTransport +impl, Req, Resp> Clone + for MakeClientTransport { fn clone(&self) -> Self { Self { @@ -37,7 +37,7 @@ impl, Resp> Cl } } -impl MakeClientTransport +impl MakeClientTransport where MkT: MakeTransport, MkC: MakeCodec, @@ -52,13 +52,14 @@ where } } -impl UnaryService
for MakeClientTransport +impl UnaryService
for MakeClientTransport where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { - type Response = ThriftTransport; + type Response = ThriftTransport; type Error = io::Error; async fn call(&self, target: Address) -> Result { @@ -73,22 +74,24 @@ where } } -pub struct Client +pub struct Client where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { #[allow(clippy::type_complexity)] - make_transport: PooledMakeTransport, Address>, + make_transport: PooledMakeTransport, Address>, _marker: PhantomData, } -impl Clone for Client +impl Clone for Client where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { fn clone(&self) -> Self { Self { @@ -98,11 +101,12 @@ where } } -impl Client +impl Client where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { pub fn new(make_transport: MkT, pool_cfg: Option, make_codec: MkC) -> Self { let make_transport = MakeClientTransport::new(make_transport, make_codec); @@ -114,9 +118,9 @@ where } } -impl Service> for Client +impl Service> for Client where - Req: Send + 'static + EntryMessage, + Req: Send + 'static + EntryMessage + Sync, Resp: EntryMessage + Send + 'static + Sync, MkT: MakeTransport, MkC: MakeCodec + Sync, diff --git a/volo-thrift/src/transport/multiplex/mod.rs b/volo-thrift/src/transport/multiplex/mod.rs index 22ce370d..3e222966 100644 --- a/volo-thrift/src/transport/multiplex/mod.rs +++ b/volo-thrift/src/transport/multiplex/mod.rs @@ -1,6 +1,7 @@ mod client; mod server; mod thrift_transport; +pub mod utils; pub use client::Client; pub use server::serve; diff --git a/volo-thrift/src/transport/multiplex/server.rs b/volo-thrift/src/transport/multiplex/server.rs index dfc86b3b..34bfa11d 100644 --- a/volo-thrift/src/transport/multiplex/server.rs +++ b/volo-thrift/src/transport/multiplex/server.rs @@ -52,7 +52,7 @@ pub async fn serve( msg = send_rx.recv() => { match msg { Some((mi, mut cx, msg)) => { - if let Err(e) = metainfo::METAINFO.scope(RefCell::new(mi), encoder.encode::(&mut cx, msg)).await { + if let Err(e) = metainfo::METAINFO.scope(RefCell::new(mi), encoder.send::(&mut cx, msg)).await { // log it error!("[VOLO] server send response error: {:?}, cx: {:?}, peer_addr: {:?}", e, cx, peer_addr); stat_tracer.iter().for_each(|f| f(&cx)); @@ -71,7 +71,7 @@ pub async fn serve( error_msg = error_send_rx.recv() => { match error_msg { Some((mut cx, msg)) => { - if let Err(e) = encoder.encode::(&mut cx, msg).await { + if let Err(e) = encoder.send::(&mut cx, msg).await { // log it error!("[VOLO] server send error error: {:?}, cx: {:?}, peer_addr: {:?}", e, cx, peer_addr); } diff --git a/volo-thrift/src/transport/multiplex/thrift_transport.rs b/volo-thrift/src/transport/multiplex/thrift_transport.rs index 0564dc03..95920ad6 100644 --- a/volo-thrift/src/transport/multiplex/thrift_transport.rs +++ b/volo-thrift/src/transport/multiplex/thrift_transport.rs @@ -1,5 +1,7 @@ use std::{ cell::RefCell, + collections::VecDeque, + marker::PhantomData, sync::{ atomic::{AtomicBool, AtomicUsize}, Arc, @@ -12,6 +14,7 @@ use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{oneshot, Mutex}, }; +use tokio_condvar::Condvar; use volo::{ context::{Role, RpcInfo}, net::Address, @@ -20,7 +23,10 @@ use volo::{ use crate::{ codec::{Decoder, Encoder, MakeCodec}, context::{ClientContext, ThriftContext}, - transport::pool::{Poolable, Reservation}, + transport::{ + multiplex::utils::TxHashMap, + pool::{Poolable, Reservation}, + }, ApplicationError, ApplicationErrorKind, EntryMessage, Error, ThriftMessage, }; @@ -30,18 +36,13 @@ lazy_static::lazy_static! { } #[pin_project] -pub struct ThriftTransport { - write_half: Arc>>, - dirty: Arc, +pub struct ThriftTransport { + _phantom1: PhantomData E>, + #[allow(clippy::type_complexity)] tx_map: Arc< - Mutex< - fxhash::FxHashMap< - i32, - oneshot::Sender< - crate::Result)>>, - >, - >, + TxHashMap< + oneshot::Sender)>>>, >, >, write_error: Arc, @@ -49,25 +50,120 @@ pub struct ThriftTransport { read_error: Arc, // read connection is closed read_closed: Arc, + // TODO make this to lockless + batch_queue: Arc>>>, + queue_cv: Arc, } -impl Clone for ThriftTransport { +impl Clone for ThriftTransport { fn clone(&self) -> Self { Self { - write_half: self.write_half.clone(), - dirty: self.dirty.clone(), tx_map: self.tx_map.clone(), write_error: self.write_error.clone(), read_error: self.read_error.clone(), read_closed: self.read_closed.clone(), + batch_queue: self.batch_queue.clone(), + _phantom1: PhantomData, + queue_cv: self.queue_cv.clone(), } } } -impl ThriftTransport +impl ThriftTransport where E: Encoder, + Req: EntryMessage + Send + 'static + Sync, + Resp: EntryMessage + Send + 'static + Sync, { + pub fn write_loop(&self, mut write_half: WriteHalf) { + let batch_queu = self.batch_queue.clone(); + let inner_tx_map = self.tx_map.clone(); + let inner_read_error: Arc = self.read_error.clone(); + let inner_read_closed = self.read_closed.clone(); + let inner_write_error = self.write_error.clone(); + let queue_cv = self.queue_cv.clone(); + tokio::spawn(async move { + let mut resolved = Vec::with_capacity(32); + let mut has_error; + loop { + { + resolved.clear(); + write_half.reset().await; + has_error = false; + let mut queue = batch_queu.lock().await; + while queue.is_empty() + && !inner_read_error.load(std::sync::atomic::Ordering::Relaxed) + && !inner_read_closed.load(std::sync::atomic::Ordering::Relaxed) + { + queue = queue_cv.wait(queue).await; + } + + if inner_read_error.load(std::sync::atomic::Ordering::Relaxed) + || inner_read_closed.load(std::sync::atomic::Ordering::Relaxed) + { + return; + } + + while !queue.is_empty() { + let current = queue.pop_front().unwrap(); + let seq = current.meta.seq_id; + resolved.push(seq); + let mut cx = ClientContext::new( + seq, + RpcInfo::with_role(Role::Client), + pilota::thrift::TMessageType::Call, + ); + let res = write_half.encode(&mut cx, current).await; + match res { + Ok(_) => {} + Err(err) => { + tracing::error!( + "[VOLO] multiplex connection encode error: {}", + err + ); + inner_write_error.store(true, std::sync::atomic::Ordering::Relaxed); + has_error = true; + while !queue.is_empty() { + let current = queue.pop_front().unwrap(); + resolved.push(current.meta.seq_id); + } + break; + } + } + } + if has_error { + for seq in resolved.iter() { + let _ = inner_tx_map.remove(seq).await.unwrap().send(Err( + Error::Application(ApplicationError::new( + ApplicationErrorKind::UNKNOWN, + format!("write error "), + )), + )); + } + return; + } + let res = write_half.flush().await; + match res { + Ok(_) => {} + Err(err) => { + tracing::error!("[VOLO] multiplex connection flush error: {}", err,); + inner_write_error.store(true, std::sync::atomic::Ordering::Relaxed); + for seq in resolved.iter() { + let _ = inner_tx_map.remove(&seq).await.unwrap().send(Err( + Error::Application(ApplicationError::new( + ApplicationErrorKind::UNKNOWN, + err.to_string(), + )), + )); + } + return; + } + } + } + } + }); + } + pub fn new< R: AsyncRead + Send + Sync + Unpin + 'static, W: AsyncWrite + Send + Sync + Unpin + 'static, @@ -91,12 +187,9 @@ where let write_half = WriteHalf { encoder, id }; #[allow(clippy::type_complexity)] let tx_map: Arc< - Mutex< - fxhash::FxHashMap< - i32, - oneshot::Sender< - crate::Result)>>, - >, + TxHashMap< + oneshot::Sender< + crate::Result)>>, >, >, > = Default::default(); @@ -107,6 +200,9 @@ where let inner_read_error = read_error.clone(); let read_closed = Arc::new(AtomicBool::new(false)); let inner_read_closed = read_closed.clone(); + let queue_cv = Arc::new(Condvar::new()); + let inner_queue_cv = queue_cv.clone(); + //// read loop tokio::spawn(async move { metainfo::METAINFO .scope(RefCell::new(Default::default()), async move { @@ -131,35 +227,42 @@ where e, target ); - let mut tx_map = inner_tx_map.lock().await; inner_read_error.store(true, std::sync::atomic::Ordering::Relaxed); - for (_, tx) in tx_map.drain() { - let _ = tx.send(Err(Error::Application(ApplicationError::new( - ApplicationErrorKind::UNKNOWN, - format!("multiplex connection error: {e}, target: {target}"), - )))); - } + inner_queue_cv.notify_all(); + + inner_tx_map + .for_all_drain(|tx| { + let _ = + tx.send(Err(Error::Application(ApplicationError::new( + ApplicationErrorKind::UNKNOWN, + format!( + "multiplex connection error: {e}, target: {target}" + ), + )))); + }) + .await; return; } // we have checked the error above, so it's safe to unwrap here let res = res.unwrap(); if res.is_none() { // the connection is closed - let mut tx_map = inner_tx_map.lock().await; - if !tx_map.is_empty() { + if !inner_tx_map.is_empty().await { inner_read_error.store(true, std::sync::atomic::Ordering::Relaxed); - for (_, tx) in tx_map.drain() { - let _ = tx.send(Ok(None)); - } + inner_tx_map + .for_all_drain(|tx| { + let _ = tx.send(Ok(None)); + }) + .await; } inner_read_closed.store(true, std::sync::atomic::Ordering::Relaxed); + inner_queue_cv.notify_all(); return; } // now we get ThriftMessage let res = res.unwrap(); let seq_id = res.meta.seq_id; - let mut tx_map = inner_tx_map.lock().await; - if let Some(tx) = tx_map.remove(&seq_id) { + if let Some(tx) = inner_tx_map.remove(&seq_id).await { metainfo::METAINFO.with(|mi| { let mi = mi.take(); let _ = tx.send(Ok(Some((mi, cx, res)))); @@ -176,23 +279,27 @@ where }) .await; }); - Self { - write_half: Arc::new(Mutex::new(write_half)), - dirty: Arc::new(AtomicBool::new(false)), + let ret = Self { tx_map, write_error, read_error, read_closed, - } + batch_queue: Default::default(), + _phantom1: PhantomData, + queue_cv, + }; + ret.write_loop(write_half); + ret } } -impl ThriftTransport +impl ThriftTransport where E: Encoder, Resp: EntryMessage, + Req: EntryMessage, { - pub async fn send( + pub async fn send( &self, cx: &mut ClientContext, msg: ThriftMessage, @@ -211,38 +318,21 @@ where "multiplex connection closed".to_string(), ))); } - let (tx, rx) = oneshot::channel(); - let mut tx_map = self.tx_map.lock().await; - let seq_id = msg.meta.seq_id; - if !oneway { - tx_map.insert(seq_id, tx); - } - drop(tx_map); - let mut wh = self.write_half.lock().await; - // check connection dirty - if self.dirty.load(std::sync::atomic::Ordering::Relaxed) { - // connection is dirty, we should also set write error to indicate the connection should - // not be reused - self.write_error - .store(true, std::sync::atomic::Ordering::Relaxed); + if self.write_error.load(std::sync::atomic::Ordering::Relaxed) { return Err(Error::Application(ApplicationError::new( ApplicationErrorKind::UNKNOWN, - "multiplex connection is dirty".to_string(), + "multiplex connection error".to_string(), ))); } - self.dirty.store(true, std::sync::atomic::Ordering::Relaxed); - let res = wh.send(cx, msg).await; - self.dirty - .store(false, std::sync::atomic::Ordering::Relaxed); - drop(wh); - if let Err(e) = res { - self.write_error - .store(true, std::sync::atomic::Ordering::Relaxed); - if !oneway { - let mut tx_map = self.tx_map.lock().await; - tx_map.remove(&seq_id); - } - return Err(e); + + let (tx, rx) = oneshot::channel(); + let seq_id = msg.meta.seq_id; + if !oneway { + self.tx_map.insert(seq_id, tx).await; + } + { + self.batch_queue.lock().await.push_back(msg); + self.queue_cv.notify_all(); } if oneway { return Ok(None); @@ -335,6 +425,7 @@ pub struct WriteHalf { id: usize, } +#[allow(dead_code)] impl WriteHalf where E: Encoder, @@ -343,18 +434,35 @@ where &mut self, cx: &mut impl ThriftContext, msg: ThriftMessage, + ) -> Result<(), Error> { + self.encoder.send(cx, msg).await.map_err(|mut e| { + e.append_msg(&format!(", rpcinfo: {:?}", cx.rpc_info())); + tracing::error!("[VOLO] transport[{}] encode error: {:?}", self.id, e); + e + }) + } + pub async fn reset(&mut self) { + self.encoder.reset().await; + } + + pub async fn encode( + &mut self, + cx: &mut impl ThriftContext, + msg: ThriftMessage, ) -> Result<(), Error> { self.encoder.encode(cx, msg).await.map_err(|mut e| { e.append_msg(&format!(", rpcinfo: {:?}", cx.rpc_info())); tracing::error!("[VOLO] transport[{}] encode error: {:?}", self.id, e); e - })?; + }) + } - Ok(()) + pub async fn flush(&mut self) -> Result<(), Error> { + self.encoder.flush().await } } -impl Poolable for ThriftTransport { +impl Poolable for ThriftTransport { fn reusable(&self) -> bool { !self.write_error.load(std::sync::atomic::Ordering::Relaxed) && !self.read_error.load(std::sync::atomic::Ordering::Relaxed) diff --git a/volo-thrift/src/transport/multiplex/utils.rs b/volo-thrift/src/transport/multiplex/utils.rs new file mode 100644 index 00000000..2a397634 --- /dev/null +++ b/volo-thrift/src/transport/multiplex/utils.rs @@ -0,0 +1,54 @@ +use std::array; + +use tokio::sync::Mutex; + +const SHARD_COUNT: usize = 64; + +pub struct TxHashMap { + sharded: [Mutex>; SHARD_COUNT], +} + +impl Default for TxHashMap { + fn default() -> Self { + TxHashMap { + sharded: array::from_fn(|_| Default::default()), + } + } +} + +impl TxHashMap +where + T: Sized, +{ + pub async fn remove(&self, key: &i32) -> Option { + self.sharded[(*key % (SHARD_COUNT as i32)) as usize] + .lock() + .await + .remove(key) + } + + pub async fn is_empty(&self) -> bool { + for s in self.sharded.iter() { + if !s.lock().await.is_empty() { + return false; + } + } + true + } + + pub async fn insert(&self, key: i32, value: T) -> Option { + self.sharded[(key % (SHARD_COUNT as i32)) as usize] + .lock() + .await + .insert(key, value) + } + + pub async fn for_all_drain(&self, mut f: impl FnMut(T) -> ()) { + for sharded in self.sharded.iter() { + let mut s = sharded.lock().await; + for data in s.drain() { + f(data.1) + } + } + } +} diff --git a/volo-thrift/src/transport/pingpong/server.rs b/volo-thrift/src/transport/pingpong/server.rs index 38521542..95976140 100644 --- a/volo-thrift/src/transport/pingpong/server.rs +++ b/volo-thrift/src/transport/pingpong/server.rs @@ -87,7 +87,7 @@ pub async fn serve( ThriftMessage::mk_server_resp(&cx, resp.map_err(|e| e.into())) .unwrap(); if let Err(e) = async { - let result = encoder.encode(&mut cx, msg).await; + let result = encoder.send(&mut cx, msg).await; span_provider.leave_encode(&cx); result }.instrument(span_provider.on_encode(tracing_cx)).await { @@ -111,7 +111,7 @@ pub async fn serve( if !matches!(e, Error::Transport(_)) { let msg = ThriftMessage::mk_server_resp(&cx, Err::(e)) .unwrap(); - if let Err(e) = encoder.encode(&mut cx, msg).await { + if let Err(e) = encoder.send(&mut cx, msg).await { error!("[VOLO] server send error error: {:?}, cx: {:?}, peer_addr: {:?}", e, cx, peer_addr); } } diff --git a/volo-thrift/src/transport/pingpong/thrift_transport.rs b/volo-thrift/src/transport/pingpong/thrift_transport.rs index 3cba2171..f60e3fd2 100644 --- a/volo-thrift/src/transport/pingpong/thrift_transport.rs +++ b/volo-thrift/src/transport/pingpong/thrift_transport.rs @@ -130,9 +130,9 @@ where cx: &mut impl ThriftContext, msg: ThriftMessage, ) -> Result<(), Error> { - self.encoder.encode(cx, msg).await.map_err(|mut e| { + self.encoder.send(cx, msg).await.map_err(|mut e| { e.append_msg(&format!(", rpcinfo: {:?}", cx.rpc_info())); - tracing::error!("[VOLO] transport[{}] encode error: {:?}", self.id, e); + tracing::error!("[VOLO] transport[{}] send error: {:?}", self.id, e); e })?;