Skip to content

Commit

Permalink
Multi async build (#29)
Browse files Browse the repository at this point in the history
* Simplify resource cleaning.

* Basic implementation of multi async build.

* Has `MAX_QUEUE_SIZE`.

* Tune `MAX_QUEUE_SIZE` to be 2.

* Less pipeline miss.

* Select the first finished job candidate.

* Improve cache performance.

* Optimize cache x2.

* Adjust the number of running tasks on the fly.

* Revert "Adjust the number of running tasks on the fly."

This reverts commit d5fab85.
  • Loading branch information
cryscan authored May 8, 2024
1 parent e517f59 commit 58275ff
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 75 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ instant = { version = "0.1", features = ["inaccurate", "wasm-bindgen"] }
itertools = "0.12"
log = "0.4"
regex = "1.10"
rustc-hash = "1.1.0"
safetensors = "0.4"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_bytes = "0.11.14"
Expand Down
17 changes: 7 additions & 10 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,8 @@ impl<'a> ContextBuilder {
pub struct Macros(Vec<(String, String)>);

impl Macros {
// pub fn new(block_size: u32) -> Self {
// Self(vec![("BLOCK_SIZE".into(), format!("{}u", block_size))])
// }
pub fn new() -> Self {
Self(vec![])
Default::default()
}
}

Expand All @@ -201,8 +198,8 @@ struct PipelineKey {
}

impl PipelineKey {
fn new(name: String, entry_point: String, mut macros: Macros) -> Self {
macros.0.sort();
fn new(name: String, entry_point: String, macros: Macros) -> Self {
// macros.0.sort();
Self {
name,
entry_point,
Expand Down Expand Up @@ -238,11 +235,11 @@ impl ContextInternal {
let entry_point = entry_point.as_ref();
let key = PipelineKey::new(name.into(), entry_point.into(), macros.clone());

use gpp::{process_str, Context};
let mut context = Context::new();
context.macros = macros.0.into_iter().collect();
self.pipeline_cache.checkout(key, move || {
use gpp::{process_str, Context};
let mut context = Context::new();
context.macros = macros.0.into_iter().collect();

self.pipeline_cache.checkout(key, || {
let shader = process_str(source.as_ref(), &mut context).unwrap();
let module = &self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
Expand Down
10 changes: 9 additions & 1 deletion src/runtime/infer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use itertools::Itertools;
use web_rwkv_derive::{Deref, DerefMut};

use super::JobInput;
use super::{JobInfo, JobInput};
use crate::tensor::TensorCpu;

pub const MIN_TOKEN_CHUNK_SIZE: usize = 32;
Expand Down Expand Up @@ -69,6 +69,13 @@ impl InferInfo {
}
}

impl JobInfo for InferInfo {
#[inline]
fn check(&self, info: &Self) -> bool {
self.num_token() == info.num_token() && self.redirect() == info.redirect()
}
}

#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct InferRedirect {
/// Indices in the *input* tensor that are included in the output.
Expand Down Expand Up @@ -199,6 +206,7 @@ impl IntoIterator for &InferInput {
}
}

#[derive(Debug, Clone)]
pub struct InferIter {
batches: Vec<(BatchState, InferOption)>,
token_chunk_size: usize,
Expand Down
78 changes: 53 additions & 25 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ pub mod v4;
pub mod v5;
pub mod v6;

const MAX_QUEUE_SIZE: usize = 2;

pub trait JobInfo: Send + Clone + 'static {
/// Check if the info are compatible.
fn check(&self, info: &Self) -> bool;
}

/// A [`Job`] to be executed on GPU.
pub trait Job: Sized + Send + 'static {
type Info;
type Info: JobInfo;
type Input;
type Output;

/// Check if the input is compatible.
fn check(&self, input: &Self::Input, info: &Self::Info) -> bool;
/// Load the data from CPU to GPU.
fn load(self, input: &Self::Input) -> Result<Self>;
/// Submit the job to GPU and execute it immediately.
Expand All @@ -27,12 +32,12 @@ pub trait Job: Sized + Send + 'static {
fn back(self) -> impl Future<Output = Result<Self::Output>> + Send;
}

pub trait JobBuilder<J: Job>: Send + 'static {
pub trait JobBuilder<J: Job>: Send + Clone + 'static {
type Info;

/// Build a [`Job`] from the given info.
/// This usually involves creating a list of GPU commands (but not actually execution).
fn build(&self, info: Self::Info) -> impl Future<Output = Result<J>> + Send;
fn build(&self, info: Self::Info) -> Result<J>;
}

#[derive(Debug)]
Expand All @@ -57,7 +62,7 @@ pub struct JobRuntime<I, O>(tokio::sync::mpsc::Sender<Submission<I, O>>);
#[allow(clippy::type_complexity)]
impl<I, O, T, F> JobRuntime<I, O>
where
T: Send + 'static,
T: JobInfo,
F: Iterator<Item = T> + Send + 'static,
I: JobInput,
O: Send + 'static,
Expand Down Expand Up @@ -85,23 +90,52 @@ where
where
J: Job<Info = T, Input = I::Chunk, Output = O>,
{
let mut predict: Option<J> = None;
let mut queue: Vec<(T, tokio::task::JoinHandle<Result<J>>)> = vec![];
let mut iter: Option<F> = None;

while let Some(Submission { input, sender }) = receiver.recv().await {
let mut iter = (&input).into_iter();
let Some(info) = iter.next() else {
let Some(info) = (&input).into_iter().next() else {
continue;
};
let next = iter.next();
drop(iter);

fn check<J: Job>(job: J, input: &J::Input, info: &J::Info) -> Option<J> {
job.check(input, info).then_some(job)
}

let chunk = input.chunk();
let mut job = match predict.take().and_then(|job| check(job, &chunk, &info)) {
Some(job) => job,
None => builder.build(info).await?,

let mut job = loop {
let mut candidates = vec![];
let mut remain = vec![];
for (key, handle) in queue.drain(..) {
match (candidates.is_empty(), info.check(&key)) {
(true, false) => handle.abort(),
(false, false) => remain.push((key, handle)),
(_, true) => candidates.push(handle),
}
}
queue = remain;

if candidates.is_empty() || iter.is_none() {
iter = Some((&input).into_iter());
}
let iter = iter.as_mut().expect("iter should be assigned");

let remain = queue.len() + candidates.len().max(1) - 1;
let predict = MAX_QUEUE_SIZE - MAX_QUEUE_SIZE.min(remain);
for info in iter.take(predict) {
let key = info.clone();
let builder = builder.clone();
let handle = tokio::task::spawn_blocking(move || builder.build(key));
queue.push((info.clone(), handle));
}

if !candidates.is_empty() {
let (job, _, remain) = futures::future::select_all(candidates).await;
let mut remain = remain
.into_iter()
.map(|handle| (info.clone(), handle))
.collect();
std::mem::swap(&mut queue, &mut remain);
queue.append(&mut remain);
break job??;
}
}
.load(&chunk)?;

Expand All @@ -117,13 +151,7 @@ where
}

job.submit();
let handle = tokio::spawn(back(job, input, sender));

predict = match next {
Some(info) => Some(builder.build(info).await?),
None => None,
};
handle.await??;
tokio::spawn(back(job, input, sender));
}
Ok(())
}
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,6 @@ impl Job for InferJob {
type Input = InferChunk;
type Output = InferOutput;

fn check(&self, input: &Self::Input, info: &Self::Info) -> bool {
input.num_token() == self.cursors.shape()[0] && info.redirect() == self.redirect
}

fn load(self, input: &Self::Input) -> Result<Self> {
if input.num_token() == 0 {
return Ok(self);
Expand Down Expand Up @@ -393,6 +389,7 @@ pub struct Frame<F: Float> {
pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
pub type HookMap<F> = HashMap<Hook, HookFn<F>>;

#[derive(Clone)]
pub struct ModelRuntime<F: Float> {
model: Model,
state: State,
Expand Down Expand Up @@ -477,7 +474,7 @@ fn hook_op<F: Float>(
impl<F: Float> JobBuilder<InferJob> for ModelRuntime<F> {
type Info = InferInfo;

async fn build(&self, seed: Self::Info) -> Result<InferJob> {
fn build(&self, seed: Self::Info) -> Result<InferJob> {
let model = &self.model;
let state = &self.state;
let context = &model.context;
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,6 @@ impl Job for InferJob {
type Input = InferChunk;
type Output = InferOutput;

fn check(&self, input: &Self::Input, info: &Self::Info) -> bool {
input.num_token() == self.cursors.shape()[0] && info.redirect() == self.redirect
}

fn load(self, input: &Self::Input) -> Result<Self> {
if input.num_token() == 0 {
return Ok(self);
Expand Down Expand Up @@ -408,6 +404,7 @@ pub struct Frame<F: Float> {
pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
pub type HookMap<F> = HashMap<Hook, HookFn<F>>;

#[derive(Clone)]
pub struct ModelRuntime<F: Float> {
model: Model,
state: State,
Expand Down Expand Up @@ -479,7 +476,7 @@ fn hook_op<F: Float>(
impl<F: Float> JobBuilder<InferJob> for ModelRuntime<F> {
type Info = InferInfo;

async fn build(&self, seed: Self::Info) -> Result<InferJob> {
fn build(&self, seed: Self::Info) -> Result<InferJob> {
let model = &self.model;
let state = &self.state;
let context = &model.context;
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,6 @@ impl Job for InferJob {
type Input = InferChunk;
type Output = InferOutput;

fn check(&self, input: &Self::Input, info: &Self::Info) -> bool {
input.num_token() == self.cursors.shape()[0] && info.redirect() == self.redirect
}

fn load(self, input: &Self::Input) -> Result<Self> {
if input.num_token() == 0 {
return Ok(self);
Expand Down Expand Up @@ -437,6 +433,7 @@ pub struct Frame<F: Float> {
pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
pub type HookMap<F> = HashMap<Hook, HookFn<F>>;

#[derive(Clone)]
pub struct ModelRuntime<F: Float> {
model: Model,
state: State,
Expand Down Expand Up @@ -509,7 +506,7 @@ fn hook_op<F: Float>(
impl<F: Float> JobBuilder<InferJob> for ModelRuntime<F> {
type Info = InferInfo;

async fn build(&self, seed: Self::Info) -> Result<InferJob> {
fn build(&self, seed: Self::Info) -> Result<InferJob> {
let model = &self.model;
let state = &self.state;
let context = &model.context;
Expand Down
52 changes: 28 additions & 24 deletions src/tensor/cache.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use std::{
collections::HashMap,
hash::Hash,
sync::{Arc, Mutex},
sync::{Arc, RwLock},
};

use itertools::Itertools;
use uid::Id;
use rustc_hash::FxHashMap as HashMap;

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
struct CacheId;

type CachedItem<V> = (Arc<V>, Id<CacheId>);
type CachedItem<V> = (Arc<V>, uid::Id<CacheId>);

#[derive(Debug)]
pub struct ResourceCache<K, V> {
map: Mutex<HashMap<K, CachedItem<V>>>,
map: RwLock<HashMap<K, CachedItem<V>>>,
#[allow(unused)]
limit: usize,
}

Expand All @@ -40,28 +39,33 @@ where

/// Checkout the item with the given key. If the item doesn't exist, `f` is called to construct it.
pub fn checkout(&self, key: K, miss: impl FnOnce() -> V) -> Arc<V> {
let mut map = self.map.lock().unwrap();
let map = self.map.read().unwrap();
let value = match map.get(&key) {
Some((value, _)) => value.clone(),
None => {
let value = Arc::new(miss());
drop(map);

let value = match map.remove(&key) {
Some((value, _)) => value,
None => Arc::new(miss()),
let mut map = self.map.write().unwrap();
map.insert(key, (value.clone(), uid::Id::new()));
value
}
};

if self.limit > 0 {
let remove_count = map.len() - self.limit.min(map.len());
let remove = map
.iter()
.sorted_unstable_by_key(|(_, (_, id))| id.get())
.map(|(key, _)| key)
.take(remove_count)
.cloned()
.collect_vec();
for key in remove {
map.remove(&key);
}
}
// if self.limit > 0 {
// let remove_count = map.len() - self.limit.min(map.len());
// let remove = map
// .iter()
// .sorted_unstable_by_key(|(_, (_, id))| id.get())
// .map(|(key, _)| key)
// .take(remove_count)
// .cloned()
// .collect_vec();
// for key in remove {
// map.remove(&key);
// }
// }

map.insert(key, (value.clone(), Id::new()));
value
}
}

0 comments on commit 58275ff

Please sign in to comment.