Skip to content

Commit

Permalink
Merge pull request #13 from erwanvivien/main
Browse files Browse the repository at this point in the history
Global improvement for wasm_thread
  • Loading branch information
chemicstry authored Mar 9, 2023
2 parents 1984228 + 99ee72a commit 0751fa2
Showing 1 changed file with 188 additions and 48 deletions.
236 changes: 188 additions & 48 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@ use std::any::Any;
use std::fmt;
use std::mem;

use std::sync::Mutex;
pub use std::thread::{current, sleep, Result, Thread, ThreadId};
use std::{
marker::PhantomData,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
time::Duration,
};

use wasm_bindgen::prelude::*;
use wasm_bindgen::*;
Expand All @@ -20,6 +30,9 @@ extern "C" {
fn load_module_workers_polyfill();
}

type DefaultBuilder = Mutex<Option<Builder>>;
static DEFAULT_BUILDER: DefaultBuilder = Mutex::new(None);

/// Extracts path of the `wasm_bindgen` generated .js shim script
pub fn get_wasm_bindgen_shim_script_path() -> String {
js_sys::eval(include_str!("script_path.js"))
Expand All @@ -30,42 +43,40 @@ pub fn get_wasm_bindgen_shim_script_path() -> String {

/// Generates worker entry script as URL encoded blob
pub fn get_worker_script(wasm_bindgen_shim_url: Option<String>) -> String {
unsafe {
static mut SCRIPT_URL: Option<String> = None;
static mut SCRIPT_URL: Option<String> = None;

if let Some(url) = SCRIPT_URL.as_ref() {
url.clone()
} else {
// If wasm bindgen shim url is not provided, try to obtain one automatically
let wasm_bindgen_shim_url =
wasm_bindgen_shim_url.unwrap_or_else(get_wasm_bindgen_shim_script_path);

// Generate script from template
let template;
#[cfg(feature = "es_modules")]
{
template = include_str!("web_worker_module.js");
}
#[cfg(not(feature = "es_modules"))]
{
template = include_str!("web_worker.js");
}
let script = template.replace("WASM_BINDGEN_SHIM_URL", &wasm_bindgen_shim_url);

// Create url encoded blob
let arr = js_sys::Array::new();
arr.set(0, JsValue::from_str(&script));
let blob = Blob::new_with_str_sequence(&arr).unwrap();
let url = Url::create_object_url_with_blob(
&blob
.slice_with_f64_and_f64_and_content_type(0.0, blob.size(), "text/javascript")
.unwrap(),
)
.unwrap();
SCRIPT_URL = Some(url.clone());
if let Some(url) = unsafe { SCRIPT_URL.as_ref() } {
url.clone()
} else {
// If wasm bindgen shim url is not provided, try to obtain one automatically
let wasm_bindgen_shim_url =
wasm_bindgen_shim_url.unwrap_or_else(get_wasm_bindgen_shim_script_path);

url
// Generate script from template
let template;
#[cfg(feature = "es_modules")]
{
template = include_str!("web_worker_module.js");
}
#[cfg(not(feature = "es_modules"))]
{
template = include_str!("web_worker.js");
}
let script = template.replace("WASM_BINDGEN_SHIM_URL", &wasm_bindgen_shim_url);

// Create url encoded blob
let arr = js_sys::Array::new();
arr.set(0, JsValue::from_str(&script));
let blob = Blob::new_with_str_sequence(&arr).unwrap();
let url = Url::create_object_url_with_blob(
&blob
.slice_with_f64_and_f64_and_content_type(0.0, blob.size(), "text/javascript")
.unwrap(),
)
.unwrap();
unsafe { SCRIPT_URL = Some(url.clone()) };

url
}
}

Expand Down Expand Up @@ -98,24 +109,24 @@ enum WorkerMessage {
impl WorkerMessage {
pub fn post(self) {
let req = Box::new(self);
unsafe {
js_sys::eval("self")
.unwrap()
.dyn_into::<DedicatedWorkerGlobalScope>()
.unwrap()
.post_message(&JsValue::from(std::mem::transmute::<_, f64>(
Box::into_raw(req) as u64,
)))
.unwrap();
}
let req = unsafe { std::mem::transmute::<_, f64>(Box::into_raw(req) as u64) };

js_sys::eval("self")
.unwrap()
.dyn_into::<DedicatedWorkerGlobalScope>()
.unwrap()
.post_message(&JsValue::from(req))
.unwrap();
}
}

/// Thread factory, which can be used in order to configure the properties of a new thread.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct Builder {
// A name for the thread-to-be, for identification in panic messages
name: Option<String>,
// A prefix for the thread-to-be, for identification in panic messages
prefix: Option<String>,
// The size of the stack for the spawned thread in bytes
stack_size: Option<usize>,
// Url of the `wasm_bindgen` generated shim `.js` script to use as web worker entry point
Expand All @@ -126,7 +137,18 @@ impl Builder {
/// Generates the base configuration for spawning a thread, from which
/// configuration methods can be chained.
pub fn new() -> Builder {
Builder::default()
let default_builder = DEFAULT_BUILDER.lock().unwrap().clone();
default_builder.unwrap_or(Builder::default())
}

pub fn set_default(self) {
*DEFAULT_BUILDER.lock().unwrap() = Some(self);
}

/// Sets the prefix of the thread-to-be.
pub fn prefix(mut self, prefix: String) -> Builder {
self.prefix = Some(prefix);
self
}

/// Names the thread-to-be.
Expand Down Expand Up @@ -158,6 +180,21 @@ impl Builder {
unsafe { self.spawn_unchecked(f) }
}

pub fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> std::io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
unsafe { self.spawn_unchecked(f) }?,
PhantomData,
))
}

/// Spawns a new thread without any lifetime restrictions by taking ownership
/// of the `Builder`, and returns an [`io::Result`] to its [`JoinHandle`].
///
Expand Down Expand Up @@ -210,6 +247,7 @@ impl Builder {
unsafe fn spawn_for_context(self, ctx: WebWorkerContext) {
let Builder {
name,
prefix,
wasm_bindgen_shim_url,
..
} = self;
Expand All @@ -219,9 +257,20 @@ impl Builder {

// Todo: figure out how to set stack size
let mut options = WorkerOptions::new();
if let Some(name) = name {
options.name(&name);
}
match (name, prefix) {
(Some(name), Some(prefix)) => {
options.name(&format!("{}:{}", prefix, name));
}
(Some(name), None) => {
options.name(&name);
}
(None, Some(prefix)) => {
let random = (js_sys::Math::random() * 10e10) as u64;
options.name(&format!("{}:{}", prefix, random));
}
(None, None) => {}
};

#[cfg(feature = "es_modules")]
{
load_module_workers_polyfill();
Expand Down Expand Up @@ -328,3 +377,94 @@ where
{
Builder::new().spawn(f).expect("failed to spawn thread")
}

use core::num::NonZeroUsize;
pub fn available_parallelism() -> std::io::Result<NonZeroUsize> {
// TODO: Use [Navigator::hardware_concurrency](https://rustwasm.github.io/wasm-bindgen/api/web_sys/struct.Navigator.html#method.hardware_concurrency)
Ok(NonZeroUsize::new(8).unwrap())
}

pub struct ScopeData {
num_running_threads: AtomicUsize,
a_thread_panicked: AtomicBool,
main_thread: Thread,
}

pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
/// Invariance over 'scope, to make sure 'scope cannot shrink,
/// which is necessary for soundness.
///
/// Without invariance, this would compile fine but be unsound:
///
/// ```compile_fail,E0373
/// std::thread::scope(|s| {
/// s.spawn(|| {
/// let a = String::from("abcd");
/// s.spawn(|| println!("{a:?}")); // might run after `a` is dropped
/// });
/// });
/// ```
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}

pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
// We put the `ScopeData` into an `Arc` so that other threads can finish their
// `decrement_num_running_threads` even after this function returns.
let scope = Scope {
data: Arc::new(ScopeData {
num_running_threads: AtomicUsize::new(0),
main_thread: current(),
a_thread_panicked: AtomicBool::new(false),
}),
env: PhantomData,
scope: PhantomData,
};

// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));

// Wait until all the threads are finished.
while scope.data.num_running_threads.load(Ordering::Acquire) != 0 {
// park();
// TODO: Replaced by a wasm-friendly version of park()
sleep(Duration::from_millis(1));
}

// Throw any panic from `f`, or the return value of `f` if no thread panicked.
match result {
Err(e) => resume_unwind(e),
Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => {
panic!("a scoped thread panicked")
}
Ok(result) => result,
}
}

pub struct ScopedJoinHandle<'scope, T>(crate::JoinHandle<T>, PhantomData<&'scope ()>);
impl<'scope, T> ScopedJoinHandle<'scope, T> {
pub fn join(self) -> std::io::Result<T> {
self.0
.join()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, ""))
}
}

pub fn spawn_scoped<'scope, 'env, F, T>(
builder: crate::Builder,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> std::io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
unsafe { builder.spawn_unchecked(f) }?,
PhantomData,
))
}

0 comments on commit 0751fa2

Please sign in to comment.