From d771285bfd03b5971c2411b414dfab3065922397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A0=94=E7=A9=B6=E7=A4=BE=E4=BA=A4?= Date: Sat, 11 May 2024 01:16:00 +0800 Subject: [PATCH] Tensor op execute (#31) * Switch to the new tensor op encoding API. * Reduce sep frequency. --- examples/inspector.rs | 19 ++---- src/model/loader.rs | 95 +++++++++++----------------- src/model/softmax.rs | 17 +---- src/model/v4.rs | 81 +++++++++--------------- src/model/v5.rs | 80 +++++++++-------------- src/model/v6.rs | 80 +++++++++-------------- src/runtime/infer.rs | 1 + src/runtime/loader.rs | 96 +++++++++++----------------- src/runtime/softmax.rs | 20 +----- src/runtime/v4.rs | 23 +++---- src/runtime/v5.rs | 36 +++++------ src/runtime/v6.rs | 44 +++++-------- src/tensor/matrix.rs | 20 +----- src/tensor/mod.rs | 8 +-- src/tensor/ops.rs | 140 +++++++++++++++++++++-------------------- 15 files changed, 288 insertions(+), 472 deletions(-) diff --git a/examples/inspector.rs b/examples/inspector.rs index 03ad66f..0d9e621 100644 --- a/examples/inspector.rs +++ b/examples/inspector.rs @@ -22,11 +22,7 @@ use web_rwkv::{ v5, Build, BuildFuture, ContextAutoLimits, Model, ModelBuilder, ModelInfo, ModelInput, ModelOutput, ModelState, ModelVersion, Quant, StateBuilder, }, - tensor::{ - kind::ReadWrite, - ops::{TensorOp, TensorPass}, - TensorError, TensorGpu, TensorShape, - }, + tensor::{kind::ReadWrite, ops::TensorOp, TensorError, TensorGpu, TensorShape}, tokenizer::Tokenizer, }; @@ -244,10 +240,8 @@ async fn run(cli: Cli) -> Result<()> { } // map the activations into vocab space - let mut encoder = context.device.create_command_encoder(&Default::default()); - let tensor = model.tensor(); - let ops = TensorOp::List(vec![ + let ops = vec![ TensorOp::layer_norm( &tensor.head.layer_norm.w, &tensor.head.layer_norm.b, @@ -259,13 +253,8 @@ async fn run(cli: Cli) -> Result<()> { buffer.out.view(.., .., .., ..)?, Default::default(), )?, - ]); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - - context.queue.submit(Some(encoder.finish())); + ]; + context.queue.submit(context.encode(&TensorOp::List(ops))); // for each layer, choose the top 5 tokens let backed = buffer.out.back().await.to_vec(); diff --git a/src/model/loader.rs b/src/model/loader.rs index 6f02015..e48dabf 100644 --- a/src/model/loader.rs +++ b/src/model/loader.rs @@ -14,7 +14,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{TensorCommand, TensorOp, TensorPass}, + ops::TensorOp, shape::{Shape, TensorDimension}, TensorCpu, TensorError, TensorGpu, TensorInit, TensorInto, TensorReshape, TensorShape, }, @@ -339,7 +339,7 @@ impl Loader { .reshape(Auto, Dimension(1), Dimension(1), Dimension(1))? .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_vectors(name).await? { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -353,10 +353,10 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - self.context.queue.submit(Some(encoder.finish())); + + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -373,7 +373,7 @@ impl Loader { .reshape(Auto, Dimension(1), Dimension(1), Dimension(1))? .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_vectors(name).await? { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -387,16 +387,13 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } let op = TensorOp::opposite_exp(&tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); + ops.push(op); - self.context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -414,7 +411,7 @@ impl Loader { .reshape(Auto, Dimension(1), Dimension(1), Dimension(1))? .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_vectors(name).await? { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -428,16 +425,13 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } let op = TensorOp::stable_exp(&tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); + ops.push(op); - self.context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -460,7 +454,7 @@ impl Loader { .transfer_into(context); let tensor_f16: TensorGpu = context.tensor_init(tensor_f32.shape()); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in lora { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -474,19 +468,16 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } let op = TensorOp::blit( tensor_f32.view(.., .., .., ..)?, tensor_f16.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); + ops.push(op); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); tensor_f16 }; Ok(tensor) @@ -500,7 +491,7 @@ impl Loader { let tensor = self.model.tensor(name.as_ref()).await?; let tensor: TensorGpu<_, _> = TensorCpu::from_reader(tensor)?.transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -510,18 +501,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, tensor.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -536,7 +525,7 @@ impl Loader { .map(|x| f16::from_f32(discount * x.to_f32())) .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -546,18 +535,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, tensor.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -571,7 +558,7 @@ impl Loader { let tensor = TensorCpu::from_reader(tensor)?; matrix.load(&tensor)?; - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -581,18 +568,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, matrix.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, matrix)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(()) } @@ -611,7 +596,7 @@ impl Loader { .reshape(Full, Full, Dimension(1), Dimension(1))?; matrix.load(&tensor)?; - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -621,18 +606,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, matrix.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, matrix)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(()) } @@ -648,20 +631,16 @@ impl Loader { Ok(tensor) } else { let tensor = TensorCpu::from_reader((dt, shape, tensor))?.transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in lora { let factor = vec![lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - let map = context.tensor_init(tensor.shape()); - encoder.copy_tensor(&tensor, &map)?; - - context.queue.submit(Some(encoder.finish())); - Ok(map.back().await) + context.queue.submit(context.encode(&TensorOp::List(ops))); + Ok(tensor.back().await) } } diff --git a/src/model/softmax.rs b/src/model/softmax.rs index 4c7b14a..2657356 100644 --- a/src/model/softmax.rs +++ b/src/model/softmax.rs @@ -7,10 +7,8 @@ use super::{ModelBase, ModelInfo, ModelOutput}; use crate::{ context::Context, tensor::{ - kind::ReadWrite, - ops::{TensorOp, TensorPass}, - shape::Shape, - TensorCpu, TensorError, TensorGpu, TensorInit, TensorShape, + kind::ReadWrite, ops::TensorOp, shape::Shape, TensorCpu, TensorError, TensorGpu, + TensorInit, TensorShape, }, }; @@ -75,16 +73,7 @@ impl ModelSoftmax for M { softmax.buffer.load(&input)?; let op = TensorOp::softmax(&softmax.buffer)?; - - let mut encoder = self - .context() - .device - .create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&op)); let output = softmax.buffer.back().await; Ok(redirect diff --git a/src/model/v4.rs b/src/model/v4.rs index afbb257..7a8218c 100644 --- a/src/model/v4.rs +++ b/src/model/v4.rs @@ -19,7 +19,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{Activation, TensorCommand, TensorOp, TensorPass}, + ops::{Activation, TensorCommand, TensorOp}, shape::Shape, DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorShape, @@ -106,6 +106,9 @@ pub struct Runtime { pub cursors: TensorGpu, pub input: TensorGpu, + pub x: TensorGpu, + pub aux_x: TensorGpu, + pub att_x: TensorGpu, pub att_kx: TensorGpu, pub att_vx: TensorGpu, @@ -121,8 +124,6 @@ pub struct Runtime { pub ffn_k: TensorGpu, pub ffn_v: TensorGpu, pub ffn_r: TensorGpu, - - pub aux_x: TensorGpu, } impl Runtime { @@ -136,6 +137,8 @@ impl Runtime { tokens: context.tensor_init(tokens_shape), cursors: context.tensor_init(cursors_shape), input: context.tensor_init(shape), + x: context.tensor_init(shape), + aux_x: context.tensor_init(shape), att_x: context.tensor_init(shape), att_kx: context.tensor_init(shape), att_vx: context.tensor_init(shape), @@ -150,7 +153,6 @@ impl Runtime { ffn_k: context.tensor_init(hidden_shape), ffn_v: context.tensor_init(shape), ffn_r: context.tensor_init(shape), - aux_x: context.tensor_init(shape), } } } @@ -322,18 +324,11 @@ impl super::ModelState for ModelState { to_batch: usize, ) -> Result<(), TensorError> { let context = self.context(); - let mut encoder = context.device.create_command_encoder(&Default::default()); - let op = TensorOp::blit( self.view(.., .., from_batch, ..)?, other.view(.., .., to_batch, ..)?, )?; - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&op)); Ok(()) } } @@ -626,7 +621,7 @@ impl> ModelRunInternal for Model { // collect and group copy operations let (head_ops, head_x) = if num_token == 1 || num_token == num_header { - (TensorOp::empty(), &buffer.ffn_x) + (TensorOp::empty(), &buffer.x) } else { let mut start = 0; let mut end = 1; @@ -637,7 +632,7 @@ impl> ModelRunInternal for Model { let last = headers[end - 1]; assert_eq!(last - first + 1, end - start); - let input = buffer.ffn_x.view(.., first..=last, .., ..)?; + let input = buffer.x.view(.., first..=last, .., ..)?; let output = header.head_x.view(.., start..end, .., ..)?; ops.push(TensorOp::blit(input, output)?); @@ -679,20 +674,19 @@ impl> ModelRunInternal for Model { &buffer.input, Self::LN_EPS, )?, + TensorOp::blit( + buffer.input.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, + )?, hook_op(Hook::PostEmbedLayerNorm)?, ]); - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let ops = TensorOp::List(ops); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - for (index, layer) in tensor.layers.iter().enumerate() { - encoder.copy_tensor(&buffer.input, &buffer.att_x)?; - - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + TensorOp::blit( + buffer.x.view(.., .., .., ..)?, + buffer.att_x.view(.., .., .., ..)?, + )?, hook_op(Hook::PreAtt(index))?, TensorOp::layer_norm( &layer.att_layer_norm.w, @@ -776,19 +770,17 @@ impl> ModelRunInternal for Model { )?, hook_op(Hook::PostAttOut(index))?, TensorOp::add( - buffer.input.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, )?, hook_op(Hook::PostAtt(index))?, ]); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - - encoder.copy_tensor(&buffer.att_o, &buffer.ffn_x)?; - - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + TensorOp::blit( + buffer.x.view(.., .., .., ..)?, + buffer.ffn_x.view(.., .., .., ..)?, + )?, hook_op(Hook::PreFfn(index))?, TensorOp::layer_norm( &layer.ffn_layer_norm.w, @@ -846,30 +838,20 @@ impl> ModelRunInternal for Model { )?, hook_op(Hook::PostFfnChannelMix(index))?, TensorOp::add( - buffer.att_o.view(.., .., .., ..)?, buffer.ffn_x.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, )?, hook_op(Hook::PostFfn(index))?, ]); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - if (index + 1) % RESCALE_LAYER == 0 { - let op = TensorOp::discount(&buffer.ffn_x, 0.5, 0.0)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - } - - if index != self.info.num_layer - 1 { - encoder.copy_tensor(&buffer.ffn_x, &buffer.input)?; + ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?); } } if num_header > 0 { - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + head_ops, hook_op(Hook::PreHead)?, TensorOp::layer_norm( &tensor.head.layer_norm.w, @@ -886,14 +868,9 @@ impl> ModelRunInternal for Model { )?, hook_op(Hook::PostHead)?, ]); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&head_ops); - pass.execute_tensor_op(&ops); - drop(pass); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok((header.head_o.clone(), redirect)) } } diff --git a/src/model/v5.rs b/src/model/v5.rs index e72c5b3..e64ae33 100644 --- a/src/model/v5.rs +++ b/src/model/v5.rs @@ -19,7 +19,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{Activation, TensorCommand, TensorOp, TensorPass}, + ops::{Activation, TensorCommand, TensorOp}, shape::{Shape, TensorDimension}, DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorReshape, TensorShape, @@ -110,6 +110,9 @@ pub struct Runtime { pub cursors: TensorGpu, pub input: TensorGpu, + pub x: TensorGpu, + pub aux_x: TensorGpu, + pub att_x: TensorGpu, pub att_kx: TensorGpu, pub att_vx: TensorGpu, @@ -127,8 +130,6 @@ pub struct Runtime { pub ffn_k: TensorGpu, pub ffn_v: TensorGpu, pub ffn_r: TensorGpu, - - pub aux_x: TensorGpu, } impl Runtime { @@ -142,6 +143,8 @@ impl Runtime { tokens: context.tensor_init(tokens_shape), cursors: context.tensor_init(cursors_shape), input: context.tensor_init(shape), + x: context.tensor_init(shape), + aux_x: context.tensor_init(shape), att_x: context.tensor_init(shape), att_kx: context.tensor_init(shape), att_vx: context.tensor_init(shape), @@ -158,7 +161,6 @@ impl Runtime { ffn_k: context.tensor_init(hidden_shape), ffn_v: context.tensor_init(shape), ffn_r: context.tensor_init(shape), - aux_x: context.tensor_init(shape), } } } @@ -381,17 +383,11 @@ impl super::ModelState for ModelState { ) -> Result<(), TensorError> { for (state, other) in self.state.iter().zip(other.state.iter()) { let context = state.context(); - let mut encoder = context.device.create_command_encoder(&Default::default()); - let op = TensorOp::blit( state.view(.., .., from_batch, ..)?, other.view(.., .., to_batch, ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&op)); } Ok(()) } @@ -716,7 +712,7 @@ impl ModelRunInternal for Model { // collect and group copy operations let (head_ops, head_x) = if num_token == 1 || num_token == num_header { - (TensorOp::empty(), &buffer.ffn_x) + (TensorOp::empty(), &buffer.x) } else { let mut start = 0; let mut end = 1; @@ -727,7 +723,7 @@ impl ModelRunInternal for Model { let last = headers[end - 1]; assert_eq!(last - first + 1, end - start); - let input = buffer.ffn_x.view(.., first..=last, .., ..)?; + let input = buffer.x.view(.., first..=last, .., ..)?; let output = header.head_x.view(.., start..end, .., ..)?; ops.push(TensorOp::blit(input, output)?); @@ -769,16 +765,13 @@ impl ModelRunInternal for Model { &buffer.input, Self::LN_EPS, )?, + TensorOp::blit( + buffer.input.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, + )?, hook_op(Hook::PostEmbedLayerNorm)?, ]); - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let ops = TensorOp::List(ops); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - for (index, layer) in tensor.layers.iter().enumerate() { use TensorDimension::{Auto, Dimension}; let time_first = layer.att.time_first.reshape( @@ -818,9 +811,11 @@ impl ModelRunInternal for Model { Dimension(1), )?; - encoder.copy_tensor(&buffer.input, &buffer.att_x)?; - - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + TensorOp::blit( + buffer.x.view(.., .., .., ..)?, + buffer.att_x.view(.., .., .., ..)?, + )?, hook_op(Hook::PreAtt(index))?, TensorOp::layer_norm( &layer.att_layer_norm.w, @@ -927,19 +922,17 @@ impl ModelRunInternal for Model { )?, hook_op(Hook::PostAttOut(index))?, TensorOp::add( - buffer.input.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, )?, hook_op(Hook::PostAtt(index))?, ]); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - - encoder.copy_tensor(&buffer.att_o, &buffer.ffn_x)?; - - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + TensorOp::blit( + buffer.x.view(.., .., .., ..)?, + buffer.ffn_x.view(.., .., .., ..)?, + )?, hook_op(Hook::PreFfn(index))?, TensorOp::layer_norm( &layer.ffn_layer_norm.w, @@ -997,30 +990,20 @@ impl ModelRunInternal for Model { )?, hook_op(Hook::PostFfnChannelMix(index))?, TensorOp::add( - buffer.att_o.view(.., .., .., ..)?, buffer.ffn_x.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, )?, hook_op(Hook::PostFfn(index))?, ]); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - if (index + 1) % RESCALE_LAYER == 0 { - let op = TensorOp::discount(&buffer.ffn_x, 0.5, 0.0)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - } - - if index != self.info.num_layer - 1 { - encoder.copy_tensor(&buffer.ffn_x, &buffer.input)?; + ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?); } } if num_header > 0 { - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + head_ops, hook_op(Hook::PreHead)?, TensorOp::layer_norm( &tensor.head.layer_norm.w, @@ -1037,14 +1020,9 @@ impl ModelRunInternal for Model { )?, hook_op(Hook::PostHead)?, ]); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&head_ops); - pass.execute_tensor_op(&ops); - drop(pass); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok((header.head_o.clone(), redirect)) } } diff --git a/src/model/v6.rs b/src/model/v6.rs index 582c0a5..cf88185 100644 --- a/src/model/v6.rs +++ b/src/model/v6.rs @@ -19,7 +19,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{Activation, TensorCommand, TensorOp, TensorPass}, + ops::{Activation, TensorCommand, TensorOp}, shape::{Shape, TensorDimension}, DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorReshape, TensorShape, @@ -117,6 +117,9 @@ pub struct Runtime { pub cursors: TensorGpu, pub input: TensorGpu, + pub x: TensorGpu, + pub aux_x: TensorGpu, + pub att_x: TensorGpu, pub att_xx: TensorGpu, /// Token shifted time decay input, `[C, T]`. @@ -147,8 +150,6 @@ pub struct Runtime { pub ffn_k: TensorGpu, pub ffn_v: TensorGpu, pub ffn_r: TensorGpu, - - pub aux_x: TensorGpu, } impl Runtime { @@ -166,6 +167,8 @@ impl Runtime { tokens: context.tensor_init(tokens_shape), cursors: context.tensor_init(cursors_shape), input: context.tensor_init(shape), + x: context.tensor_init(shape), + aux_x: context.tensor_init(shape), att_x: context.tensor_init(shape), att_xx: context.tensor_init(shape), att_wx: context.tensor_init(shape), @@ -189,7 +192,6 @@ impl Runtime { ffn_k: context.tensor_init(hidden_shape), ffn_v: context.tensor_init(shape), ffn_r: context.tensor_init(shape), - aux_x: context.tensor_init(shape), } } } @@ -422,17 +424,11 @@ impl super::ModelState for ModelState { ) -> Result<(), TensorError> { for (state, other) in self.state.iter().zip(other.state.iter()) { let context = state.context(); - let mut encoder = context.device.create_command_encoder(&Default::default()); - let op = TensorOp::blit( state.view(.., .., from_batch, ..)?, other.view(.., .., to_batch, ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&op)); } Ok(()) } @@ -773,7 +769,7 @@ impl ModelRunInternal for Model { // collect and group copy operations let (head_ops, head_x) = if num_token == 1 || num_token == num_header { - (TensorOp::empty(), &buffer.ffn_x) + (TensorOp::empty(), &buffer.x) } else { let mut start = 0; let mut end = 1; @@ -784,7 +780,7 @@ impl ModelRunInternal for Model { let last = headers[end - 1]; assert_eq!(last - first + 1, end - start); - let input = buffer.ffn_x.view(.., first..=last, .., ..)?; + let input = buffer.x.view(.., first..=last, .., ..)?; let output = header.head_x.view(.., start..end, .., ..)?; ops.push(TensorOp::blit(input, output)?); @@ -826,16 +822,13 @@ impl ModelRunInternal for Model { &buffer.input, Self::LN_EPS, )?, + TensorOp::blit( + buffer.input.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, + )?, hook_op(Hook::PostEmbedLayerNorm)?, ]); - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let ops = TensorOp::List(ops); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - for (index, layer) in tensor.layers.iter().enumerate() { use TensorDimension::{Auto, Dimension}; let time_first = layer.att.time_first.reshape( @@ -881,9 +874,11 @@ impl ModelRunInternal for Model { Dimension(1), )?; - encoder.copy_tensor(&buffer.input, &buffer.att_x)?; - - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + TensorOp::blit( + buffer.x.view(.., .., .., ..)?, + buffer.att_x.view(.., .., .., ..)?, + )?, hook_op(Hook::PreAtt(index))?, TensorOp::layer_norm( &layer.att_layer_norm.w, @@ -1069,19 +1064,17 @@ impl ModelRunInternal for Model { )?, hook_op(Hook::PostAttOut(index))?, TensorOp::add( - buffer.input.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, )?, hook_op(Hook::PostAtt(index))?, ]); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - - encoder.copy_tensor(&buffer.att_o, &buffer.ffn_x)?; - - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + TensorOp::blit( + buffer.x.view(.., .., .., ..)?, + buffer.ffn_x.view(.., .., .., ..)?, + )?, hook_op(Hook::PreFfn(index))?, TensorOp::layer_norm( &layer.ffn_layer_norm.w, @@ -1139,30 +1132,20 @@ impl ModelRunInternal for Model { )?, hook_op(Hook::PostFfnChannelMix(index))?, TensorOp::add( - buffer.att_o.view(.., .., .., ..)?, buffer.ffn_x.view(.., .., .., ..)?, + buffer.x.view(.., .., .., ..)?, )?, hook_op(Hook::PostFfn(index))?, ]); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - if (index + 1) % RESCALE_LAYER == 0 { - let op = TensorOp::discount(&buffer.ffn_x, 0.5, 0.0)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - } - - if index != self.info.num_layer - 1 { - encoder.copy_tensor(&buffer.ffn_x, &buffer.input)?; + ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?); } } if num_header > 0 { - let ops = TensorOp::List(vec![ + ops.append(&mut vec![ + head_ops, hook_op(Hook::PreHead)?, TensorOp::layer_norm( &tensor.head.layer_norm.w, @@ -1179,14 +1162,9 @@ impl ModelRunInternal for Model { )?, hook_op(Hook::PostHead)?, ]); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&head_ops); - pass.execute_tensor_op(&ops); - drop(pass); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok((header.head_o.clone(), redirect)) } } diff --git a/src/runtime/infer.rs b/src/runtime/infer.rs index ef8b127..be6ba74 100644 --- a/src/runtime/infer.rs +++ b/src/runtime/infer.rs @@ -5,6 +5,7 @@ use super::{JobInfo, JobInput}; use crate::tensor::TensorCpu; pub const MIN_TOKEN_CHUNK_SIZE: usize = 32; +pub const NUM_LAYER_CHUNK: usize = 4; #[derive(Debug, Clone, Deref, DerefMut, PartialEq, Eq)] pub struct InferInfo(pub Vec); diff --git a/src/runtime/loader.rs b/src/runtime/loader.rs index bb7f00b..f73edb5 100644 --- a/src/runtime/loader.rs +++ b/src/runtime/loader.rs @@ -14,7 +14,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{TensorCommand, TensorOp, TensorPass}, + ops::TensorOp, shape::{Shape, TensorDimension}, TensorCpu, TensorError, TensorGpu, TensorInit, TensorInto, TensorReshape, TensorShape, }, @@ -338,7 +338,7 @@ impl Loader { .reshape(Auto, Dimension(1), Dimension(1), Dimension(1))? .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_vectors(name).await? { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -352,10 +352,10 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - self.context.queue.submit(Some(encoder.finish())); + + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -372,7 +372,7 @@ impl Loader { .reshape(Auto, Dimension(1), Dimension(1), Dimension(1))? .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_vectors(name).await? { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -386,16 +386,13 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } let op = TensorOp::opposite_exp(&tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); + ops.push(op); - self.context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -413,7 +410,7 @@ impl Loader { .reshape(Auto, Dimension(1), Dimension(1), Dimension(1))? .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_vectors(name).await? { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -427,16 +424,13 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } let op = TensorOp::stable_exp(&tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); + ops.push(op); - self.context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -459,7 +453,7 @@ impl Loader { .transfer_into(context); let tensor_f16: TensorGpu = context.tensor_init(tensor_f32.shape()); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in lora { let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -473,19 +467,16 @@ impl Loader { )?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } let op = TensorOp::blit( tensor_f32.view(.., .., .., ..)?, tensor_f16.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); + ops.push(op); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); tensor_f16 }; Ok(tensor) @@ -499,7 +490,7 @@ impl Loader { let tensor = self.model.tensor(name.as_ref()).await?; let tensor: TensorGpu<_, _> = TensorCpu::from_reader(tensor)?.transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -509,18 +500,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, tensor.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -535,7 +524,7 @@ impl Loader { .map(|x| f16::from_f32(discount * x.to_f32())) .transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -545,18 +534,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, tensor.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(tensor) } @@ -570,7 +557,7 @@ impl Loader { let tensor = TensorCpu::from_reader(tensor)?; matrix.load(&tensor)?; - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -580,18 +567,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, matrix.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, matrix)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(()) } @@ -610,7 +595,7 @@ impl Loader { .reshape(Full, Full, Dimension(1), Dimension(1))?; matrix.load(&tensor)?; - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in self.lora_matrices(name.as_ref()).await? { let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; @@ -620,18 +605,16 @@ impl Loader { lora.y.view(.., .., .., ..)?, matrix.view(.., .., .., ..)?, )?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } for lora in self.lora_vectors(name.as_ref()).await? { let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, matrix)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); Ok(()) } @@ -647,20 +630,15 @@ impl Loader { Ok(tensor) } else { let tensor = TensorCpu::from_reader((dt, shape, tensor))?.transfer_into(context); - let mut encoder = context.device.create_command_encoder(&Default::default()); + let mut ops = vec![]; for lora in lora { let factor = vec![lora.alpha, 1.0, 0.0, 0.0]; let factor = context.tensor_from_data([4, 1, 1, 1], factor)?; let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?; - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); + ops.push(op); } - - let map = context.tensor_init(tensor.shape()); - encoder.copy_tensor(&tensor, &map)?; - - context.queue.submit(Some(encoder.finish())); - Ok(map.back().await) + context.queue.submit(context.encode(&TensorOp::List(ops))); + Ok(tensor.back().await) } } diff --git a/src/runtime/softmax.rs b/src/runtime/softmax.rs index afcbfbb..b6eb622 100644 --- a/src/runtime/softmax.rs +++ b/src/runtime/softmax.rs @@ -1,10 +1,7 @@ use crate::{ context::Context, num::Float, - tensor::{ - ops::{TensorOp, TensorPass}, - TensorCpu, TensorError, TensorGpu, TensorInto, - }, + tensor::{ops::TensorOp, TensorCpu, TensorError, TensorGpu, TensorInto}, }; pub async fn softmax_one( @@ -17,12 +14,7 @@ pub async fn softmax_one( let tensor: TensorGpu<_, _> = input.transfer_into(context); let op = TensorOp::softmax(&tensor)?; - - let mut encoder = context.device.create_command_encoder(&Default::default()); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&op)); let output = tensor.back().await; Ok(output) @@ -35,7 +27,6 @@ pub async fn softmax( let mut tensors = Vec::with_capacity(input.len()); let mut ops = Vec::with_capacity(input.len()); - let mut encoder = context.device.create_command_encoder(&Default::default()); for input in input.into_iter() { let tensor: TensorGpu<_, _> = input.transfer_into(context); if tensor.size() > 0 { @@ -43,12 +34,7 @@ pub async fn softmax( } tensors.push(tensor); } - - let ops = TensorOp::List(ops); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&TensorOp::List(ops))); let mut output = Vec::with_capacity(tensors.len()); for tensor in tensors.into_iter() { diff --git a/src/runtime/v4.rs b/src/runtime/v4.rs index 9988160..da115f4 100644 --- a/src/runtime/v4.rs +++ b/src/runtime/v4.rs @@ -9,9 +9,7 @@ use web_rwkv_derive::DeserializeSeed; use wgpu::CommandBuffer; use super::{ - infer::{ - InferChunk, InferInfo, InferOutput, InferOutputBatch, InferRedirect, MIN_TOKEN_CHUNK_SIZE, - }, + infer::{InferChunk, InferInfo, InferOutput, InferOutputBatch, InferRedirect}, loader::{Loader, Reader}, model::{AsAny, Build, EmbedDevice, ModelBuilder, ModelInfo, Quant, State as _}, Job, JobBuilder, @@ -22,7 +20,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{Activation, TensorCommand, TensorOp, TensorPass}, + ops::{Activation, TensorCommand, TensorOp}, shape::Shape, DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit, TensorShape, TensorStack, @@ -457,7 +455,7 @@ impl ModelRuntime { } fn turbo(num_token: usize) -> bool { - num_token % MIN_TOKEN_CHUNK_SIZE == 0 + num_token % super::infer::MIN_TOKEN_CHUNK_SIZE == 0 } fn hook_op( @@ -579,6 +577,10 @@ impl JobBuilder for ModelRuntime { let op = build_layer(hooks, frame, layer, index, num_token)?; ops.push(op); + + if (index + 1) % (info.num_layer / super::infer::NUM_LAYER_CHUNK) == 0 { + ops.push(TensorOp::Sep); + } } { @@ -593,16 +595,11 @@ impl JobBuilder for ModelRuntime { ops.push(op); } - let mut encoder = context.device.create_command_encoder(&Default::default()); - { + let commands = { #[cfg(feature = "trace")] let _span = tracing::trace_span!("encode").entered(); - let op = TensorOp::List(ops); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - } - let commands = vec![encoder.finish()]; + context.encode(&TensorOp::List(ops)) + }; Ok(InferJob { commands, diff --git a/src/runtime/v5.rs b/src/runtime/v5.rs index e5394b3..18475f1 100644 --- a/src/runtime/v5.rs +++ b/src/runtime/v5.rs @@ -9,9 +9,7 @@ use web_rwkv_derive::DeserializeSeed; use wgpu::CommandBuffer; use super::{ - infer::{ - InferChunk, InferInfo, InferOutput, InferOutputBatch, InferRedirect, MIN_TOKEN_CHUNK_SIZE, - }, + infer::{InferChunk, InferInfo, InferOutput, InferOutputBatch, InferRedirect}, loader::{Loader, Reader}, model::{AsAny, Build, EmbedDevice, ModelBuilder, ModelInfo, Quant, State as _}, Job, JobBuilder, @@ -22,7 +20,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{Activation, TensorCommand, TensorOp, TensorPass}, + ops::{Activation, TensorCommand, TensorOp}, shape::{Shape, TensorDimension}, DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit, TensorInto, TensorReshape, TensorShape, TensorStack, @@ -459,7 +457,7 @@ impl super::model::ModelRuntime for ModelRuntime { } fn turbo(num_token: usize) -> bool { - num_token % MIN_TOKEN_CHUNK_SIZE == 0 + num_token % super::infer::MIN_TOKEN_CHUNK_SIZE == 0 } fn hook_op( @@ -582,6 +580,10 @@ impl JobBuilder for ModelRuntime { let op = build_layer(hooks, frame, layer, index, num_token, head_size)?; ops.push(op); + + if (index + 1) % (info.num_layer / super::infer::NUM_LAYER_CHUNK) == 0 { + ops.push(TensorOp::Sep); + } } { @@ -596,16 +598,11 @@ impl JobBuilder for ModelRuntime { ops.push(op); } - let mut encoder = context.device.create_command_encoder(&Default::default()); - { + let commands = { #[cfg(feature = "trace")] let _span = tracing::trace_span!("encode").entered(); - let op = TensorOp::List(ops); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - } - let commands = vec![encoder.finish()]; + context.encode(&TensorOp::List(ops)) + }; Ok(InferJob { commands, @@ -1071,8 +1068,7 @@ pub async fn read_state( let head_size = info.num_emb / info.num_head; let data: TensorGpu = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]); - let mut encoder = context.device.create_command_encoder(&Default::default()); - + let mut ops = vec![]; for layer in 0..info.num_layer { let matrix = loader .load_matrix_f16(format!("blocks.{layer}.att.time_state")) @@ -1084,19 +1080,15 @@ pub async fn read_state( Dimension(1), Auto, )?; - let ops = vec![ + ops.append(&mut vec![ TensorOp::transpose(matrix.view(.., .., .., ..)?, state.view(.., .., .., ..)?)?, TensorOp::blit( reshaped.view(.., .., .., ..)?, data.view(.., 1..head_size + 1, layer, ..)?, )?, - ]; - let ops = TensorOp::List(ops); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); + ]); } + context.queue.submit(context.encode(&TensorOp::List(ops))); - context.queue.submit(Some(encoder.finish())); Ok(data.back().await) } diff --git a/src/runtime/v6.rs b/src/runtime/v6.rs index bd48673..f9d405f 100644 --- a/src/runtime/v6.rs +++ b/src/runtime/v6.rs @@ -9,9 +9,7 @@ use web_rwkv_derive::DeserializeSeed; use wgpu::CommandBuffer; use super::{ - infer::{ - InferChunk, InferInfo, InferOutput, InferOutputBatch, InferRedirect, MIN_TOKEN_CHUNK_SIZE, - }, + infer::{InferChunk, InferInfo, InferOutput, InferOutputBatch, InferRedirect}, loader::{Loader, Reader}, model::{AsAny, Build, EmbedDevice, ModelBuilder, ModelInfo, Quant, State as _}, Job, JobBuilder, @@ -22,7 +20,7 @@ use crate::{ tensor::{ kind::ReadWrite, matrix::Matrix, - ops::{Activation, TensorCommand, TensorOp, TensorPass}, + ops::{Activation, TensorCommand, TensorOp}, shape::{Shape, TensorDimension}, DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit, TensorInto, TensorReshape, TensorShape, TensorStack, @@ -489,7 +487,7 @@ impl super::model::ModelRuntime for ModelRuntime { } fn turbo(num_token: usize) -> bool { - num_token % MIN_TOKEN_CHUNK_SIZE == 0 + num_token % super::infer::MIN_TOKEN_CHUNK_SIZE == 0 } fn hook_op( @@ -612,6 +610,10 @@ impl JobBuilder for ModelRuntime { let op = build_layer(hooks, frame, layer, index, num_token, head_size)?; ops.push(op); + + if (index + 1) % (info.num_layer / super::infer::NUM_LAYER_CHUNK) == 0 { + ops.push(TensorOp::Sep); + } } { @@ -626,16 +628,11 @@ impl JobBuilder for ModelRuntime { ops.push(op); } - let mut encoder = context.device.create_command_encoder(&Default::default()); - { + let commands = { #[cfg(feature = "trace")] let _span = tracing::trace_span!("encode").entered(); - let op = TensorOp::List(ops); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - } - let commands = vec![encoder.finish()]; + context.encode(&TensorOp::List(ops)) + }; Ok(InferJob { commands, @@ -1032,7 +1029,6 @@ impl Build for ModelBuilder { let time_mix_r = loader.load_vector_f16(format!("{att}.time_mix_r")).await?; let time_mix_g = loader.load_vector_f16(format!("{att}.time_mix_g")).await?; - let mut encoder = context.device.create_command_encoder(&Default::default()); let ops = TensorOp::List(vec![ TensorOp::blit( time_mix_w.view(.., .., .., ..)?, @@ -1055,12 +1051,7 @@ impl Build for ModelBuilder { time_mix.view(.., .., 4, ..)?, )?, ]); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&ops)); time_mix }; @@ -1183,8 +1174,7 @@ pub async fn read_state( let head_size = info.num_emb / info.num_head; let data: TensorGpu = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]); - let mut encoder = context.device.create_command_encoder(&Default::default()); - + let mut ops = vec![]; for layer in 0..info.num_layer { let matrix = loader .load_matrix_f16(format!("blocks.{layer}.att.time_state")) @@ -1196,19 +1186,15 @@ pub async fn read_state( Dimension(1), Auto, )?; - let ops = vec![ + ops.append(&mut vec![ TensorOp::transpose(matrix.view(.., .., .., ..)?, state.view(.., .., .., ..)?)?, TensorOp::blit( reshaped.view(.., .., .., ..)?, data.view(.., 1..head_size + 1, layer, ..)?, )?, - ]; - let ops = TensorOp::List(ops); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); + ]); } + context.queue.submit(context.encode(&TensorOp::List(ops))); - context.queue.submit(Some(encoder.finish())); Ok(data.back().await) } diff --git a/src/tensor/matrix.rs b/src/tensor/matrix.rs index 876f038..0b14bd2 100644 --- a/src/tensor/matrix.rs +++ b/src/tensor/matrix.rs @@ -7,7 +7,7 @@ use crate::{ num::Float, tensor::{ kind::{ReadWrite, Uniform}, - ops::{TensorOp, TensorPass}, + ops::TensorOp, shape::Shape, TensorError, TensorGpu, TensorGpuView, TensorShape, }, @@ -115,14 +115,7 @@ impl Matrix { )); let op = TensorOp::quantize_mat_int8(matrix, &m, &w)?; - - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&op)); Ok(Matrix::Int8 { w, m }) } @@ -144,14 +137,7 @@ impl Matrix { let m = context.tensor_init(absmax_shape); let op = TensorOp::quantize_mat_nf4(matrix, &q, &m, &w)?; - - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&op); - drop(pass); - - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&op)); Ok(Matrix::NF4 { w, q, m }) } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 36d144b..1ce227b 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -7,7 +7,6 @@ use wgpu::{BindingResource, Buffer, BufferBinding, BufferUsages}; use self::{ kind::{Kind, ReadWrite, Uniform}, - ops::TensorCommand, shape::{IntoBytes, Shape, TensorAxis, TensorDimension, TensorSlice}, }; use crate::{ @@ -843,12 +842,11 @@ impl DeepClone for TensorGpu { fn deep_clone(&self) -> Self { let context = &self.context; let shape = self.shape; - let cloned = context.tensor_init(shape); + let size = shape.len() as u64; + let cloned: TensorGpu<_, _> = context.tensor_init(shape); let mut encoder = context.device.create_command_encoder(&Default::default()); - encoder - .copy_tensor(self, &cloned) - .expect("tensor deep clone"); + encoder.copy_buffer_to_buffer(&self.buffer, 0, &cloned.buffer, 0, size); context.queue.submit(Some(encoder.finish())); cloned diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 89a953f..7d59ae4 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -1,7 +1,9 @@ use std::{hash::Hash, sync::Arc}; use half::f16; -use wgpu::{BindGroup, BindGroupDescriptor, BindGroupEntry, CommandEncoder, ComputePass}; +use wgpu::{ + BindGroup, BindGroupDescriptor, BindGroupEntry, CommandBuffer, CommandEncoder, ComputePass, +}; use super::{ kind::{Kind, ReadWrite, Uniform}, @@ -72,28 +74,71 @@ impl TensorCommand for CommandEncoder { } } -pub trait TensorPass<'a> { - fn execute_tensor_op(&mut self, op: &'a TensorOp); -} +impl crate::context::Context { + pub fn encode(&self, op: &TensorOp) -> Vec { + struct Atom<'a> { + pipeline: &'a CachedPipeline, + bindings: &'a [BindGroup], + dispatch: &'a [u32; 3], + } -impl<'b, 'a: 'b> TensorPass<'a> for ComputePass<'b> { - fn execute_tensor_op(&mut self, op: &'a TensorOp) { - match op { - TensorOp::Atom { + fn dispatch<'b, 'a: 'b>( + pass: &'b mut ComputePass<'a>, + Atom { pipeline, bindings, dispatch, - } => { - self.set_pipeline(&pipeline.pipeline); - for (index, bind_group) in bindings.iter().enumerate() { - self.set_bind_group(index as u32, bind_group, &[]) - } - self.dispatch_workgroups(dispatch[0], dispatch[1], dispatch[2]); + }: Atom<'a>, + ) { + pass.set_pipeline(&pipeline.pipeline); + for (index, bind) in bindings.iter().enumerate() { + pass.set_bind_group(index as u32, bind, &[]); } - TensorOp::List(ops) => { - ops.iter().for_each(|op| self.execute_tensor_op(op)); + pass.dispatch_workgroups(dispatch[0], dispatch[1], dispatch[2]); + } + + fn flatten<'b, 'a: 'b>( + commands: &'b mut Vec>>, + passes: &'b mut Vec>, + op: &'a TensorOp, + ) { + match op { + TensorOp::Atom { + pipeline, + bindings, + dispatch, + } => passes.push(Atom { + pipeline, + bindings, + dispatch, + }), + TensorOp::List(ops) => ops.iter().for_each(|op| flatten(commands, passes, op)), + TensorOp::Sep => { + let mut temp = vec![]; + std::mem::swap(&mut temp, passes); + commands.push(temp); + } } } + + let mut commands = vec![]; + let mut passes = vec![]; + flatten(&mut commands, &mut passes, op); + commands.push(passes); + + commands + .into_iter() + .filter(|atoms| !atoms.is_empty()) + .map(|atoms| { + let mut encoder = self.device.create_command_encoder(&Default::default()); + let mut pass = encoder.begin_compute_pass(&Default::default()); + for atom in atoms { + dispatch(&mut pass, atom); + } + drop(pass); + encoder.finish() + }) + .collect() } } @@ -197,6 +242,7 @@ pub enum TensorOp { dispatch: [u32; 3], }, List(Vec), + Sep, } impl TensorOp { @@ -2407,7 +2453,7 @@ mod tests { use wgpu::{Instance, PowerPreference}; // use wgpu_profiler::GpuProfiler; - use super::{TensorOp, TensorPass}; + use super::TensorOp; use crate::{ context::{Context, ContextBuilder, InstanceExt}, tensor::{ops::Activation, Shape, TensorGpu}, @@ -2451,11 +2497,7 @@ mod tests { let x_dev: TensorGpu<_, _> = context.tensor_from_data(shape, x.clone())?; let softmax = TensorOp::softmax(&x_dev)?; - let mut encoder = context.device.create_command_encoder(&Default::default()); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&softmax); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&softmax)); let x_host = x_dev.back_in_place().to_vec(); @@ -2517,12 +2559,7 @@ mod tests { // let s_dev = context.tensor_init(shape); let layer_norm = TensorOp::layer_norm(&w_dev, &b_dev, &x_dev, EPS)?; - - let mut encoder = context.device.create_command_encoder(&Default::default()); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&layer_norm); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&layer_norm)); let x_host = x_dev.back_in_place().to_vec(); // let s_host = s_dev.back_in_place().to_vec(); @@ -2534,12 +2571,7 @@ mod tests { TensorOp::recenter(&x_dev)?, TensorOp::rms_norm(&w_dev, &b_dev, &x_dev, EPS)?, ]); - - let mut encoder = context.device.create_command_encoder(&Default::default()); - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&ops)); let x_rms_host = x_dev.back_in_place().to_vec(); @@ -2660,14 +2692,8 @@ mod tests { )?, ]); - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - // profiler.resolve_queries(&mut encoder); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&ops)); let output_host = output_dev.back_in_place(); let output_host = Vec::from(output_host); @@ -2789,13 +2815,7 @@ mod tests { Activation::None, )?, ]); - - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&ops)); let matrix_u8_host = matrix_u8_dev.back_in_place().to_vec(); let output_host = output_dev.back_in_place().to_vec(); @@ -2963,13 +2983,7 @@ mod tests { Activation::None, )?, ]); - - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&ops)); let matrix_u4_host = matrix_u4_dev.back_in_place().to_vec(); let absmax_host = absmax_dev.back_in_place().to_vec(); @@ -3062,13 +3076,7 @@ mod tests { ops.push(TensorOp::blit(input, output.view(.., 2.., 1..2, ..)?)?); let ops = TensorOp::List(ops); - - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&ops)); let output_host = output.back_in_place(); let output_host = Vec::from(output_host); @@ -3099,13 +3107,7 @@ mod tests { let input: TensorGpu<_, _> = context.tensor_from_data([4, 3, 2, 1], input)?; let ops = TensorOp::transpose(input.view(.., .., .., ..)?, output.view(.., ..2, .., ..)?)?; - - let mut encoder = context.device.create_command_encoder(&Default::default()); - - let mut pass = encoder.begin_compute_pass(&Default::default()); - pass.execute_tensor_op(&ops); - drop(pass); - context.queue.submit(Some(encoder.finish())); + context.queue.submit(context.encode(&ops)); let output_host = output.back_in_place(); let output_host: Vec = Vec::from(output_host);