Skip to content

Commit

Permalink
Ftr: use RwLock instead of unsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
onewe committed Mar 18, 2024
1 parent 4e02c28 commit 659e66f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 60 deletions.
96 changes: 64 additions & 32 deletions dubbo/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use dubbo_base::{extension_param::ExtensionType, url::UrlParam, StdError, Url};
use dubbo_logger::tracing::{error, info};
use std::{future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::sync::{oneshot, Semaphore};
use tokio::sync::{oneshot, RwLock};

pub static EXTENSIONS: once_cell::sync::Lazy<ExtensionDirectoryCommander> =
once_cell::sync::Lazy::new(|| ExtensionDirectory::init());
Expand Down Expand Up @@ -141,61 +141,95 @@ impl ExtensionDirectory {
}
}

type ExtensionCreator<T> = Box<
dyn Fn(Url) -> Pin<Box<dyn Future<Output = Result<T, StdError>> + Send + 'static>>
+ Send
+ Sync
+ 'static,
>;
pub(crate) struct ExtensionPromiseResolver<T> {
resolved_data: Option<T>,
creator: ExtensionCreator<T>,
url: Url,
}

impl<T> ExtensionPromiseResolver<T>
where
T: Send + Clone + 'static,
{
fn new(creator: ExtensionCreator<T>, url: Url) -> Self {
ExtensionPromiseResolver {
resolved_data: None,
creator,
url,
}
}

fn resolved_data(&self) -> Option<T> {
self.resolved_data.clone()
}

async fn resolve(&mut self) -> Result<T, StdError> {
match (self.creator)(self.url.clone()).await {
Ok(data) => {
self.resolved_data = Some(data.clone());
Ok(data)
}
Err(err) => {
error!("create extension failed: {}", err);
Err(LoadExtensionError::new(
"load extension failed, create extension occur an error".to_string(),
)
.into())
}
}
}
}

pub(crate) struct LoadExtensionPromise<T> {
extension: Arc<Option<T>>,
fut: Option<Pin<Box<dyn Future<Output = Result<T, StdError>> + Send + 'static>>>,
semaphore: Arc<Semaphore>,
resolver: Arc<RwLock<ExtensionPromiseResolver<T>>>,
}

impl<T> LoadExtensionPromise<T>
where
T: Send + Clone + 'static,
{
pub(crate) fn new(
fut: Pin<Box<dyn Future<Output = Result<T, StdError>> + Send + 'static>>,
) -> Self {
pub(crate) fn new(creator: ExtensionCreator<T>, url: Url) -> Self {
let resolver = ExtensionPromiseResolver::new(creator, url);
LoadExtensionPromise {
extension: Arc::new(None),
fut: Some(fut),
semaphore: Arc::new(Semaphore::new(0)),
resolver: Arc::new(RwLock::new(resolver)),
}
}

fn get_extension(&self) -> Option<T> {
self.extension.as_ref().as_ref().map(|a| a.clone())
}

pub(crate) async fn resolve(&mut self) -> Result<T, StdError> {
let extension = self.get_extension();
if let Some(extension) = extension {
// get read lock
let resolver_read_lock = self.resolver.read().await;
// if extension is not None, return it
if let Some(extension) = resolver_read_lock.resolved_data() {
return Ok(extension);
}
drop(resolver_read_lock);

let fut = self.fut.take();
let Some(mut fut) = fut else {
let _ = self.semaphore.acquire().await;
let resolver_write_lock = self.resolver.try_write();
let Ok(mut resolver_write_lock) = resolver_write_lock else {
// can not get write lock
// wait until this extension is created
let resolver_read_lock = self.resolver.read().await;
// check it again
let extension = self.get_extension();
if let Some(extension) = extension {
info!("promise has been resolved.");
if let Some(extension) = resolver_read_lock.resolved_data() {
return Ok(extension);
}

return Err(LoadExtensionError::new("load extension canceled ".to_string()).into());
};

match fut.as_mut().await {
match resolver_write_lock.resolve().await {
Ok(extension) => {
info!("create extension success");
let ptr = Arc::as_ptr(&self.extension) as *mut Option<T>;
unsafe {
*ptr = Some(extension.clone());
}
self.semaphore.close();
Ok(extension)
}
Err(err) => {
error!("create extension failed: {}", err);
self.semaphore.close();
Err(LoadExtensionError::new(
"load extension failed, create extension occur an error".to_string(),
)
Expand All @@ -208,9 +242,7 @@ where
impl<T> Clone for LoadExtensionPromise<T> {
fn clone(&self) -> Self {
LoadExtensionPromise {
extension: self.extension.clone(),
fut: None,
semaphore: self.semaphore.clone(),
resolver: self.resolver.clone(),
}
}
}
Expand Down
51 changes: 23 additions & 28 deletions dubbo/src/extension/registry_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,9 @@ where
T: Extension<Target = Box<dyn Registry + Send + 'static>>,
{
fn convert_to_extension_factories() -> ExtensionFactories {
fn constrain<F>(f: F) -> F
where
F: Fn(
Url,
) -> Pin<
Box<
dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>>
+ Send,
>,
>,
{
f
}

let constructor = constrain(|url: Url| {
let f = <T as Extension>::create(url);
Box::pin(f)
});

ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new(constructor))
ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new(
<T as Extension>::create,
))
}
}

Expand Down Expand Up @@ -166,14 +149,26 @@ impl RegistryExtensionFactory {
Ok(proxy)
}
None => {
let registry = (self.constructor)(url);
let fut = Box::pin(async move {
let registry = registry.await?;
let proxy = <RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
Ok(proxy)
});

let promise = LoadExtensionPromise::new(fut);
let constructor = self.constructor;

let creator = move |url: Url| {
let registry = constructor(url);
Box::pin(async move {
let registry = registry.await?;
let proxy =
<RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
Ok(proxy)
})
as Pin<
Box<
dyn Future<Output = Result<RegistryProxy, StdError>>
+ Send
+ 'static,
>,
>
};

let promise = LoadExtensionPromise::new(Box::new(creator), url);
self.instances.insert(url_str, promise.clone());
Ok(promise)
}
Expand Down

0 comments on commit 659e66f

Please sign in to comment.