Skip to content

Commit

Permalink
Simplify inference API.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 21, 2024
1 parent 44d3384 commit 7f78abd
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 35 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ keywords = ["deep-learning", "language", "model", "rwkv"]
license = "MIT OR Apache-2.0"
name = "web-rwkv"
repository = "https://github.com/cryscan/web-rwkv"
version = "0.8.9"
version = "0.8.10"

[dependencies]
ahash = "0.8"
anyhow = "1.0"
bytemuck = { version = "1.13", features = ["extern_crate_alloc"] }
derive-getters = "0.3"
derive-getters = "0.4"
document-features = "0.2.8"
flume = { version = "0.11.0" }
futures = "0.3"
gpp = "0.6.2"
half = { version = "2.2", features = ["bytemuck", "serde"] }
instant = { version = "0.1", features = ["inaccurate", "wasm-bindgen"] }
itertools = "0.12"
itertools = "0.13"
log = "0.4"
regex = "1.10"
rustc-hash = "1.1.0"
Expand Down
8 changes: 2 additions & 6 deletions examples/rt-batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use web_rwkv::{
loader::{Loader, Lora},
model::{Build, ContextAutoLimits, ModelBuilder, ModelInfo, ModelVersion, Quant},
softmax::softmax,
v4, v5, v6, JobRuntime, Submission,
v4, v5, v6, JobRuntime,
},
tokenizer::Tokenizer,
};
Expand Down Expand Up @@ -294,11 +294,7 @@ async fn main() -> Result<()> {
}

let input = inference.clone();
let (sender, receiver) = tokio::sync::oneshot::channel();
let submission = Submission { input, sender };

let _ = runtime.send(submission).await;
let (input, output) = receiver.await.unwrap();
let (input, output) = runtime.infer(input).await;
inference = input;

let output = output.iter().map(|batch| batch.0.clone()).collect_vec();
Expand Down
14 changes: 3 additions & 11 deletions examples/rt-chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use web_rwkv::{
State,
},
softmax::softmax_one,
v4, v5, v6, JobRuntime, Submission,
v4, v5, v6, JobRuntime,
},
tensor::{TensorCpu, TensorInit, TensorShape},
tokenizer::Tokenizer,
Expand Down Expand Up @@ -293,11 +293,7 @@ async fn main() -> Result<()> {

loop {
let input = inference.clone();
let (sender, receiver) = tokio::sync::oneshot::channel();
let submission = Submission { input, sender };

let _ = runtime.send(submission).await;
let (input, output) = receiver.await.unwrap();
let (input, output) = runtime.infer(input).await;
inference = input;

if output[0].size() > 0 {
Expand Down Expand Up @@ -353,11 +349,7 @@ async fn main() -> Result<()> {

loop {
let input = inference.clone();
let (sender, receiver) = tokio::sync::oneshot::channel();
let submission = Submission { input, sender };

let _ = runtime.send(submission).await;
let (input, output) = receiver.await.unwrap();
let (input, output) = runtime.infer(input).await;
inference = input;

let output = output[0].0.clone();
Expand Down
10 changes: 2 additions & 8 deletions examples/rt-gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use web_rwkv::{
loader::{Loader, Lora},
model::{Build, ContextAutoLimits, ModelBuilder, ModelInfo, ModelVersion, Quant},
softmax::softmax_one,
v4, v5, v6, JobRuntime, Submission,
v4, v5, v6, JobRuntime,
},
tokenizer::Tokenizer,
};
Expand Down Expand Up @@ -205,14 +205,8 @@ async fn main() -> Result<()> {

let num_token = 500;
for _ in 0..num_token {
let (sender, receiver) = tokio::sync::oneshot::channel();
let input = prompt.clone();
let submission = Submission { input, sender };

let _ = runtime.send(submission).await;
let Ok((input, output)) = receiver.await else {
break;
};
let (input, output) = runtime.infer(input).await;
prompt = input;

let output = output[0].0.clone();
Expand Down
20 changes: 13 additions & 7 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::future::Future;

use anyhow::Result;
use web_rwkv_derive::Deref;

pub mod infer;
pub mod loader;
Expand Down Expand Up @@ -41,9 +40,9 @@ pub trait JobBuilder<J: Job>: Send + Clone + 'static {
}

#[derive(Debug)]
pub struct Submission<I, O> {
pub input: I,
pub sender: tokio::sync::oneshot::Sender<(I, O)>,
struct Submission<I, O> {
input: I,
sender: tokio::sync::oneshot::Sender<(I, O)>,
}

pub trait JobInput: Send + 'static {
Expand All @@ -56,7 +55,7 @@ pub trait JobInput: Send + 'static {
fn chunk(&self) -> Self::Chunk;
}

#[derive(Debug, Clone, Deref)]
#[derive(Debug, Clone)]
pub struct JobRuntime<I, O>(tokio::sync::mpsc::Sender<Submission<I, O>>);

#[allow(clippy::type_complexity)]
Expand Down Expand Up @@ -127,8 +126,6 @@ where
}
let iter = iter.as_mut().expect("iter should be assigned");

// let remain = queue.len() + candidates.len().max(1) - 1;
// let predict = MAX_QUEUE_SIZE - MAX_QUEUE_SIZE.min(remain);
for info in iter.take(predict) {
#[cfg(feature = "trace")]
tracing::event!(
Expand Down Expand Up @@ -176,4 +173,13 @@ where
}
Ok(())
}

/// Perform (partial) inference and return the remaining input and (perhaps partial) output.
/// The amount of input processed during one call is bound by the input chunk size.
pub async fn infer(&self, input: I) -> (I, O) {
let (sender, receiver) = tokio::sync::oneshot::channel();
let submission = Submission { input, sender };
let _ = self.0.send(submission).await;
receiver.await.expect("receive infer output error")
}
}

0 comments on commit 7f78abd

Please sign in to comment.