Skip to content

Commit

Permalink
Fix initial state loading (when loading buffer from CPU, don't use ca…
Browse files Browse the repository at this point in the history
…che).
  • Loading branch information
cryscan committed Jun 21, 2024
1 parent 10bf58a commit 2b60c4c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 61 deletions.
24 changes: 18 additions & 6 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,18 @@ impl ContextInternal {

pub(crate) fn checkout_buffer_init(&self, contents: &[u8], usage: BufferUsages) -> Arc<Buffer> {
let size = std::mem::size_of_val(contents);
let key = BufferKey { size, usage };
let _key = BufferKey { size, usage };
let desc = BufferInitDescriptor {
label: None,
contents,
usage,
};
self.buffer_cache.checkout(
key,
|| self.device.create_buffer_init(&desc),
|buffer| self.queue.write_buffer(buffer, 0, contents),
)
// self.buffer_cache.checkout(
// key,
// || self.device.create_buffer_init(&desc),
// |buffer| self.queue.write_buffer(buffer, 0, contents),
// )
self.device.create_buffer_init(&desc).into()
}

pub(crate) fn checkout_buffer(&self, size: usize, usage: BufferUsages) -> Arc<Buffer> {
Expand All @@ -346,6 +347,17 @@ impl ContextInternal {
.checkout(key, || self.device.create_buffer(&desc), |_| {})
}

// pub(crate) fn checkout_buffer_uncached(&self, size: usize, usage: BufferUsages) -> Arc<Buffer> {
// self.device
// .create_buffer(&BufferDescriptor {
// label: None,
// size: size as u64,
// usage,
// mapped_at_creation: false,
// })
// .into()
// }

/// Maintain resource caches.
#[inline]
pub fn maintain(&self) {
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,6 @@ pub async fn read_state<R: Reader>(
info: &ModelInfo,
model: R,
) -> Result<TensorCpu<f32>> {
use crate::tensor::TensorInitContext;
use TensorDimension::{Auto, Dimension};

let loader = Loader {
Expand All @@ -1105,7 +1104,7 @@ pub async fn read_state<R: Reader>(
let matrix = loader
.load_matrix_f16(format!("blocks.{layer}.att.time_state"))
.await?;
let state = TensorGpu::init(context, [head_size, info.num_head, head_size, 1]);
let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
let reshaped: TensorGpu<f16, _> = state.reshape(
Dimension(info.num_emb),
Dimension(head_size),
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,6 @@ pub async fn read_state<R: Reader>(
info: &ModelInfo,
model: R,
) -> Result<TensorCpu<f32>> {
use crate::tensor::TensorInitContext;
use TensorDimension::{Auto, Dimension};

let loader = Loader {
Expand All @@ -1211,7 +1210,7 @@ pub async fn read_state<R: Reader>(
let matrix = loader
.load_matrix_f16(format!("blocks.{layer}.att.time_state"))
.await?;
let state = TensorGpu::init(context, [head_size, info.num_head, head_size, 1]);
let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
let reshaped: TensorGpu<f16, _> = state.reshape(
Dimension(info.num_emb),
Dimension(head_size),
Expand Down
50 changes: 1 addition & 49 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,10 @@ pub trait TensorReshape: Sized {
}

/// A tensor on either CPU or GPU.
#[derive(Debug, Eq)]
#[derive(Debug)]
pub struct Tensor<D: Device, T: Scalar> {
shape: Shape,
data: D::Data,
id: uid::Id<TensorId>,
phantom: PhantomData<T>,
}

Expand All @@ -247,7 +246,6 @@ impl<D: Device, T: Scalar> Clone for Tensor<D, T> {
Self {
shape: self.shape,
data: self.data.clone(),
id: self.id,
phantom: PhantomData,
}
}
Expand All @@ -262,12 +260,6 @@ impl<D: Device, T: Scalar> std::ops::Deref for Tensor<D, T> {
}
}

impl<D: Device, T: Scalar> PartialEq for Tensor<D, T> {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}

impl<D: Device, T: Scalar> TensorScalar for Tensor<D, T> {
type T = T;
}
Expand Down Expand Up @@ -299,11 +291,6 @@ impl<D: Device, T: Scalar> Tensor<D, T> {
pub fn data(&self) -> &D::Data {
&self.data
}

#[inline]
pub fn id(&self) -> uid::Id<TensorId> {
self.id
}
}

impl<D: Device, F: Float> Tensor<D, F> {
Expand All @@ -317,14 +304,12 @@ impl<T: Scalar> TensorInit<T> for TensorCpu<T> {
fn from_data(shape: impl Into<Shape>, data: impl Into<Arc<[T]>>) -> Result<Self, TensorError> {
let shape = shape.into();
let data = data.into();
let id = uid::Id::new();
if shape.len() != data.len() {
return Err(TensorError::Size(shape.len(), data.len()));
}
Ok(Self {
shape,
data,
id,
phantom: PhantomData,
})
}
Expand All @@ -333,11 +318,9 @@ impl<T: Scalar> TensorInit<T> for TensorCpu<T> {
fn init(shape: impl Into<Shape>) -> Self {
let shape = shape.into();
let data = vec![T::zero(); shape.len()].into();
let id = uid::Id::new();
Self {
shape,
data,
id,
phantom: PhantomData,
}
}
Expand Down Expand Up @@ -407,20 +390,15 @@ impl<T: Scalar, K: Kind> TensorInitContext<T> for TensorGpu<T, K> {
let context = context.clone();
let shape = shape.into();
let meta = context.checkout_shape_uniform(shape);

let size = shape.len() * std::mem::size_of::<T>();
let buffer = context.checkout_buffer(size, K::buffer_usages());

let id = uid::Id::new();

Self {
shape,
data: TensorGpuData {
context,
meta,
buffer,
},
id,
phantom: PhantomData,
}
}
Expand All @@ -433,15 +411,13 @@ impl<T: Scalar, K: Kind> TensorInto<TensorGpu<T, K>> for TensorCpu<T> {
let meta = context.checkout_shape_uniform(shape);
let contents = bytemuck::cast_slice(&data);
let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
let id = uid::Id::new();
TensorGpu {
shape,
data: TensorGpuData {
context,
meta,
buffer,
},
id,
phantom: PhantomData,
}
}
Expand Down Expand Up @@ -523,12 +499,9 @@ impl<T: Scalar, K: Kind> TensorGpu<T, K> {
let data = data.into_vec().into();
let shape = self.shape;

let id = uid::Id::new();

TensorCpu {
shape,
data,
id,
phantom: PhantomData,
}
}
Expand Down Expand Up @@ -559,12 +532,9 @@ impl<T: Scalar, K: Kind> TensorGpu<T, K> {
};
let data = data.into_vec().into();

let id = uid::Id::new();

TensorCpu {
shape: self.shape,
data,
id,
phantom: PhantomData,
}
}
Expand Down Expand Up @@ -596,12 +566,9 @@ impl<T: Scalar, K: Kind> TensorGpu<T, K> {
};
buffer.unmap();

let id = uid::Id::new();

TensorCpu {
shape: self.shape,
data,
id,
phantom: PhantomData,
}
}
Expand Down Expand Up @@ -698,12 +665,9 @@ impl<T: Scalar> TensorCpu<T> {
.concat()
.into();

let id = uid::Id::new();

Self {
shape,
data,
id,
phantom: PhantomData,
}
}
Expand All @@ -729,12 +693,9 @@ impl<T: Scalar> TensorCpu<T> {
.concat()
.into();

let id = uid::Id::new();

Ok(Self {
shape,
data,
id,
phantom: PhantomData,
})
}
Expand Down Expand Up @@ -770,12 +731,9 @@ impl<T: Scalar> TensorCpu<T> {
let (start, end) = slice.bounds(self.shape)?;
let data = self.data[start..end].into();

let id = uid::Id::new();

Ok(Self {
shape,
data,
id,
phantom: PhantomData,
})
}
Expand All @@ -794,12 +752,9 @@ impl<T: Scalar> TensorCpu<T> {
let (start, end) = slice.bounds(self.shape)?;
let data = self.data[start..end].into();

let id = uid::Id::new();

Ok(Self {
shape,
data,
id,
phantom: PhantomData,
})
}
Expand Down Expand Up @@ -974,13 +929,10 @@ impl<T: Scalar> TryFrom<Vec<TensorCpu<T>>> for TensorStack<T> {
);
let data = data.into();

let id = uid::Id::new();

Ok(Self {
tensor: Tensor {
shape,
data,
id,
phantom: PhantomData,
},
cursors,
Expand Down
2 changes: 0 additions & 2 deletions src/tensor/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ impl<T: Scalar> From<TensorBlob> for TensorCpu<T> {
let TensorBlob { shape, data } = value;
let data: Vec<T> = bytemuck::cast_slice(&data).to_vec();
let data = data.into();
let id = uid::Id::new();
Self {
shape,
data,
id,
phantom: PhantomData,
}
}
Expand Down

0 comments on commit 2b60c4c

Please sign in to comment.