From 9414803b759e9d9af6d6d20a43dc2fa7dbebd87d Mon Sep 17 00:00:00 2001 From: cryscan Date: Mon, 1 Jan 2024 13:15:15 +0800 Subject: [PATCH] Allow embed on gpu. --- examples/batch.rs | 27 ++++++++++++++++++++++-- examples/chat.rs | 28 +++++++++++++++++++++++-- examples/gen.rs | 27 ++++++++++++++++++++++-- src/context.rs | 47 ++++++++++++++++++++++-------------------- src/model/mod.rs | 13 ++++++++---- src/model/v4.rs | 39 ++++++++++++++++++++++++++--------- src/model/v5.rs | 39 +++++++++++++++++++++++++++-------- src/model/v6.rs | 52 +++++++++++++++++++++++++++++++++-------------- 8 files changed, 206 insertions(+), 66 deletions(-) diff --git a/examples/batch.rs b/examples/batch.rs index fc1f730..268b246 100644 --- a/examples/batch.rs +++ b/examples/batch.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use clap::Parser; +use clap::{Parser, ValueEnum}; #[cfg(not(debug_assertions))] use crossterm::terminal::{ disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen, @@ -106,6 +106,7 @@ fn load_model( lora: Option, quant: Option, quant_nf4: Option, + embed_device: Option, turbo: bool, ) -> Result { let quant = quant @@ -117,7 +118,8 @@ fn load_model( let quant = quant.into_iter().chain(quant_nf4).collect(); let model = ModelBuilder::new(context, data) .with_quant(quant) - .with_turbo(turbo); + .with_turbo(turbo) + .with_embed_device(embed_device.unwrap_or_default().into()); match lora { Some(lora) => { let file = File::open(lora)?; @@ -175,6 +177,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; // The model state should keep the same batch as input. @@ -192,6 +195,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; // The model state should keep the same batch as input. @@ -209,6 +213,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; // The model state should keep the same batch as input. @@ -334,6 +339,22 @@ where Ok(()) } +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ValueEnum)] +enum EmbedDevice { + #[default] + Cpu, + Gpu, +} + +impl From for web_rwkv::model::EmbedDevice { + fn from(value: EmbedDevice) -> Self { + match value { + EmbedDevice::Cpu => Self::Cpu, + EmbedDevice::Gpu => Self::Gpu, + } + } +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Cli { @@ -345,6 +366,8 @@ struct Cli { quant: Option, #[arg(long, value_name = "LAYERS")] quant_nf4: Option, + #[arg(short, long)] + embed_device: Option, #[arg(short, long, action)] turbo: bool, #[arg(short, long, default_value_t = 4)] diff --git a/examples/chat.rs b/examples/chat.rs index 1b25745..a5f3241 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use clap::{Args, Parser}; +use clap::{Args, Parser, ValueEnum}; #[cfg(not(debug_assertions))] use dialoguer::{theme::ColorfulTheme, Select}; use itertools::Itertools; @@ -74,6 +74,7 @@ async fn create_context(info: &ModelInfo) -> Result { let instance = Instance::new(); let limits = wgpu::Limits { max_storage_buffer_binding_size: info.max_buffer_size() as u32, + max_buffer_size: info.max_buffer_size() as u64, ..Default::default() }; #[cfg(not(debug_assertions))] @@ -118,6 +119,7 @@ fn load_model( lora: Option, quant: Option, quant_nf4: Option, + embed_device: Option, turbo: bool, ) -> Result { let quant = quant @@ -129,7 +131,8 @@ fn load_model( let quant = quant.into_iter().chain(quant_nf4).collect(); let model = ModelBuilder::new(context, data) .with_quant(quant) - .with_turbo(turbo); + .with_turbo(turbo) + .with_embed_device(embed_device.unwrap_or_default().into()); match lora { Some(lora) => { let file = File::open(lora)?; @@ -197,6 +200,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; let state: v4::ModelState = StateBuilder::new(&context, model.info()).build(); @@ -209,6 +213,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; let state: v5::ModelState = StateBuilder::new(&context, model.info()).build(); @@ -221,6 +226,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; let state: v6::ModelState = StateBuilder::new(&context, model.info()).build(); @@ -336,6 +342,22 @@ where Ok(()) } +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ValueEnum)] +enum EmbedDevice { + #[default] + Cpu, + Gpu, +} + +impl From for web_rwkv::model::EmbedDevice { + fn from(value: EmbedDevice) -> Self { + match value { + EmbedDevice::Cpu => Self::Cpu, + EmbedDevice::Gpu => Self::Gpu, + } + } +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Cli { @@ -349,6 +371,8 @@ struct Cli { quant: Option, #[arg(long, value_name = "LAYERS")] quant_nf4: Option, + #[arg(short, long)] + embed_device: Option, #[arg(short, long, action)] turbo: bool, #[command(flatten)] diff --git a/examples/gen.rs b/examples/gen.rs index a0c9410..eb3f7a8 100644 --- a/examples/gen.rs +++ b/examples/gen.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use clap::Parser; +use clap::{Parser, ValueEnum}; #[cfg(not(debug_assertions))] use dialoguer::{theme::ColorfulTheme, Select}; use itertools::Itertools; @@ -95,6 +95,7 @@ fn load_model( lora: Option, quant: Option, quant_nf4: Option, + embed_device: Option, turbo: bool, ) -> Result { let quant = quant @@ -106,7 +107,8 @@ fn load_model( let quant = quant.into_iter().chain(quant_nf4).collect(); let model = ModelBuilder::new(context, data) .with_quant(quant) - .with_turbo(turbo); + .with_turbo(turbo) + .with_embed_device(embed_device.unwrap_or_default().into()); match lora { Some(lora) => { let file = File::open(lora)?; @@ -149,6 +151,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; let state: v4::ModelState = StateBuilder::new(&context, model.info()).build(); @@ -161,6 +164,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; let state: v5::ModelState = StateBuilder::new(&context, model.info()).build(); @@ -173,6 +177,7 @@ async fn run(cli: Cli) -> Result<()> { cli.lora, cli.quant, cli.quant_nf4, + cli.embed_device, cli.turbo, )?; let state: v6::ModelState = StateBuilder::new(&context, model.info()).build(); @@ -217,6 +222,22 @@ where Ok(()) } +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ValueEnum)] +enum EmbedDevice { + #[default] + Cpu, + Gpu, +} + +impl From for web_rwkv::model::EmbedDevice { + fn from(value: EmbedDevice) -> Self { + match value { + EmbedDevice::Cpu => Self::Cpu, + EmbedDevice::Gpu => Self::Gpu, + } + } +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Cli { @@ -228,6 +249,8 @@ struct Cli { quant: Option, #[arg(long, value_name = "LAYERS")] quant_nf4: Option, + #[arg(short, long)] + embed_device: Option, #[arg(short, long, action)] turbo: bool, } diff --git a/src/context.rs b/src/context.rs index 11479c4..b5dc69a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -115,20 +115,26 @@ impl<'a> ContextBuilder<'a> { } pub async fn build(self) -> Result { - let (device, queue) = self - .adapter + let Self { + adapter, + features, + limits, + pipelines, + } = self; + + let (device, queue) = adapter .request_device( &DeviceDescriptor { label: None, - features: self.features, - limits: self.limits, + features, + limits, }, None, ) .await .map_err(|_| CreateEnvironmentError::RequestDeviceFailed)?; - let pipelines = self - .pipelines + + let pipelines = pipelines .into_iter() .map(|(name, (shader, entry_point, layout))| { let module = &device.create_shader_module(ShaderModuleDescriptor { @@ -158,10 +164,11 @@ impl<'a> ContextBuilder<'a> { ) }) .collect(); + Ok(Context( ContextInner { id: ContextId::new(), - adapter: self.adapter, + adapter, device, queue, pipelines, @@ -172,30 +179,29 @@ impl<'a> ContextBuilder<'a> { )) } - pub fn with_limits(self, limits: Limits) -> Self { - Self { limits, ..self } + pub fn with_limits(mut self, limits: Limits) -> Self { + self.limits = limits; + self } - pub fn with_features(self, features: Features) -> Self { - Self { features, ..self } + pub fn with_features(mut self, features: Features) -> Self { + self.features = features; + self } pub fn with_pipeline( - self, + mut self, name: &'a str, shader: &'a str, entry_point: &'a str, layout: Option<&'a [BindGroupLayoutEntry]>, ) -> Self { - let mut pipelines = self.pipelines; - pipelines.insert(name, (shader, entry_point, layout)); - Self { pipelines, ..self } + self.pipelines.insert(name, (shader, entry_point, layout)); + self } pub fn with_default_pipelines(self) -> Self { - self.with_core_pipelines() - .with_util_pipelines() - .with_quant_pipelines() + self.with_core_pipelines().with_quant_pipelines() } fn with_core_pipelines(self) -> Self { @@ -333,10 +339,7 @@ impl<'a> ContextBuilder<'a> { "softmax", None, ) - } - - fn with_util_pipelines(self) -> Self { - self.with_pipeline("blit", include_str!("shaders/blit.wgsl"), "blit", None) + .with_pipeline("blit", include_str!("shaders/blit.wgsl"), "blit", None) .with_pipeline( "transpose", include_str!("shaders/blit.wgsl"), diff --git a/src/model/mod.rs b/src/model/mod.rs index dc4721c..b1b3128 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -25,8 +25,8 @@ pub mod v6; pub const RESCALE_LAYER: usize = 6; pub const MIN_TOKEN_CHUNK_SIZE: usize = 32; -pub const HEAD_CHUNK_SIZES: [usize; 8] = [ - 0x4000, 0x3000, 0x2000, 0x1800, 0x1600, 0x1400, 0x1200, 0x1000, +pub const HEAD_CHUNK_SIZES: [usize; 10] = [ + 0x10000, 0x5000, 0x4000, 0x3000, 0x2000, 0x1800, 0x1600, 0x1400, 0x1200, 0x1000, ]; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -72,9 +72,14 @@ pub struct ModelInfo { } impl ModelInfo { - /// Computes the required storage buffer size. + /// Computes the required storage buffer size, not including head. + pub fn max_non_head_buffer_size(&self) -> usize { + (self.num_emb * self.num_hidden * f16::size()).max(256 << 20) + } + + /// Computes the required storage buffer size, including head. pub fn max_buffer_size(&self) -> usize { - (self.num_emb * self.num_hidden * f16::size()).max(128 << 20) + (self.num_emb * self.num_vocab * f16::size()).max(256 << 20) } } diff --git a/src/model/v4.rs b/src/model/v4.rs index 6c771a6..03a1ace 100644 --- a/src/model/v4.rs +++ b/src/model/v4.rs @@ -96,7 +96,7 @@ struct Layer { struct Embed<'a> { layer_norm: LayerNorm, w: TensorCpu<'a, f16>, - _u: Option>, + u: Option>, } #[derive(Debug)] @@ -135,11 +135,12 @@ pub struct Runtime { impl Runtime { pub fn new(context: &Context, info: &ModelInfo, num_token: usize, max_token: usize) -> Self { let shape = Shape::new(info.num_emb, num_token, 1, 1); + let tokens_shape = Shape::new(num_token, 1, 1, 1); let cursors_shape = Shape::new(max_token, 1, 1, 1); let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1); Self { - tokens: context.tensor_init(cursors_shape), + tokens: context.tensor_init(tokens_shape), cursors: context.tensor_init(cursors_shape), input: context.tensor_init(shape), att_x: context.tensor_init(shape), @@ -454,7 +455,7 @@ impl<'a> FromBuilder for Model<'a> { b: loader.load_vector_f16("blocks.0.ln0.bias")?, }, w: loader.load_embed()?, - _u: match embed_device { + u: match embed_device { super::EmbedDevice::Cpu => None, super::EmbedDevice::Gpu => Some(loader.load_matrix_f16("emb.weight")?), }, @@ -726,18 +727,30 @@ impl ModelRunInner for Model<'_> { // }) // .try_collect()?; + let mut ops = vec![]; + let mut cursors = input.cursors.into_cursors(); cursors.resize(self.token_chunk_size, 0); - let cursors = context.tensor_from_data(buffer.cursors.shape(), cursors)?; - buffer.input.load(&input.tensor)?; + let cursors = context.tensor_from_data(buffer.cursors.shape(), cursors)?; buffer.cursors.load(&cursors)?; - let mut encoder = context - .device - .create_command_encoder(&CommandEncoderDescriptor::default()); + match &tensor.embed.u { + Some(u) => { + let tokens = tokens + .concat() + .into_iter() + .map(|token| token as u32) + .collect_vec(); + + let tokens = context.tensor_from_data(buffer.tokens.shape(), tokens)?; + buffer.tokens.load(&tokens)?; - let op = TensorOp::List(vec![ + ops.push(TensorOp::embed(&buffer.tokens, u, &buffer.input)?); + } + None => buffer.input.load(&input.tensor)?, + } + ops.append(&mut vec![ hook_op(Hook::PostEmbedLoaded), TensorOp::layer_norm( &tensor.embed.layer_norm.w, @@ -746,8 +759,14 @@ impl ModelRunInner for Model<'_> { )?, hook_op(Hook::PostEmbedLayerNorm), ]); + + let mut encoder = context + .device + .create_command_encoder(&CommandEncoderDescriptor::default()); + + let ops = TensorOp::List(ops); let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default()); - pass.execute_tensor_op(&op); + pass.execute_tensor_op(&ops); drop(pass); for (index, layer) in tensor.layers.iter().enumerate() { diff --git a/src/model/v5.rs b/src/model/v5.rs index 881ae30..fa05d4a 100644 --- a/src/model/v5.rs +++ b/src/model/v5.rs @@ -99,7 +99,7 @@ struct Layer { struct Embed<'a> { layer_norm: LayerNorm, w: TensorCpu<'a, f16>, - _u: Option>, + u: Option>, } #[derive(Debug)] @@ -111,6 +111,7 @@ struct Head { /// Runtime buffers. #[derive(Debug)] pub struct Runtime { + pub tokens: TensorGpu, pub cursors: TensorGpu, pub input: TensorGpu, @@ -139,10 +140,12 @@ pub struct Runtime { impl Runtime { pub fn new(context: &Context, info: &ModelInfo, num_token: usize, max_token: usize) -> Self { let shape = Shape::new(info.num_emb, num_token, 1, 1); + let tokens_shape = Shape::new(num_token, 1, 1, 1); let cursors_shape = Shape::new(max_token, 1, 1, 1); let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1); Self { + tokens: context.tensor_init(tokens_shape), cursors: context.tensor_init(cursors_shape), input: context.tensor_init(shape), att_x: context.tensor_init(shape), @@ -529,7 +532,7 @@ impl<'a> FromBuilder for Model<'a> { b: loader.load_vector_f16("blocks.0.ln0.bias")?, }, w: loader.load_embed()?, - _u: match embed_device { + u: match embed_device { super::EmbedDevice::Cpu => None, super::EmbedDevice::Gpu => Some(loader.load_matrix_f16("emb.weight")?), }, @@ -825,18 +828,30 @@ impl ModelRunInner for Model<'_> { // }) // .try_collect()?; + let mut ops = vec![]; + let mut cursors = input.cursors.into_cursors(); cursors.resize(self.token_chunk_size, 0); - let cursors = context.tensor_from_data(buffer.cursors.shape(), cursors)?; - buffer.input.load(&input.tensor)?; + let cursors = context.tensor_from_data(buffer.cursors.shape(), cursors)?; buffer.cursors.load(&cursors)?; - let mut encoder = context - .device - .create_command_encoder(&CommandEncoderDescriptor::default()); + match &tensor.embed.u { + Some(u) => { + let tokens = tokens + .concat() + .into_iter() + .map(|token| token as u32) + .collect_vec(); + + let tokens = context.tensor_from_data(buffer.tokens.shape(), tokens)?; + buffer.tokens.load(&tokens)?; - let op = TensorOp::List(vec![ + ops.push(TensorOp::embed(&buffer.tokens, u, &buffer.input)?); + } + None => buffer.input.load(&input.tensor)?, + } + ops.append(&mut vec![ hook_op(Hook::PostEmbedLoaded), TensorOp::layer_norm( &tensor.embed.layer_norm.w, @@ -845,8 +860,14 @@ impl ModelRunInner for Model<'_> { )?, hook_op(Hook::PostEmbedLayerNorm), ]); + + let mut encoder = context + .device + .create_command_encoder(&CommandEncoderDescriptor::default()); + + let ops = TensorOp::List(ops); let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default()); - pass.execute_tensor_op(&op); + pass.execute_tensor_op(&ops); drop(pass); for (index, layer) in tensor.layers.iter().enumerate() { diff --git a/src/model/v6.rs b/src/model/v6.rs index 845029c..d230654 100644 --- a/src/model/v6.rs +++ b/src/model/v6.rs @@ -106,7 +106,7 @@ struct Layer { struct Embed<'a> { layer_norm: LayerNorm, w: TensorCpu<'a, f16>, - _u: Option>, + u: Option>, } #[derive(Debug)] @@ -118,6 +118,7 @@ struct Head { /// Runtime buffers. #[derive(Debug)] pub struct Runtime { + pub tokens: TensorGpu, pub cursors: TensorGpu, pub input: TensorGpu, @@ -161,6 +162,7 @@ pub struct Runtime { impl Runtime { pub fn new(context: &Context, info: &ModelInfo, num_token: usize, max_token: usize) -> Self { let shape = Shape::new(info.num_emb, num_token, 1, 1); + let tokens_shape = Shape::new(num_token, 1, 1, 1); let cursors_shape = Shape::new(max_token, 1, 1, 1); let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1); let time_mix_shape = Shape::new(info.num_emb, num_token, 5, 1); @@ -169,6 +171,7 @@ impl Runtime { let time_decay_shape = Shape::new(Model::TIME_DECAY_ADAPTER_SIZE, num_token, 1, 1); Self { + tokens: context.tensor_init(tokens_shape), cursors: context.tensor_init(cursors_shape), input: context.tensor_init(shape), att_x: context.tensor_init(shape), @@ -579,7 +582,7 @@ impl<'a> FromBuilder for Model<'a> { b: loader.load_vector_f16("blocks.0.ln0.bias")?, }, w: loader.load_embed()?, - _u: match embed_device { + u: match embed_device { super::EmbedDevice::Cpu => None, super::EmbedDevice::Gpu => Some(loader.load_matrix_f16("emb.weight")?), }, @@ -806,6 +809,7 @@ impl ModelRunInner for Model<'_> { Output::new(&self.context, &self.info, num_batch) }) } + fn run_internal( &self, tokens: Vec>, @@ -842,6 +846,13 @@ impl ModelRunInner for Model<'_> { let buffer = self.request_runtime(num_token); let output = self.request_output(num_header.max(1)); + let hook_op = |hook: Hook| -> TensorOp { + hooks + .get(&hook) + .map(|f| f(state, &buffer)) + .unwrap_or(TensorOp::List(vec![])) + }; + // gather and group copy operations let (head_ops, head_x) = if num_token == 1 || num_token == num_header { (TensorOp::List(vec![]), &buffer.ffn_x) @@ -881,25 +892,30 @@ impl ModelRunInner for Model<'_> { // }) // .try_collect()?; + let mut ops = vec![]; + let mut cursors = input.cursors.into_cursors(); cursors.resize(self.token_chunk_size, 0); - let cursors = context.tensor_from_data(buffer.cursors.shape(), cursors)?; - buffer.input.load(&input.tensor)?; + let cursors = context.tensor_from_data(buffer.cursors.shape(), cursors)?; buffer.cursors.load(&cursors)?; - let hook_op = |hook: Hook| -> TensorOp { - hooks - .get(&hook) - .map(|f| f(state, &buffer)) - .unwrap_or(TensorOp::List(vec![])) - }; + match &tensor.embed.u { + Some(u) => { + let tokens = tokens + .concat() + .into_iter() + .map(|token| token as u32) + .collect_vec(); - let mut encoder = context - .device - .create_command_encoder(&CommandEncoderDescriptor::default()); + let tokens = context.tensor_from_data(buffer.tokens.shape(), tokens)?; + buffer.tokens.load(&tokens)?; - let op = TensorOp::List(vec![ + ops.push(TensorOp::embed(&buffer.tokens, u, &buffer.input)?); + } + None => buffer.input.load(&input.tensor)?, + } + ops.append(&mut vec![ hook_op(Hook::PostEmbedLoaded), TensorOp::layer_norm( &tensor.embed.layer_norm.w, @@ -908,8 +924,14 @@ impl ModelRunInner for Model<'_> { )?, hook_op(Hook::PostEmbedLayerNorm), ]); + + let mut encoder = context + .device + .create_command_encoder(&CommandEncoderDescriptor::default()); + + let ops = TensorOp::List(ops); let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default()); - pass.execute_tensor_op(&op); + pass.execute_tensor_op(&ops); drop(pass); for (index, layer) in tensor.layers.iter().enumerate() {