diff --git a/dubbo/src/extension/invoker_extension.rs b/dubbo/src/extension/invoker_extension.rs new file mode 100644 index 00000000..8d59e5b4 --- /dev/null +++ b/dubbo/src/extension/invoker_extension.rs @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::{ + extension::{ + invoker_extension::proxy::InvokerProxy, Extension, ExtensionFactories, ExtensionMetaInfo, + LoadExtensionPromise, + }, + params::extension_param::{ExtensionName, ExtensionType}, + url::UrlParam, + StdError, Url, +}; +use async_trait::async_trait; +use bytes::Bytes; +use futures_core::Stream; +use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin}; +use thiserror::Error; + +#[async_trait] +pub trait Invoker { + async fn invoke( + &self, + invocation: GrpcInvocation, + ) -> Result + Send + 'static>>, StdError>; + + async fn url(&self) -> Result; +} + +pub enum CallType { + Unary, + ClientStream, + ServerStream, + BiStream, +} + +pub struct GrpcInvocation { + service_name: String, + method_name: String, + arguments: Vec, + attachments: HashMap, + call_type: CallType, +} + +pub struct Argument { + name: String, + value: Box> + Send + 'static>, +} + +pub trait Serializable { + fn serialize(&self, serialization_type: String) -> Result; +} + +pub trait Deserializable { + fn deserialize(&self, bytes: Bytes, deserialization_type: String) -> Result + where + Self: Sized; +} + +pub mod proxy { + use crate::{ + extension::invoker_extension::{GrpcInvocation, Invoker}, + StdError, Url, + }; + use async_trait::async_trait; + use bytes::Bytes; + use futures_core::Stream; + use std::pin::Pin; + use thiserror::Error; + use tokio::sync::{mpsc::Sender, oneshot}; + use tracing::error; + + pub(super) enum InvokerOpt { + Invoke( + GrpcInvocation, + oneshot::Sender + Send + 'static>>, StdError>>, + ), + Url(oneshot::Sender>), + } + + #[derive(Clone)] + pub struct InvokerProxy { + tx: Sender, + } + + #[async_trait] + impl Invoker for InvokerProxy { + async fn invoke( + &self, + invocation: GrpcInvocation, + ) -> Result + Send + 'static>>, StdError> { + let (tx, rx) = oneshot::channel(); + let ret = self.tx.send(InvokerOpt::Invoke(invocation, tx)).await; + match ret { + Ok(_) => {} + Err(err) => { + error!( + "call invoke method failed by invoker proxy, error: {:?}", + err + ); + return Err(InvokerProxyError::new( + "call invoke method failed by invoker proxy", + ) + .into()); + } + } + let ret = rx.await?; + ret + } + + async fn url(&self) -> Result { + let (tx, rx) = oneshot::channel(); + let ret = self.tx.send(InvokerOpt::Url(tx)).await; + match ret { + Ok(_) => {} + Err(err) => { + error!("call url method failed by invoker proxy, error: {:?}", err); + return Err( + InvokerProxyError::new("call url method failed by invoker proxy").into(), + ); + } + } + let ret = rx.await?; + ret + } + } + + impl From> for InvokerProxy { + fn from(invoker: Box) -> Self { + let (tx, mut rx) = tokio::sync::mpsc::channel(64); + tokio::spawn(async move { + while let Some(opt) = rx.recv().await { + match opt { + InvokerOpt::Invoke(invocation, tx) => { + let result = invoker.invoke(invocation).await; + let callback_ret = tx.send(result); + match callback_ret { + Ok(_) => {} + Err(Err(err)) => { + error!("invoke method has been called, but callback to caller failed. {:?}", err); + } + _ => {} + } + } + InvokerOpt::Url(tx) => { + let ret = tx.send(invoker.url().await); + match ret { + Ok(_) => {} + Err(err) => { + error!("url method has been called, but callback to caller failed. {:?}", err); + } + } + } + } + } + }); + InvokerProxy { tx } + } + } + + #[derive(Error, Debug)] + #[error("invoker proxy error: {0}")] + pub struct InvokerProxyError(String); + + impl InvokerProxyError { + pub fn new(msg: &str) -> Self { + InvokerProxyError(msg.to_string()) + } + } +} + +#[derive(Default)] +pub(super) struct InvokerExtensionLoader { + factories: HashMap, +} + +impl InvokerExtensionLoader { + pub fn register(&mut self, extension_name: String, factory: InvokerExtensionFactory) { + self.factories.insert(extension_name, factory); + } + + pub fn remove(&mut self, extension_name: String) { + self.factories.remove(&extension_name); + } + + pub fn load(&mut self, url: Url) -> Result, StdError> { + let extension_name = url.query::(); + let Some(extension_name) = extension_name else { + return Err(InvokerExtensionLoaderError::new( + "load invoker extension failed, extension mustn't be empty", + ) + .into()); + }; + let extension_name = extension_name.value(); + let factory = self.factories.get_mut(&extension_name); + let Some(factory) = factory else { + let err_msg = format!( + "load {} invoker extension failed, can not found extension factory", + extension_name + ); + return Err(InvokerExtensionLoaderError(err_msg).into()); + }; + factory.create(url) + } +} + +type InvokerExtensionConstructor = fn( + Url, +) -> Pin< + Box, StdError>> + Send + 'static>, +>; +pub(crate) struct InvokerExtensionFactory { + constructor: InvokerExtensionConstructor, + instances: HashMap>, +} + +impl InvokerExtensionFactory { + pub fn new(constructor: InvokerExtensionConstructor) -> Self { + Self { + constructor, + instances: HashMap::default(), + } + } +} + +impl InvokerExtensionFactory { + pub fn create(&mut self, url: Url) -> Result, StdError> { + let key = url.to_string(); + + match self.instances.get(&key) { + Some(instance) => Ok(instance.clone()), + None => { + let constructor = self.constructor; + let creator = move |url: Url| { + let invoker_future = constructor(url); + Box::pin(async move { + let invoker = invoker_future.await?; + Ok(InvokerProxy::from(invoker)) + }) + as Pin< + Box< + dyn Future> + + Send + + 'static, + >, + > + }; + + let promise: LoadExtensionPromise = + LoadExtensionPromise::new(Box::new(creator), url); + self.instances.insert(key, promise.clone()); + Ok(promise) + } + } + } +} + +pub struct InvokerExtension(PhantomData) +where + T: Invoker + Send + 'static; + +impl ExtensionMetaInfo for InvokerExtension +where + T: Invoker + Send + 'static, + T: Extension>, +{ + fn name() -> String { + T::name() + } + + fn extension_type() -> ExtensionType { + ExtensionType::Invoker + } + + fn extension_factory() -> ExtensionFactories { + ExtensionFactories::InvokerExtensionFactory(InvokerExtensionFactory::new( + ::create, + )) + } +} + +#[derive(Error, Debug)] +#[error("{0}")] +pub struct InvokerExtensionLoaderError(String); + +impl InvokerExtensionLoaderError { + pub fn new(msg: &str) -> Self { + InvokerExtensionLoaderError(msg.to_string()) + } +} diff --git a/dubbo/src/extension/mod.rs b/dubbo/src/extension/mod.rs index 1229ff4e..724b64c7 100644 --- a/dubbo/src/extension/mod.rs +++ b/dubbo/src/extension/mod.rs @@ -15,10 +15,14 @@ * limitations under the License. */ +mod invoker_extension; pub mod registry_extension; use crate::{ - extension::registry_extension::proxy::RegistryProxy, + extension::{ + invoker_extension::proxy::InvokerProxy, + registry_extension::{proxy::RegistryProxy, RegistryExtension}, + }, logger::tracing::{error, info}, params::extension_param::ExtensionType, registry::registry::StaticRegistry, @@ -35,6 +39,7 @@ pub static EXTENSIONS: once_cell::sync::Lazy = #[derive(Default)] struct ExtensionDirectory { registry_extension_loader: registry_extension::RegistryExtensionLoader, + invoker_extension_loader: invoker_extension::InvokerExtensionLoader, } impl ExtensionDirectory { @@ -47,8 +52,8 @@ impl ExtensionDirectory { // register static registry extension let _ = extension_directory.register( StaticRegistry::name(), - StaticRegistry::convert_to_extension_factories(), - ExtensionType::Registry, + RegistryExtension::::extension_factory(), + RegistryExtension::::extension_type(), ); while let Some(extension_opt) = rx.recv().await { @@ -93,6 +98,15 @@ impl ExtensionDirectory { .register(extension_name, registry_extension_factory); Ok(()) } + _ => Ok(()), + }, + ExtensionType::Invoker => match extension_factories { + ExtensionFactories::InvokerExtensionFactory(invoker_extension_factory) => { + self.invoker_extension_loader + .register(extension_name, invoker_extension_factory); + Ok(()) + } + _ => Ok(()), }, } } @@ -107,6 +121,10 @@ impl ExtensionDirectory { self.registry_extension_loader.remove(extension_name); Ok(()) } + ExtensionType::Invoker => { + self.invoker_extension_loader.remove(extension_name); + Ok(()) + } } } @@ -128,14 +146,37 @@ impl ExtensionDirectory { let _ = callback.send(Ok(Extensions::Registry(extension))); } Err(err) => { - error!("load extension failed: {}", err); + error!("load registry extension failed: {}", err); let _ = callback.send(Err(err)); } } }); } Err(err) => { - error!("load extension failed: {}", err); + error!("load registry extension failed: {}", err); + let _ = callback.send(Err(err)); + } + } + } + ExtensionType::Invoker => { + let extension = self.invoker_extension_loader.load(url); + match extension { + Ok(mut extension) => { + tokio::spawn(async move { + let invoker = extension.resolve().await; + match invoker { + Ok(invoker) => { + let _ = callback.send(Ok(Extensions::Invoker(invoker))); + } + Err(err) => { + error!("load invoker extension failed: {}", err); + let _ = callback.send(Err(err)); + } + } + }); + } + Err(err) => { + error!("load invoker extension failed: {}", err); let _ = callback.send(Err(err)); } } @@ -241,12 +282,10 @@ impl ExtensionDirectoryCommander { #[allow(private_bounds)] pub async fn register(&self) -> Result<(), StdError> where - T: Extension, T: ExtensionMetaInfo, - T: ConvertToExtensionFactories, { let extension_name = T::name(); - let extension_factories = T::convert_to_extension_factories(); + let extension_factories = T::extension_factory(); let extension_type = T::extension_type(); info!( @@ -286,7 +325,6 @@ impl ExtensionDirectoryCommander { #[allow(private_bounds)] pub async fn remove(&self) -> Result<(), StdError> where - T: Extension, T: ExtensionMetaInfo, { let extension_name = T::name(); @@ -355,6 +393,45 @@ impl ExtensionDirectoryCommander { match extensions { Extensions::Registry(proxy) => Ok(proxy), + _ => { + panic!("load registry extension failed: invalid extension type"); + } + } + } + + pub async fn load_invoker(&self, url: Url) -> Result { + let url_str = url.to_string(); + info!("load invoker extension: {}", url_str); + + let (tx, rx) = oneshot::channel(); + + let send = self + .sender + .send(ExtensionOpt::Load(url, ExtensionType::Invoker, tx)) + .await; + + let Ok(_) = send else { + let err_msg = format!("load invoker extension failed: {}", url_str); + return Err(LoadExtensionError::new(err_msg).into()); + }; + + let extensions = rx.await; + + let Ok(extension) = extensions else { + let err_msg = format!("load invoker extension failed: {}", url_str); + return Err(LoadExtensionError::new(err_msg).into()); + }; + + let Ok(extensions) = extension else { + let err_msg = format!("load invoker extension failed: {}", url_str); + return Err(LoadExtensionError::new(err_msg).into()); + }; + + match extensions { + Extensions::Invoker(proxy) => Ok(proxy), + _ => { + panic!("load invoker extension failed: invalid extension type"); + } } } } @@ -374,11 +451,9 @@ enum ExtensionOpt { ), } -pub(crate) trait Sealed {} - #[allow(private_bounds)] #[async_trait::async_trait] -pub trait Extension: Sealed { +pub trait Extension { type Target; fn name() -> String; @@ -388,20 +463,19 @@ pub trait Extension: Sealed { #[allow(private_bounds)] pub(crate) trait ExtensionMetaInfo { + fn name() -> String; fn extension_type() -> ExtensionType; + fn extension_factory() -> ExtensionFactories; } pub(crate) enum Extensions { Registry(RegistryProxy), + Invoker(InvokerProxy), } pub(crate) enum ExtensionFactories { RegistryExtensionFactory(registry_extension::RegistryExtensionFactory), -} - -#[allow(private_bounds)] -pub(crate) trait ConvertToExtensionFactories { - fn convert_to_extension_factories() -> ExtensionFactories; + InvokerExtensionFactory(invoker_extension::InvokerExtensionFactory), } #[derive(Error, Debug)] diff --git a/dubbo/src/extension/registry_extension.rs b/dubbo/src/extension/registry_extension.rs index d625f02e..2e9291b3 100644 --- a/dubbo/src/extension/registry_extension.rs +++ b/dubbo/src/extension/registry_extension.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use std::{collections::HashMap, future::Future, pin::Pin}; +use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin}; use async_trait::async_trait; use thiserror::Error; @@ -30,8 +30,7 @@ use crate::{ use proxy::RegistryProxy; use crate::extension::{ - ConvertToExtensionFactories, Extension, ExtensionFactories, ExtensionMetaInfo, ExtensionType, - LoadExtensionPromise, + Extension, ExtensionFactories, ExtensionMetaInfo, ExtensionType, LoadExtensionPromise, }; // extension://0.0.0.0/?extension-type=registry&extension-name=nacos®istry-url=nacos://127.0.0.1:8848 @@ -63,24 +62,24 @@ pub trait Registry { fn url(&self) -> &Url; } -impl crate::extension::Sealed for T where T: Registry + Send + Sync + 'static {} +pub struct RegistryExtension(PhantomData) +where + T: Registry + Send + Sync + 'static; -impl ExtensionMetaInfo for T +impl ExtensionMetaInfo for RegistryExtension where T: Registry + Send + Sync + 'static, T: Extension>, { + fn name() -> String { + T::name() + } + fn extension_type() -> ExtensionType { ExtensionType::Registry } -} -impl ConvertToExtensionFactories for T -where - T: Registry + Send + Sync + 'static, - T: Extension>, -{ - fn convert_to_extension_factories() -> ExtensionFactories { + fn extension_factory() -> ExtensionFactories { ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new( ::create, )) diff --git a/dubbo/src/params/extension_param.rs b/dubbo/src/params/extension_param.rs index 93e0a16c..08ec1c97 100644 --- a/dubbo/src/params/extension_param.rs +++ b/dubbo/src/params/extension_param.rs @@ -51,6 +51,7 @@ impl FromStr for ExtensionName { pub enum ExtensionType { Registry, + Invoker, } impl UrlParam for ExtensionType { @@ -63,12 +64,14 @@ impl UrlParam for ExtensionType { fn value(&self) -> Self::TargetType { match self { ExtensionType::Registry => "registry".to_owned(), + ExtensionType::Invoker => "invoker".to_owned(), } } fn as_str(&self) -> Cow { match self { ExtensionType::Registry => Cow::Borrowed("registry"), + ExtensionType::Invoker => Cow::Borrowed("invoker"), } } } diff --git a/examples/echo/src/generated/grpc.examples.echo.rs b/examples/echo/src/generated/grpc.examples.echo.rs index fc48dc5c..ee8cc1e6 100644 --- a/examples/echo/src/generated/grpc.examples.echo.rs +++ b/examples/echo/src/generated/grpc.examples.echo.rs @@ -43,7 +43,9 @@ pub mod echo_client { let invocation = RpcInvocation::default() .with_service_unique_name(String::from("grpc.examples.echo.Echo")) .with_method_name(String::from("UnaryEcho")); - let path = http::uri::PathAndQuery::from_static("/grpc.examples.echo.Echo/UnaryEcho"); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/UnaryEcho", + ); self.inner.unary(request, path, invocation).await } /// ServerStreamingEcho is server side streaming. @@ -100,7 +102,9 @@ pub mod echo_server { request: Request, ) -> Result, dubbo::status::Status>; ///Server streaming response type for the ServerStreamingEcho method. - type ServerStreamingEchoStream: futures_util::Stream> + type ServerStreamingEchoStream: futures_util::Stream< + Item = Result, + > + Send + 'static; /// ServerStreamingEcho is server side streaming. @@ -114,14 +118,19 @@ pub mod echo_server { request: Request>, ) -> Result, dubbo::status::Status>; ///Server streaming response type for the BidirectionalStreamingEcho method. - type BidirectionalStreamingEchoStream: futures_util::Stream> + type BidirectionalStreamingEchoStream: futures_util::Stream< + Item = Result, + > + Send + 'static; /// BidirectionalStreamingEcho is bidi streaming. async fn bidirectional_streaming_echo( &self, request: Request>, - ) -> Result, dubbo::status::Status>; + ) -> Result< + Response, + dubbo::status::Status, + >; } /// Echo is the echo service. #[derive(Debug)] @@ -151,7 +160,10 @@ pub mod echo_server { type Response = http::Response; type Error = std::convert::Infallible; type Future = BoxFuture; - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: http::Request) -> Self::Future { @@ -164,16 +176,24 @@ pub mod echo_server { } impl UnarySvc for UnaryEchoServer { type Response = super::EchoResponse; - type Future = BoxFuture, dubbo::status::Status>; - fn call(&mut self, request: Request) -> Self::Future { + type Future = BoxFuture< + Response, + dubbo::status::Status, + >; + fn call( + &mut self, + request: Request, + ) -> Self::Future { let inner = self.inner.0.clone(); let fut = async move { inner.unary_echo(request).await }; Box::pin(fut) } } let fut = async move { - let mut server = - TripleServer::::new(); + let mut server = TripleServer::< + super::EchoRequest, + super::EchoResponse, + >::new(); let res = server.unary(UnaryEchoServer { inner }, req).await; Ok(res) }; @@ -184,20 +204,30 @@ pub mod echo_server { struct ServerStreamingEchoServer { inner: _Inner, } - impl ServerStreamingSvc for ServerStreamingEchoServer { + impl ServerStreamingSvc + for ServerStreamingEchoServer { type Response = super::EchoResponse; type ResponseStream = T::ServerStreamingEchoStream; - type Future = - BoxFuture, dubbo::status::Status>; - fn call(&mut self, request: Request) -> Self::Future { + type Future = BoxFuture< + Response, + dubbo::status::Status, + >; + fn call( + &mut self, + request: Request, + ) -> Self::Future { let inner = self.inner.0.clone(); - let fut = async move { inner.server_streaming_echo(request).await }; + let fut = async move { + inner.server_streaming_echo(request).await + }; Box::pin(fut) } } let fut = async move { - let mut server = - TripleServer::::new(); + let mut server = TripleServer::< + super::EchoRequest, + super::EchoResponse, + >::new(); let res = server .server_streaming(ServerStreamingEchoServer { inner }, req) .await; @@ -210,21 +240,29 @@ pub mod echo_server { struct ClientStreamingEchoServer { inner: _Inner, } - impl ClientStreamingSvc for ClientStreamingEchoServer { + impl ClientStreamingSvc + for ClientStreamingEchoServer { type Response = super::EchoResponse; - type Future = BoxFuture, dubbo::status::Status>; + type Future = BoxFuture< + Response, + dubbo::status::Status, + >; fn call( &mut self, request: Request>, ) -> Self::Future { let inner = self.inner.0.clone(); - let fut = async move { inner.client_streaming_echo(request).await }; + let fut = async move { + inner.client_streaming_echo(request).await + }; Box::pin(fut) } } let fut = async move { - let mut server = - TripleServer::::new(); + let mut server = TripleServer::< + super::EchoRequest, + super::EchoResponse, + >::new(); let res = server .client_streaming(ClientStreamingEchoServer { inner }, req) .await; @@ -237,39 +275,54 @@ pub mod echo_server { struct BidirectionalStreamingEchoServer { inner: _Inner, } - impl StreamingSvc for BidirectionalStreamingEchoServer { + impl StreamingSvc + for BidirectionalStreamingEchoServer { type Response = super::EchoResponse; type ResponseStream = T::BidirectionalStreamingEchoStream; - type Future = - BoxFuture, dubbo::status::Status>; + type Future = BoxFuture< + Response, + dubbo::status::Status, + >; fn call( &mut self, request: Request>, ) -> Self::Future { let inner = self.inner.0.clone(); - let fut = - async move { inner.bidirectional_streaming_echo(request).await }; + let fut = async move { + inner.bidirectional_streaming_echo(request).await + }; Box::pin(fut) } } let fut = async move { - let mut server = - TripleServer::::new(); + let mut server = TripleServer::< + super::EchoRequest, + super::EchoResponse, + >::new(); let res = server - .bidi_streaming(BidirectionalStreamingEchoServer { inner }, req) + .bidi_streaming( + BidirectionalStreamingEchoServer { + inner, + }, + req, + ) .await; Ok(res) }; Box::pin(fut) } - _ => Box::pin(async move { - Ok(http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(empty_body()) - .unwrap()) - }), + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } } } } diff --git a/examples/greeter/src/greeter/client.rs b/examples/greeter/src/greeter/client.rs index d2edae5f..f6fea892 100644 --- a/examples/greeter/src/greeter/client.rs +++ b/examples/greeter/src/greeter/client.rs @@ -22,7 +22,7 @@ pub mod protos { use dubbo::codegen::*; -use dubbo::extension; +use dubbo::{extension, extension::registry_extension::RegistryExtension}; use futures_util::StreamExt; use protos::{greeter_client::GreeterClient, GreeterRequest}; use registry_nacos::NacosRegistry; @@ -31,7 +31,9 @@ use registry_nacos::NacosRegistry; async fn main() { dubbo::logger::init(); - let _ = extension::EXTENSIONS.register::().await; + let _ = extension::EXTENSIONS + .register::>() + .await; let builder = ClientBuilder::new().with_registry("nacos://127.0.0.1:8848".parse().unwrap()); diff --git a/examples/greeter/src/greeter/server.rs b/examples/greeter/src/greeter/server.rs index 17a3ac49..aecc1f8e 100644 --- a/examples/greeter/src/greeter/server.rs +++ b/examples/greeter/src/greeter/server.rs @@ -27,6 +27,7 @@ use dubbo::{ codegen::*, config::RootConfig, extension, + extension::registry_extension::RegistryExtension, logger::{ tracing::{info, span}, Level, @@ -60,7 +61,9 @@ async fn main() { Err(_err) => panic!("err: {:?}", _err), // response was droped }; - let _ = extension::EXTENSIONS.register::().await; + let _ = extension::EXTENSIONS + .register::>() + .await; let mut f = Dubbo::new() .with_config(r) .add_registry("nacos://127.0.0.1:8848/");