Skip to content

Commit

Permalink
Change Reader interface to allow inputting owned data.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 22, 2024
1 parent 342e383 commit 3881b39
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.6.16"
version = "0.6.17"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
104 changes: 49 additions & 55 deletions src/model/loader.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{fmt::Debug, future::Future, pin::Pin};
use std::{borrow::Cow, fmt::Debug, future::Future, pin::Pin};

use anyhow::Result;
use half::f16;
use itertools::Itertools;
use regex::Regex;
use safetensors::{tensor::TensorView, SafeTensorError, SafeTensors};
use safetensors::{SafeTensorError, SafeTensors};
use web_rwkv_derive::{Deref, DerefMut};

use super::{ModelError, ModelInfo, ModelVersion, Quant};
Expand All @@ -16,7 +16,7 @@ use crate::{
matrix::Matrix,
ops::{TensorOp, TensorPass},
shape::{Shape, TensorDimension},
TensorCpu, TensorError, TensorGpu, TensorInit, TensorReshape, TensorShape,
TensorCpu, TensorGpu, TensorInit, TensorReshape, TensorShape,
},
};

Expand All @@ -25,10 +25,11 @@ pub trait Reader {
fn names(&self) -> Vec<&str>;
fn contains(&self, name: &str) -> bool;

#[allow(clippy::type_complexity)]
fn tensor<'a>(
&'a self,
name: &str,
) -> Pin<Box<dyn Future<Output = Result<TensorView<'a>, SafeTensorError>> + 'a>>;
) -> Pin<Box<dyn Future<Output = Result<(Vec<usize>, Cow<'a, [u8]>), SafeTensorError>> + 'a>>;
}

impl Reader for SafeTensors<'_> {
Expand All @@ -46,9 +47,15 @@ impl Reader for SafeTensors<'_> {
fn tensor(
&self,
name: &str,
) -> Pin<Box<dyn Future<Output = Result<TensorView<'_>, SafeTensorError>> + '_>> {
) -> Pin<Box<dyn Future<Output = Result<(Vec<usize>, Cow<'_, [u8]>), SafeTensorError>> + '_>>
{
let name = name.to_string();
Box::pin(async move { self.tensor(&name) })
Box::pin(async move {
let tensor = self.tensor(&name)?;
let shape = tensor.shape().to_vec();
let data = Cow::from(tensor.data());
Ok((shape, data))
})
}
}

Expand Down Expand Up @@ -156,9 +163,9 @@ impl<'a> Loader<'a> {
r + 1
};

let embed = model.tensor("emb.weight").await?;
let ffn = model.tensor("blocks.0.ffn.key.weight").await?;
let time_first = model.tensor("blocks.0.att.time_first").await?;
let embed = model.tensor("emb.weight").await?.0;
let ffn = model.tensor("blocks.0.ffn.key.weight").await?.0;
let time_first = model.tensor("blocks.0.att.time_first").await?.0;

