diff --git a/dubbo/src/extension/mod.rs b/dubbo/src/extension/mod.rs index bfd198a8..412afc7f 100644 --- a/dubbo/src/extension/mod.rs +++ b/dubbo/src/extension/mod.rs @@ -24,7 +24,7 @@ use dubbo_base::{extension_param::ExtensionType, url::UrlParam, StdError, Url}; use dubbo_logger::tracing::{error, info}; use std::{future::Future, pin::Pin, sync::Arc}; use thiserror::Error; -use tokio::sync::{oneshot, Semaphore}; +use tokio::sync::{oneshot, RwLock}; pub static EXTENSIONS: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| ExtensionDirectory::init()); @@ -141,61 +141,95 @@ impl ExtensionDirectory { } } +type ExtensionCreator = Box< + dyn Fn(Url) -> Pin> + Send + 'static>> + + Send + + Sync + + 'static, +>; +pub(crate) struct ExtensionPromiseResolver { + resolved_data: Option, + creator: ExtensionCreator, + url: Url, +} + +impl ExtensionPromiseResolver +where + T: Send + Clone + 'static, +{ + fn new(creator: ExtensionCreator, url: Url) -> Self { + ExtensionPromiseResolver { + resolved_data: None, + creator, + url, + } + } + + fn resolved_data(&self) -> Option { + self.resolved_data.clone() + } + + async fn resolve(&mut self) -> Result { + match (self.creator)(self.url.clone()).await { + Ok(data) => { + self.resolved_data = Some(data.clone()); + Ok(data) + } + Err(err) => { + error!("create extension failed: {}", err); + Err(LoadExtensionError::new( + "load extension failed, create extension occur an error".to_string(), + ) + .into()) + } + } + } +} + pub(crate) struct LoadExtensionPromise { - extension: Arc>, - fut: Option> + Send + 'static>>>, - semaphore: Arc, + resolver: Arc>>, } impl LoadExtensionPromise where T: Send + Clone + 'static, { - pub(crate) fn new( - fut: Pin> + Send + 'static>>, - ) -> Self { + pub(crate) fn new(creator: ExtensionCreator, url: Url) -> Self { + let resolver = ExtensionPromiseResolver::new(creator, url); LoadExtensionPromise { - extension: Arc::new(None), - fut: Some(fut), - semaphore: Arc::new(Semaphore::new(0)), + resolver: Arc::new(RwLock::new(resolver)), } } - fn get_extension(&self) -> Option { - self.extension.as_ref().as_ref().map(|a| a.clone()) - } - pub(crate) async fn resolve(&mut self) -> Result { - let extension = self.get_extension(); - if let Some(extension) = extension { + // get read lock + let resolver_read_lock = self.resolver.read().await; + // if extension is not None, return it + if let Some(extension) = resolver_read_lock.resolved_data() { return Ok(extension); } + drop(resolver_read_lock); - let fut = self.fut.take(); - let Some(mut fut) = fut else { - let _ = self.semaphore.acquire().await; + let resolver_write_lock = self.resolver.try_write(); + let Ok(mut resolver_write_lock) = resolver_write_lock else { + // can not get write lock + // wait until this extension is created + let resolver_read_lock = self.resolver.read().await; // check it again - let extension = self.get_extension(); - if let Some(extension) = extension { - info!("promise has been resolved."); + if let Some(extension) = resolver_read_lock.resolved_data() { return Ok(extension); } + return Err(LoadExtensionError::new("load extension canceled ".to_string()).into()); }; - match fut.as_mut().await { + match resolver_write_lock.resolve().await { Ok(extension) => { info!("create extension success"); - let ptr = Arc::as_ptr(&self.extension) as *mut Option; - unsafe { - *ptr = Some(extension.clone()); - } - self.semaphore.close(); Ok(extension) } Err(err) => { error!("create extension failed: {}", err); - self.semaphore.close(); Err(LoadExtensionError::new( "load extension failed, create extension occur an error".to_string(), ) @@ -208,9 +242,7 @@ where impl Clone for LoadExtensionPromise { fn clone(&self) -> Self { LoadExtensionPromise { - extension: self.extension.clone(), - fut: None, - semaphore: self.semaphore.clone(), + resolver: self.resolver.clone(), } } } diff --git a/dubbo/src/extension/registry_extension.rs b/dubbo/src/extension/registry_extension.rs index b00fafa6..ce998bce 100644 --- a/dubbo/src/extension/registry_extension.rs +++ b/dubbo/src/extension/registry_extension.rs @@ -79,26 +79,9 @@ where T: Extension>, { fn convert_to_extension_factories() -> ExtensionFactories { - fn constrain(f: F) -> F - where - F: Fn( - Url, - ) -> Pin< - Box< - dyn Future, StdError>> - + Send, - >, - >, - { - f - } - - let constructor = constrain(|url: Url| { - let f = ::create(url); - Box::pin(f) - }); - - ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new(constructor)) + ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new( + ::create, + )) } } @@ -166,14 +149,26 @@ impl RegistryExtensionFactory { Ok(proxy) } None => { - let registry = (self.constructor)(url); - let fut = Box::pin(async move { - let registry = registry.await?; - let proxy = >>::from(registry); - Ok(proxy) - }); - - let promise = LoadExtensionPromise::new(fut); + let constructor = self.constructor; + + let creator = move |url: Url| { + let registry = constructor(url); + Box::pin(async move { + let registry = registry.await?; + let proxy = + >>::from(registry); + Ok(proxy) + }) + as Pin< + Box< + dyn Future> + + Send + + 'static, + >, + > + }; + + let promise = LoadExtensionPromise::new(Box::new(creator), url); self.instances.insert(url_str, promise.clone()); Ok(promise) }