diff --git a/src/model/run.rs b/src/model/run.rs index cb558dd..fb7dcbf 100644 --- a/src/model/run.rs +++ b/src/model/run.rs @@ -1,14 +1,16 @@ use std::{collections::HashMap, future::Future, hash::Hash, sync::Arc}; use anyhow::Result; +use half::f16; +use itertools::Itertools; use super::{ModelBase, ModelError, ModelInfo, ModelState}; use crate::{ context::Context, tensor::{ ops::{TensorOp, TensorOpHook}, - shape::Shape, - ReadBack, ReadWrite, TensorGpu, + shape::{Shape, TensorDimension}, + ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu, TensorReshape, TensorStack, }, }; @@ -50,6 +52,35 @@ pub(crate) trait ModelRunInner: ModelBase { should_output: Vec, hooks: &HookMap, ) -> Result<(TensorGpu, Vec>)>; + + fn create_input<'a>( + &self, + embed: &TensorCpu<'a, f16>, + tokens: &[Vec], + ) -> Result, TensorError> { + let info = self.info(); + let context = self.context(); + + let input: Vec<_> = tokens + .iter() + .map(|tokens| -> Result<_, TensorError> { + let stack = TensorCpu::stack( + tokens + .iter() + .map(|&token| embed.slice(.., token as usize, .., ..)) + .try_collect()?, + ) + .unwrap_or_else(|_| context.zeros(Shape::new(info.num_emb, 1, 0, 1))); + stack.map(|x| x.to_f32()).reshape( + TensorDimension::Full, + TensorDimension::Auto, + TensorDimension::Dimension(1), + TensorDimension::Full, + ) + }) + .try_collect()?; + TensorStack::try_from(input) + } } pub trait ModelRun: ModelBase { diff --git a/src/model/v4.rs b/src/model/v4.rs index fbfeb73..6c771a6 100644 --- a/src/model/v4.rs +++ b/src/model/v4.rs @@ -20,9 +20,9 @@ use crate::{ tensor::{ cache::ResourceCache, ops::{TensorCommand, TensorOp, TensorOpHook, TensorPass}, - shape::{Shape, TensorDimension}, + shape::Shape, DeepClone, IntoPackedCursors, ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu, - TensorReshape, TensorShape, TensorStack, TensorView, + TensorShape, TensorView, }, }; @@ -108,6 +108,7 @@ struct Head { /// Runtime buffers. #[derive(Debug)] pub struct Runtime { + pub tokens: TensorGpu, pub cursors: TensorGpu, pub input: TensorGpu, @@ -138,6 +139,7 @@ impl Runtime { let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1); Self { + tokens: context.tensor_init(cursors_shape), cursors: context.tensor_init(cursors_shape), input: context.tensor_init(shape), att_x: context.tensor_init(shape), @@ -653,26 +655,7 @@ impl ModelRunInner for Model<'_> { let context = &self.context; let tensor = &self.tensor; - let input: Vec<_> = tokens - .into_iter() - .map(|tokens| -> Result<_, TensorError> { - let stack = TensorCpu::stack( - tokens - .into_iter() - .map(|token| tensor.embed.w.slice(.., token as usize, .., ..)) - .try_collect()?, - ) - .unwrap_or_else(|_| context.zeros(Shape::new(self.info.num_emb, 1, 0, 1))); - stack.map(|x| x.to_f32()).reshape( - TensorDimension::Full, - TensorDimension::Auto, - TensorDimension::Dimension(1), - TensorDimension::Full, - ) - }) - .try_collect()?; - - let input = TensorStack::try_from(input)?; + let input = self.create_input(&tensor.embed.w, &tokens)?; let num_batch = input.num_batch(); let num_token = input.num_token(); assert_ne!(num_token, 0); diff --git a/src/model/v5.rs b/src/model/v5.rs index cbca395..881ae30 100644 --- a/src/model/v5.rs +++ b/src/model/v5.rs @@ -21,7 +21,7 @@ use crate::{ ops::{TensorCommand, TensorOp, TensorOpHook, TensorPass}, shape::{Shape, TensorDimension}, DeepClone, IntoPackedCursors, ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu, - TensorReshape, TensorShape, TensorStack, TensorView, + TensorReshape, TensorShape, TensorView, }, }; @@ -753,26 +753,7 @@ impl ModelRunInner for Model<'_> { let context = &self.context; let tensor = &self.tensor; - let input: Vec<_> = tokens - .into_iter() - .map(|tokens| -> Result<_, TensorError> { - let stack = TensorCpu::stack( - tokens - .into_iter() - .map(|token| tensor.embed.w.slice(.., token as usize, .., ..)) - .try_collect()?, - ) - .unwrap_or_else(|_| context.zeros(Shape::new(self.info.num_emb, 1, 0, 1))); - stack.map(|x| x.to_f32()).reshape( - TensorDimension::Full, - TensorDimension::Auto, - TensorDimension::Dimension(1), - TensorDimension::Full, - ) - }) - .try_collect()?; - - let input = TensorStack::try_from(input)?; + let input = self.create_input(&tensor.embed.w, &tokens)?; let num_batch = input.num_batch(); let num_token = input.num_token(); let head_size = self.info.num_emb / self.info.num_head; diff --git a/src/model/v6.rs b/src/model/v6.rs index 731519c..845029c 100644 --- a/src/model/v6.rs +++ b/src/model/v6.rs @@ -21,7 +21,7 @@ use crate::{ ops::{TensorCommand, TensorOp, TensorOpHook, TensorPass}, shape::{Shape, TensorDimension}, DeepClone, IntoPackedCursors, ReadBack, ReadWrite, TensorCpu, TensorError, TensorGpu, - TensorReshape, TensorShape, TensorStack, TensorView, + TensorReshape, TensorShape, TensorView, }, }; @@ -816,26 +816,7 @@ impl ModelRunInner for Model<'_> { let context = &self.context; let tensor = &self.tensor; - let input: Vec<_> = tokens - .into_iter() - .map(|tokens| -> Result<_, TensorError> { - let stack = TensorCpu::stack( - tokens - .into_iter() - .map(|token| tensor.embed.w.slice(.., token as usize, .., ..)) - .try_collect()?, - ) - .unwrap_or_else(|_| context.zeros(Shape::new(self.info.num_emb, 1, 0, 1))); - stack.map(|x| x.to_f32()).reshape( - TensorDimension::Full, - TensorDimension::Auto, - TensorDimension::Dimension(1), - TensorDimension::Full, - ) - }) - .try_collect()?; - - let input = TensorStack::try_from(input)?; + let input = self.create_input(&tensor.embed.w, &tokens)?; let num_batch = input.num_batch(); let num_token = input.num_token(); let head_size = self.info.num_emb / self.info.num_head; diff --git a/src/shaders/embed.wgsl b/src/shaders/embed.wgsl index d27c282..71520c0 100644 --- a/src/shaders/embed.wgsl +++ b/src/shaders/embed.wgsl @@ -1,11 +1,15 @@ @group(0) @binding(0) var shape: vec4; // [C, T, B] @group(0) @binding(1) var tokens: array; // (B, T) -@group(0) @binding(2) var input: array>; // (V, C) +@group(0) @binding(2) var input: array>; // (V, C) @group(0) @binding(3) var output: array>; // (B, T, C) const BLOCK_SIZE: u32 = 128u; +fn unpack4x16float(x: vec2) -> vec4 { + return vec4(unpack2x16float(x.x), unpack2x16float(x.y)); +} + @compute @workgroup_size(128, 1, 1) fn embed(@builtin(global_invocation_id) invocation_id: vec3) { let stride = shape[0] / 4u; @@ -19,6 +23,6 @@ fn embed(@builtin(global_invocation_id) invocation_id: vec3) { let bti = (batch * shape[1] + token) * stride + index; let bei = fetch * stride + index; - output[bti] = input[bei]; + output[bti] = unpack4x16float(input[bei]); } } \ No newline at end of file