Skip to content

Commit

Permalink
poc: PooledPlanner
Browse files Browse the repository at this point in the history
  • Loading branch information
Marc-Andre Giroux committed Mar 12, 2024
1 parent 35052d8 commit 95b3440
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 0 deletions.
5 changes: 5 additions & 0 deletions federation-2/router-bridge/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ pub enum Error {
/// The deno response id we tried to deserialize.
id: String,
},
/// An uncaught error was raised when invoking a custom script.
///
/// This contains the script invocation error message.
#[error("internal error: `{0}`")]
Internal(&'static str),
}
1 change: 1 addition & 0 deletions federation-2/router-bridge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pub mod introspect;
mod js;
pub mod planner;
mod worker;
mod pool;
222 changes: 222 additions & 0 deletions federation-2/router-bridge/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Formatter;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::sync::Arc;

use serde::de::DeserializeOwned;
Expand All @@ -14,6 +15,7 @@ use serde::Serialize;
use thiserror::Error;

use crate::introspect::IntrospectionResponse;
use crate::pool::JsWorkerPool;
use crate::worker::JsWorker;

// ------------------------------------
Expand Down Expand Up @@ -398,6 +400,226 @@ where
}
}

/// A Deno worker backed query Planner,
/// using a pool of JsRuntimes load balanced
/// using Power of Two Choices.
pub struct PooledPlanner<T>
where
T: DeserializeOwned + Send + Debug + 'static,
{
pool: Arc<JsWorkerPool>,
schema_id: u64,
t: PhantomData<T>,
}

impl<T> Debug for PooledPlanner<T>
where
T: DeserializeOwned + Send + Debug + 'static,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledPlanner")
.field("schema_id", &self.schema_id)
.finish()
}
}

impl<T> PooledPlanner<T>
where
T: DeserializeOwned + Send + Debug + 'static,
{
/// Instantiate a `Planner` from a schema string
pub async fn new(
schema: String,
config: QueryPlannerConfig,
pool_size: NonZeroUsize,
) -> Result<Self, Vec<PlannerError>> {
let schema_id: u64 = rand::random();

let pool = JsWorkerPool::new(include_str!("../bundled/plan_worker.js"), pool_size);

let workers_are_setup = pool
.broadcast_request::<PlanCmd, BridgeSetupResult<serde_json::Value>>(PlanCmd::UpdateSchema {
schema,
config,
schema_id,
})
.await
.map_err(|e| {
vec![WorkerError {
name: Some("planner setup error".to_string()),
message: Some(e.to_string()),
stack: None,
extensions: None,
locations: Default::default(),
}
.into()]
});

// Both cases below the mean schema update failed.
// We need to pay attention here.
// returning early will drop the worker, which will join the jsruntime thread.
// however the event loop will run for ever. We need to let the worker know it needs to exit,
// before we drop the worker
match workers_are_setup {
Err(setup_error) => {
let _ = pool
.broadcast_request::<PlanCmd, serde_json::Value>(PlanCmd::Exit { schema_id })
.await;
return Err(setup_error);
}
Ok(responses) => {
for r in responses {
if let Some(error) = r.errors {
let _ = pool.broadcast_send(None, PlanCmd::Exit { schema_id }).await;
return Err(error);
}
}
}
}

let pool = Arc::new(pool);

Ok(Self {
pool,
schema_id,
t: PhantomData,
})
}

/// Update `Planner` from a schema string
pub async fn update(
&self,
schema: String,
config: QueryPlannerConfig,
) -> Result<Self, Vec<PlannerError>> {
let schema_id: u64 = rand::random();

let workers_are_setup = self
.pool
.broadcast_request::<PlanCmd, BridgeSetupResult<serde_json::Value>>(PlanCmd::UpdateSchema {
schema,
config,
schema_id,
})
.await
.map_err(|e| {
vec![WorkerError {
name: Some("planner setup error".to_string()),
message: Some(e.to_string()),
stack: None,
extensions: None,
locations: Default::default(),
}
.into()]
});

// If the update failed, we keep the existing schema in place
match workers_are_setup {
Err(setup_error) => {
return Err(setup_error);
}
Ok(responses) => {
for r in responses {
if let Some(error) = r.errors {
let _ = self.pool.broadcast_send(None, PlanCmd::Exit { schema_id }).await;
return Err(error);
}
}
}
}

Ok(Self {
pool: self.pool.clone(),
schema_id,
t: PhantomData,
})
}

/// Plan a query against an instantiated query planner
pub async fn plan(
&self,
query: String,
operation_name: Option<String>,
options: PlanOptions,
) -> Result<PlanResult<T>, crate::error::Error> {
self.pool
.request(PlanCmd::Plan {
query,
operation_name,
schema_id: self.schema_id,
options,
})
.await
}

/// Generate the API schema from the current schema
pub async fn api_schema(&self) -> Result<ApiSchema, crate::error::Error> {
self.pool
.request(PlanCmd::ApiSchema {
schema_id: self.schema_id,
})
.await
}

/// Generate the introspection response for this query
pub async fn introspect(
&self,
query: String,
) -> Result<IntrospectionResponse, crate::error::Error> {
self.pool
.request(PlanCmd::Introspect {
query,
schema_id: self.schema_id,
})
.await
}

/// Get the operation signature for a query
pub async fn operation_signature(
&self,
query: String,
operation_name: Option<String>,
) -> Result<String, crate::error::Error> {
self.pool
.request(PlanCmd::Signature {
query,
operation_name,
schema_id: self.schema_id,
})
.await
}

/// Extract the subgraph schemas from the supergraph schema
pub async fn subgraphs(&self) -> Result<HashMap<String, String>, crate::error::Error> {
self.pool
.request(PlanCmd::Subgraphs {
schema_id: self.schema_id,
})
.await
}
}

