Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cryscan/web-rwkv
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Nov 29, 2023
2 parents 7af12d5 + bf43239 commit 782343c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.4.0"
version = "0.4.1"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
6 changes: 3 additions & 3 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ pub trait FromBuilder: Sized {
fn from_builder(builder: Self::Builder<'_>) -> Result<Self, Self::Error>;
}

pub trait BackedState {
pub trait BackedState: Send {
fn max_batch(&self) -> usize;
fn num_layer(&self) -> usize;

/// Extract the embedding from a given layer of the state.
fn embed(&self, batch: usize, layer: usize) -> Vec<f32>;
}

#[async_trait(?Send)]
#[async_trait]
pub trait ModelState {
type BackedState: BackedState;

Expand All @@ -99,7 +99,7 @@ pub trait ModelState {
) -> Result<(), TensorError>;
}

#[async_trait(?Send)]
#[async_trait]
pub trait Model {
type ModelState: ModelState;

Expand Down
27 changes: 11 additions & 16 deletions src/model/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ impl FromBuilder for ModelState {
}
}

#[async_trait(?Send)]
#[async_trait]
impl super::ModelState for ModelState {
type BackedState = BackedState;

Expand All @@ -272,7 +272,7 @@ impl super::ModelState for ModelState {
if backed.max_batch() != self.max_batch() {
return Err(ModelError::BatchSize(backed.max_batch(), self.max_batch()).into());
}
let host = self.context.tensor_from_data(self.shape(), &backed.data)?;
let host = self.context.tensor_from_data(self.shape(), &*backed.data)?;
self.0.load(&host).map_err(|err| err.into())
}

Expand All @@ -283,7 +283,7 @@ impl super::ModelState for ModelState {
}
let shape = self.shape();
let shape = Shape::new(shape[0], shape[1], 1, 1);
let host = self.context.tensor_from_data(shape, &backed.data)?;
let host = self.context.tensor_from_data(shape, &*backed.data)?;
self.0.load_batch(&host, batch).map_err(|err| err.into())
}

Expand All @@ -298,11 +298,8 @@ impl super::ModelState for ModelState {
encoder.copy_tensor(self, &map).expect("back entire state");
self.context.queue.submit(Some(encoder.finish()));

let host = map.back_async().await;
BackedState {
shape,
data: host.to_vec(),
}
let data = map.back_async().await.to_vec().into();
BackedState { shape, data }
}

async fn back_batch(&self, batch: usize) -> Result<Self::BackedState> {
Expand All @@ -325,11 +322,8 @@ impl super::ModelState for ModelState {
encoder.copy_tensor_batch(self, &map, batch)?;
self.context.queue.submit(Some(encoder.finish()));

let host = map.back_async().await;
Ok(BackedState {
shape,
data: host.to_vec(),
})
let data = map.back_async().await.to_vec().into();
Ok(BackedState { shape, data })
}

fn blit(&self, other: &Self) -> Result<(), TensorError> {
Expand Down Expand Up @@ -369,7 +363,7 @@ impl super::ModelState for ModelState {
#[derive(Debug, Clone)]
pub struct BackedState {
pub shape: Shape,
pub data: Vec<f32>,
pub data: Arc<Vec<f32>>,
}

impl FromBuilder for BackedState {
Expand Down Expand Up @@ -398,7 +392,8 @@ impl FromBuilder for BackedState {
.concat()
})
.collect_vec()
.concat();
.concat()
.into();
Ok(Self { shape, data })
}
}
Expand Down Expand Up @@ -910,7 +905,7 @@ impl<'a> FromBuilder for Model<'a> {
}
}

#[async_trait(?Send)]
#[async_trait]
impl super::Model for Model<'_> {
type ModelState = ModelState;

Expand Down
15 changes: 9 additions & 6 deletions src/model/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ impl FromBuilder for ModelState {
})
.collect();
Ok(Self {
context: context.clone(),
info: info.clone(),
context,
info,
max_batch,
chunk_size,
head_size,
Expand All @@ -282,7 +282,7 @@ impl FromBuilder for ModelState {
}
}

#[async_trait(?Send)]
#[async_trait]
impl super::ModelState for ModelState {
type BackedState = BackedState;

Expand Down Expand Up @@ -342,6 +342,7 @@ impl super::ModelState for ModelState {
let host = map.back_async().await;
data.push((shape, host.to_vec()))
}
let data = data.into();

BackedState {
max_batch,
Expand Down Expand Up @@ -380,6 +381,7 @@ impl super::ModelState for ModelState {
let host = map.back_async().await;
data.push((shape, host.to_vec()));
}
let data = data.into();

Ok(BackedState {
max_batch: 1,
Expand Down Expand Up @@ -433,7 +435,7 @@ pub struct BackedState {
pub max_batch: usize,
pub chunk_size: usize,
pub head_size: usize,
pub data: Vec<(Shape, Vec<f32>)>,
pub data: Arc<Vec<(Shape, Vec<f32>)>>,
}

impl FromBuilder for BackedState {
Expand All @@ -457,7 +459,8 @@ impl FromBuilder for BackedState {
.concat()
})
.map(|x| (shape, x))
.collect();
.collect_vec()
.into();
Ok(Self {
max_batch,
chunk_size,
Expand Down Expand Up @@ -1055,7 +1058,7 @@ impl<'a> FromBuilder for Model<'a> {
}
}

#[async_trait(?Send)]
#[async_trait]
impl super::Model for Model<'_> {
type ModelState = ModelState;

Expand Down
15 changes: 9 additions & 6 deletions src/model/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ impl FromBuilder for ModelState {
})
.collect();
Ok(Self {
context: context.clone(),
info: info.clone(),
context,
info,
max_batch,
chunk_size,
head_size,
Expand All @@ -317,7 +317,7 @@ impl FromBuilder for ModelState {
}
}

#[async_trait(?Send)]
#[async_trait]
impl super::ModelState for ModelState {
type BackedState = BackedState;

Expand Down Expand Up @@ -377,6 +377,7 @@ impl super::ModelState for ModelState {
let host = map.back_async().await;
data.push((shape, host.to_vec()))
}
let data = data.into();

BackedState {
max_batch,
Expand Down Expand Up @@ -415,6 +416,7 @@ impl super::ModelState for ModelState {
let host = map.back_async().await;
data.push((shape, host.to_vec()));
}
let data = data.into();

Ok(BackedState {
max_batch: 1,
Expand Down Expand Up @@ -468,7 +470,7 @@ pub struct BackedState {
pub max_batch: usize,
pub chunk_size: usize,
pub head_size: usize,
pub data: Vec<(Shape, Vec<f32>)>,
pub data: Arc<Vec<(Shape, Vec<f32>)>>,
}

impl FromBuilder for BackedState {
Expand All @@ -492,7 +494,8 @@ impl FromBuilder for BackedState {
.concat()
})
.map(|x| (shape, x))
.collect();
.collect_vec()
.into();
Ok(Self {
max_batch,
chunk_size,
Expand Down Expand Up @@ -1184,7 +1187,7 @@ impl<'a> FromBuilder for Model<'a> {
}
}

#[async_trait(?Send)]
#[async_trait]
impl super::Model for Model<'_> {
type ModelState = ModelState;

Expand Down

0 comments on commit 782343c

Please sign in to comment.