Skip to content

Commit

Permalink
Refactor input tensor creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Dec 31, 2023
1 parent 500a6ac commit d3a0d9a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 68 deletions.
35 changes: 33 additions & 2 deletions src/model/run.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::{collections::HashMap, future::Future, hash::Hash, sync::Arc};

use anyhow::Result;
use half::f16;
use itertools::Itertools;

use super::{ModelBase, ModelError, ModelInfo, ModelState};
use crate::{
context::Context,
tensor::{
ops::{TensorOp, TensorOpHook},
shape::Shape,
ReadBack, ReadWrite, TensorGpu,
shape::{Shape, TensorDimension},
ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu, TensorReshape, TensorStack,
},
};

Expand Down Expand Up @@ -50,6 +52,35 @@ pub(crate) trait ModelRunInner: ModelBase {
should_output: Vec<bool>,
hooks: &HookMap<Self::Hook, Self::ModelState, Self::Runtime>,
) -> Result<(TensorGpu<f32, ReadBack>, Vec<Option<usize>>)>;

fn create_input<'a>(
&self,
embed: &TensorCpu<'a, f16>,
tokens: &[Vec<u16>],
) -> Result<TensorStack<'a, f32>, TensorError> {
let info = self.info();
let context = self.context();

let input: Vec<_> = tokens
.iter()
.map(|tokens| -> Result<_, TensorError> {
let stack = TensorCpu::stack(
tokens
.iter()
.map(|&token| embed.slice(.., token as usize, .., ..))
.try_collect()?,
)
.unwrap_or_else(|_| context.zeros(Shape::new(info.num_emb, 1, 0, 1)));
stack.map(|x| x.to_f32()).reshape(
TensorDimension::Full,
TensorDimension::Auto,
TensorDimension::Dimension(1),
TensorDimension::Full,
)
})
.try_collect()?;
TensorStack::try_from(input)
}
}

pub trait ModelRun: ModelBase {
Expand Down
27 changes: 5 additions & 22 deletions src/model/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ use crate::{
tensor::{
cache::ResourceCache,
ops::{TensorCommand, TensorOp, TensorOpHook, TensorPass},
shape::{Shape, TensorDimension},
shape::Shape,
DeepClone, IntoPackedCursors, ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu,
TensorReshape, TensorShape, TensorStack, TensorView,
TensorShape, TensorView,
},
};

Expand Down Expand Up @@ -108,6 +108,7 @@ struct Head {
/// Runtime buffers.
#[derive(Debug)]
pub struct Runtime {
pub tokens: TensorGpu<u32, ReadWrite>,
pub cursors: TensorGpu<u32, ReadWrite>,
pub input: TensorGpu<f32, ReadWrite>,

Expand Down Expand Up @@ -138,6 +139,7 @@ impl Runtime {
let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);

Self {
tokens: context.tensor_init(cursors_shape),
cursors: context.tensor_init(cursors_shape),
input: context.tensor_init(shape),
att_x: context.tensor_init(shape),
Expand Down Expand Up @@ -653,26 +655,7 @@ impl ModelRunInner for Model<'_> {
let context = &self.context;
let tensor = &self.tensor;

let input: Vec<_> = tokens
.into_iter()
.map(|tokens| -> Result<_, TensorError> {
let stack = TensorCpu::stack(
tokens
.into_iter()
.map(|token| tensor.embed.w.slice(.., token as usize, .., ..))
.try_collect()?,
)
.unwrap_or_else(|_| context.zeros(Shape::new(self.info.num_emb, 1, 0, 1)));
stack.map(|x| x.to_f32()).reshape(
TensorDimension::Full,
TensorDimension::Auto,
TensorDimension::Dimension(1),
TensorDimension::Full,
)
})
.try_collect()?;

let input = TensorStack::try_from(input)?;
let input = self.create_input(&tensor.embed.w, &tokens)?;
let num_batch = input.num_batch();
let num_token = input.num_token();
assert_ne!(num_token, 0);
Expand Down
23 changes: 2 additions & 21 deletions src/model/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
ops::{TensorCommand, TensorOp, TensorOpHook, TensorPass},
shape::{Shape, TensorDimension},
DeepClone, IntoPackedCursors, ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu,
TensorReshape, TensorShape, TensorStack, TensorView,
TensorReshape, TensorShape, TensorView,
},
};

Expand Down Expand Up @@ -753,26 +753,7 @@ impl ModelRunInner for Model<'_> {
let context = &self.context;
let tensor = &self.tensor;

let input: Vec<_> = tokens
.into_iter()
.map(|tokens| -> Result<_, TensorError> {
let stack = TensorCpu::stack(
tokens
.into_iter()
.map(|token| tensor.embed.w.slice(.., token as usize, .., ..))
.try_collect()?,
)
.unwrap_or_else(|_| context.zeros(Shape::new(self.info.num_emb, 1, 0, 1)));
stack.map(|x| x.to_f32()).reshape(
TensorDimension::Full,
TensorDimension::Auto,
TensorDimension::Dimension(1),
TensorDimension::Full,
)
})
.try_collect()?;

let input = TensorStack::try_from(input)?;
let input = self.create_input(&tensor.embed.w, &tokens)?;
let num_batch = input.num_batch();
let num_token = input.num_token();
let head_size = self.info.num_emb / self.info.num_head;
Expand Down
23 changes: 2 additions & 21 deletions src/model/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
ops::{TensorCommand, TensorOp, TensorOpHook, TensorPass},
shape::{Shape, TensorDimension},
DeepClone, IntoPackedCursors, ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu,
TensorReshape, TensorShape, TensorStack, TensorView,
TensorReshape, TensorShape, TensorView,
},
};