impl<T> Drop for PooledPlanner<T>
where
T: DeserializeOwned + Send + Debug + 'static,
{
fn drop(&mut self) {
// Send a PlanCmd::Exit signal
let pool_clone = self.pool.clone();
let schema_id = self.schema_id;
let _ = std::thread::spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();

let _ = runtime.block_on(async move {
pool_clone.broadcast_send(None, PlanCmd::Exit { schema_id }).await
});
})
.join();
}
}

/// A Deno worker backed query Planner.
pub struct Planner<T>
Expand Down
125 changes: 125 additions & 0 deletions federation-2/router-bridge/src/pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::sync::atomic::Ordering;
use std::{num::NonZeroUsize, sync::atomic::AtomicUsize};
use std::fmt::Debug;
use rand::Rng;
use serde::de::DeserializeOwned;
use serde::Serialize;

use tokio::task::JoinSet;
use std::sync::Arc;

use crate::{error::Error, worker::JsWorker};

pub(crate) struct JsWorkerPool {
workers: Vec<Arc<JsWorker>>,
pending_requests: Vec<AtomicUsize>
}

impl JsWorkerPool {
pub(crate) fn new(worker_source_code: &'static str, size: NonZeroUsize) -> Self {
let workers: Vec<Arc<JsWorker>> = (0..size.into())
.map(|_| Arc::new(JsWorker::new(worker_source_code)))
.collect();

let pending_requests: Vec<AtomicUsize> = (0..size.into()).map(|_| AtomicUsize::new(0)).collect();

Self { workers, pending_requests }
}

pub(crate) async fn request<Request, Response>(
&self,
command: Request,
) -> Result<Response, Error>
where
Request: std::hash::Hash + Serialize + Send + Debug + 'static,
Response: DeserializeOwned + Send + Debug + 'static,
{
let (i, worker) = self.choice_of_two();

self.pending_requests[i].fetch_add(1, Ordering::SeqCst);
let result = worker.request(command).await;
self.pending_requests[i].fetch_add(1, Ordering::SeqCst);

result
}

pub(crate) async fn broadcast_request<Request, Response>(
&self,
command: Request
) -> Result<Vec<Response>, Error>
where
Request: std::hash::Hash + Serialize + Send + Debug + Clone + 'static,
Response: DeserializeOwned + Send + Debug + 'static,
{
let mut join_set = JoinSet::new();

for worker in self.workers.iter().cloned() {
let command_clone = command.clone();

join_set.spawn(async move {
worker.request(command_clone).await
});
}

let mut responses = Vec::new();

while let Some(result) = join_set.join_next().await {
let response = result.map_err(|_e| Error::Internal("could not join spawned task"))?;
responses.push(response?);
}

Ok(responses)
}

pub(crate) async fn broadcast_send<Request>(
&self,
id_opt: Option<String>,
request: Request,
) -> Result<(), Error>
where
Request: std::hash::Hash + Serialize + Send + Debug + Clone + 'static,
{
let mut join_set = JoinSet::new();

for worker in self.workers.iter().cloned() {
let request_clone = request.clone();
let id_opt_clone = id_opt.clone();

join_set.spawn(async move {
worker.send(id_opt_clone, request_clone).await
});
}

let mut results = Vec::new();

while let Some(result) = join_set.join_next().await {
let result = result.map_err(|_e| Error::Internal("could not join spawned task"))?;
results.push(result?);
}

Ok(())
}

fn choice_of_two(&self) -> (usize, &JsWorker) {
let mut rng = rand::thread_rng();

let len = self.workers.len();

let index1 = rng.gen_range(0..len);
let mut index2 = rng.gen_range(0..len);
while index2 == index1 {
index2 = rng.gen_range(0..len);
}

let index1_load = &self.pending_requests[index1].load(Ordering::SeqCst);
let index2_load = &self.pending_requests[index2].load(Ordering::SeqCst);

let choice = if index1_load < index2_load {
index1
} else {
index2
};

(choice, &self.workers[choice])
}
}

0 comments on commit 95b3440

Please sign in to comment.