let v5 = [
"blocks.0.att.gate.weight",
Expand Down Expand Up @@ -191,10 +198,10 @@ impl<'a> Loader<'a> {
_ => return Err(ModelError::InvalidVersion.into()),
};

let num_emb = embed.shape()[1];
let num_hidden = ffn.shape()[0];
let num_vocab = embed.shape()[0];
let num_head = time_first.shape()[0];
let num_emb = embed[1];
let num_hidden = ffn[0];
let num_vocab = embed[0];
let num_head = time_first[0];

Ok(ModelInfo {
version,
Expand Down Expand Up @@ -222,10 +229,10 @@ impl<'a> Loader<'a> {
continue;
};

let Ok(tensor) = lora.data.tensor(name).await else {
let Ok((shape, data)) = lora.data.tensor(name).await else {
continue;
};
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, tensor)?
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, data)?
.map(|x| x.to_f32())
.into();
let alpha = blend.alpha;
Expand Down Expand Up @@ -259,8 +266,8 @@ impl<'a> Loader<'a> {
continue;
};

let a = TensorGpu::from_safetensors(&self.context, a)?;
let b = TensorGpu::from_safetensors(&self.context, b)?;
let a = TensorGpu::from_safetensors(&self.context, a.0, a.1)?;
let b = TensorGpu::from_safetensors(&self.context, b.0, b.1)?;
let rank = a.shape()[0];
let alpha = blend.alpha;
matrices.push(LoraMatrix { a, b, rank, alpha });
Expand All @@ -271,25 +278,17 @@ impl<'a> Loader<'a> {
}

pub async fn tensor_shape(&self, name: impl AsRef<str>) -> Result<Shape> {
let tensor = self.model.tensor(name.as_ref()).await?;
let shape = match *tensor.shape() {
[] => Shape::new(0, 0, 0, 0),
[x] => Shape::new(x, 1, 1, 1),
[y, x] => Shape::new(x, y, 1, 1),
[z, y, x] => Shape::new(x, y, z, 1),
[w, z, y, x] => Shape::new(x, y, z, w),
_ => return Err(TensorError::Deduce.into()),
};
Ok(shape)
let (shape, _) = self.model.tensor(name.as_ref()).await?;
Ok(Shape::from_safetensors(&shape)?)
}

pub async fn load_vector_f32(
&self,
name: impl AsRef<str>,
) -> Result<TensorGpu<f32, ReadWrite>> {
use TensorDimension::{Auto, Dimension};
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, tensor)?
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, tensor)?
.map(|x| x.to_f32())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
.into();
Expand All @@ -316,8 +315,8 @@ impl<'a> Loader<'a> {
name: impl AsRef<str>,
) -> Result<TensorGpu<f32, ReadWrite>> {
use TensorDimension::{Auto, Dimension};
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, tensor)?
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, tensor)?
.map(|x| -x.to_f32().exp())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
.into();
Expand All @@ -344,8 +343,8 @@ impl<'a> Loader<'a> {
name: impl AsRef<str>,
) -> Result<TensorGpu<f32, ReadWrite>> {
use TensorDimension::{Auto, Dimension};
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, tensor)?
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, tensor)?
.map(|x| -x.to_f32().exp())
.map(|x| x.exp())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
Expand Down Expand Up @@ -375,16 +374,16 @@ impl<'a> Loader<'a> {
use TensorDimension::{Auto, Dimension};
let context = &self.context;
let lora = self.lora_vectors(name.as_ref()).await?;
let tensor = self.model.tensor(name.as_ref()).await?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = if lora.is_empty() {
TensorGpu::from_safetensors(context, tensor)?.reshape(
TensorGpu::from_safetensors(context, shape, tensor)?.reshape(
Auto,
Dimension(1),
Dimension(1),
Dimension(1),
)?
} else {
let tensor_f32 = TensorCpu::<f16>::from_safetensors(context, tensor)?
let tensor_f32 = TensorCpu::<f16>::from_safetensors(context, shape, tensor)?
.map(|x| x.to_f32())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?;
let tensor_f32 = TensorGpu::from(tensor_f32);
Expand Down Expand Up @@ -420,9 +419,8 @@ impl<'a> Loader<'a> {
) -> Result<TensorGpu<f16, ReadWrite>> {
let context = &self.context;
let lora = self.lora_matrices(name.as_ref()).await?;
let tensor = self.model.tensor(name.as_ref()).await?;

let tensor = TensorGpu::from_safetensors(context, tensor)?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorGpu::from_safetensors(&self.context, shape, tensor)?;

if !lora.is_empty() {
let mut encoder = context.device.create_command_encoder(&Default::default());
Expand Down Expand Up @@ -451,9 +449,8 @@ impl<'a> Loader<'a> {
let context = &self.context;

let lora = self.lora_matrices(name.as_ref()).await?;
let tensor = self.model.tensor(name.as_ref()).await?;

let tensor = TensorCpu::<f16>::from_safetensors(context, tensor)?
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(context, shape, tensor)?
.map(|x| f16::from_f32(discount * x.to_f32()));
let tensor = TensorGpu::from(tensor);

Expand Down Expand Up @@ -484,9 +481,8 @@ impl<'a> Loader<'a> {
) -> Result<()> {
let context = &self.context;
let lora = self.lora_matrices(name.as_ref()).await?;
let tensor = self.model.tensor(name.as_ref()).await?;

let tensor = TensorCpu::from_safetensors(context, tensor)?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::from_safetensors(context, shape, tensor)?;
matrix.load(&tensor)?;

if !lora.is_empty() {
Expand Down Expand Up @@ -519,9 +515,8 @@ impl<'a> Loader<'a> {
let context = &self.context;

let lora = self.lora_matrices(name.as_ref()).await?;
let tensor = self.model.tensor(name.as_ref()).await?;

let tensor = TensorCpu::<f16>::from_safetensors(context, tensor)?
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(context, shape, tensor)?
.map(|x| f16::from_f32(discount * x.to_f32()))
.reshape(Full, Full, Dimension(1), Dimension(1))?;
matrix.load(&tensor)?;
Expand All @@ -547,23 +542,22 @@ impl<'a> Loader<'a> {
}

pub async fn load_embed<'b>(&self) -> Result<TensorCpu<'b, f16>> {
let embed = self.model.tensor("emb.weight").await?;
let num_emb = embed.shape()[1];
let num_vocab = embed.shape()[0];
let (shape, tensor) = self.model.tensor("emb.weight").await?;
let num_emb = shape[1];
let num_vocab = shape[0];
let tensor = self.context.tensor_from_data(
Shape::new(num_emb, num_vocab, 1, 1),
bytemuck::pod_collect_to_vec(embed.data()),
bytemuck::pod_collect_to_vec(&tensor),
)?;
Ok(tensor)
}

pub async fn load_head(&self, chunk_size: usize) -> Result<Vec<TensorGpu<f16, ReadWrite>>> {
let context = &self.context;
let tensor = self.model.tensor("head.weight").await?;
let shape = tensor.shape();
let (shape, tensor) = self.model.tensor("head.weight").await?;
let shape = Shape::new(shape[1], shape[0], 1, 1);
let chunks = (shape[1] + chunk_size - 1) / chunk_size;
let data = bytemuck::cast_slice(tensor.data());
let data = bytemuck::cast_slice(&tensor);

let head = (0..chunks)
.map(|chunk| {
Expand Down
23 changes: 11 additions & 12 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,19 @@ pub trait TensorInit<'a, T: Scalar>: Sized {
/// Create a tensor from safetensors.
fn from_safetensors(
context: &Context,
tensor: safetensors::tensor::TensorView<'a>,
shape: Vec<usize>,
data: impl Into<Cow<'a, [u8]>>,
) -> Result<Self, TensorError> {
if tensor.dtype() != T::DATA_TYPE {
return Err(TensorError::Type);
let shape = Shape::from_safetensors(&shape)?;
let data: Cow<'_, [u8]> = data.into();
match data {
Cow::Borrowed(data) => Self::from_data(context, shape, bytemuck::cast_slice(data)),
Cow::Owned(data) => {
let data = bytemuck::cast_slice(&data);
let data = Cow::Owned(data.to_vec());
Self::from_data(context, shape, data)
}
}
let shape = match *tensor.shape() {
[] => Shape::new(0, 0, 0, 0),
[x] => Shape::new(x, 1, 1, 1),
[y, x] => Shape::new(x, y, 1, 1),
[z, y, x] => Shape::new(x, y, z, 1),
[w, z, y, x] => Shape::new(x, y, z, w),
_ => return Err(TensorError::Deduce),
};
Self::from_data(context, shape, bytemuck::cast_slice(tensor.data()))
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/tensor/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ impl Shape {
shape
}

pub fn from_safetensors(shape: &[usize]) -> Result<Self, TensorError> {
let shape = match shape[..] {
[] => Shape::new(0, 0, 0, 0),
[x] => Shape::new(x, 1, 1, 1),
[y, x] => Shape::new(x, y, 1, 1),
[z, y, x] => Shape::new(x, y, z, 1),
[w, z, y, x] => Shape::new(x, y, z, w),
_ => return Err(TensorError::Deduce),
};
Ok(shape)
}

pub fn len(&self) -> usize {
self.0.into_iter().product()
}
Expand Down

0 comments on commit 3881b39

Please sign in to comment.