Expand Down Expand Up @@ -816,26 +816,7 @@ impl ModelRunInner for Model<'_> {
let context = &self.context;
let tensor = &self.tensor;

let input: Vec<_> = tokens
.into_iter()
.map(|tokens| -> Result<_, TensorError> {
let stack = TensorCpu::stack(
tokens
.into_iter()
.map(|token| tensor.embed.w.slice(.., token as usize, .., ..))
.try_collect()?,
)
.unwrap_or_else(|_| context.zeros(Shape::new(self.info.num_emb, 1, 0, 1)));
stack.map(|x| x.to_f32()).reshape(
TensorDimension::Full,
TensorDimension::Auto,
TensorDimension::Dimension(1),
TensorDimension::Full,
)
})
.try_collect()?;

let input = TensorStack::try_from(input)?;
let input = self.create_input(&tensor.embed.w, &tokens)?;
let num_batch = input.num_batch();
let num_token = input.num_token();
let head_size = self.info.num_emb / self.info.num_head;
Expand Down
8 changes: 6 additions & 2 deletions src/shaders/embed.wgsl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
@group(0) @binding(0) var<uniform> shape: vec4<u32>; // [C, T, B]

@group(0) @binding(1) var<storage, read> tokens: array<u32>; // (B, T)
@group(0) @binding(2) var<storage, read> input: array<vec4<f32>>; // (V, C)
@group(0) @binding(2) var<storage, read> input: array<vec2<u32>>; // (V, C)
@group(0) @binding(3) var<storage, read_write> output: array<vec4<f32>>; // (B, T, C)

const BLOCK_SIZE: u32 = 128u;

fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
return vec4<f32>(unpack2x16float(x.x), unpack2x16float(x.y));
}

@compute @workgroup_size(128, 1, 1)
fn embed(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let stride = shape[0] / 4u;
Expand All @@ -19,6 +23,6 @@ fn embed(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let bti = (batch * shape[1] + token) * stride + index;
let bei = fetch * stride + index;

output[bti] = input[bei];
output[bti] = unpack4x16float(input[bei]);
}
}

0 comments on commit d3a0d9a

Please sign in to comment.