Skip to content

Commit

Permalink
Do not split head from now on.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jan 1, 2024
1 parent 1856547 commit 026b583
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 103 deletions.
6 changes: 3 additions & 3 deletions examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fn sample(probs: Vec<f32>, _top_p: f32) -> u16 {
token as u16
}

async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) -> Result<Context> {
async fn create_context(info: &ModelInfo) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
Expand All @@ -81,7 +81,7 @@ async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) ->
.await?;
let context = ContextBuilder::new(adapter)
.with_default_pipelines()
.with_auto_limits(info, embed_device.unwrap_or_default().into())
.with_auto_limits(info)
.build()
.await?;
println!("{:#?}", context.adapter.get_info());
Expand Down Expand Up @@ -163,7 +163,7 @@ async fn run(cli: Cli) -> Result<()> {
let info = Loader::info(&map)?;
println!("{:#?}", info);

let context = create_context(&info, cli.embed_device).await?;
let context = create_context(&info).await?;

match info.version {
ModelVersion::V4 => {
Expand Down
6 changes: 3 additions & 3 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl Sampler {
}
}

async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) -> Result<Context> {
async fn create_context(info: &ModelInfo) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
Expand All @@ -93,7 +93,7 @@ async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) ->
.await?;
let context = ContextBuilder::new(adapter)
.with_default_pipelines()
.with_auto_limits(info, embed_device.unwrap_or_default().into())
.with_auto_limits(info)
.build()
.await?;
println!("{:#?}", context.adapter.get_info());
Expand Down Expand Up @@ -185,7 +185,7 @@ async fn run(cli: Cli) -> Result<()> {
let info = Loader::info(&map)?;
println!("{:#?}", info);

let context = create_context(&info, cli.embed_device).await?;
let context = create_context(&info).await?;

match info.version {
ModelVersion::V4 => {
Expand Down
6 changes: 3 additions & 3 deletions examples/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 {
token as u16
}

async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) -> Result<Context> {
async fn create_context(info: &ModelInfo) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
Expand All @@ -70,7 +70,7 @@ async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) ->
.await?;
let context = ContextBuilder::new(adapter)
.with_default_pipelines()
.with_auto_limits(info, embed_device.unwrap_or_default().into())
.with_auto_limits(info)
.build()
.await?;
println!("{:#?}", context.adapter.get_info());
Expand Down Expand Up @@ -137,7 +137,7 @@ async fn run(cli: Cli) -> Result<()> {
let info = Loader::info(&map)?;
println!("{:#?}", info);

let context = create_context(&info, cli.embed_device).await?;
let context = create_context(&info).await?;

match info.version {
ModelVersion::V4 => {
Expand Down
9 changes: 3 additions & 6 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use wgpu::{
};

use crate::{
model::{EmbedDevice, ModelInfo},
model::ModelInfo,
tensor::{
cache::ResourceCache,
shape::{IntoBytes, Shape},
Expand Down Expand Up @@ -188,11 +188,8 @@ impl<'a> ContextBuilder<'a> {
}

/// Compute the limits automatically based on given model build info.
pub fn with_auto_limits(mut self, info: &ModelInfo, embed_device: EmbedDevice) -> Self {
let max_buffer_size = match embed_device {
EmbedDevice::Cpu => info.max_non_head_buffer_size(),
EmbedDevice::Gpu => info.max_buffer_size(),
};
pub fn with_auto_limits(mut self, info: &ModelInfo) -> Self {
let max_buffer_size = info.max_buffer_size();
self.limits.max_buffer_size = (256 << 20).max(max_buffer_size as u64);
self.limits.max_storage_buffer_binding_size = (128 << 20).max(max_buffer_size as u32);
self
Expand Down
16 changes: 0 additions & 16 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ pub mod v5;
pub mod v6;

pub const RESCALE_LAYER: usize = 6;

pub const MIN_TOKEN_CHUNK_SIZE: usize = 32;
pub const HEAD_CHUNK_SIZES: [usize; 10] = [
0x10000, 0x5000, 0x4000, 0x3000, 0x2000, 0x1800, 0x1600, 0x1400, 0x1200, 0x1000,
];

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelVersion {
Expand Down Expand Up @@ -132,10 +128,7 @@ pub trait ModelBase: Sync {

fn context(&self) -> &Context;
fn info(&self) -> &ModelInfo;

fn token_chunk_size(&self) -> usize;
fn head_chunk_size(&self) -> usize;

fn head_shape(&self, num_batch: usize) -> Shape;
}

Expand Down Expand Up @@ -246,7 +239,6 @@ struct PreparedModelBuilder<'a> {
turbo: bool,
rescale: bool,
token_chunk_size: usize,
head_chunk_size: usize,
}

impl<'a> ModelBuilder<'a> {
Expand Down Expand Up @@ -281,13 +273,6 @@ impl<'a> ModelBuilder<'a> {
.next_power_of_two();
log::info!("token chunk size: {token_chunk_size}");

let max_chunk_size = context.device.limits().max_storage_buffer_binding_size as usize;
let head_chunk_size = HEAD_CHUNK_SIZES
.into_iter()
.find(|&x| info.num_emb * x * f16::size() <= max_chunk_size)
.ok_or(ModelError::NoViableChunkSize)?;
log::info!("head chunk size: {head_chunk_size}");

let rescale = turbo || quant.iter().any(|(_, quant)| matches!(quant, Quant::NF4));

Ok(PreparedModelBuilder {
Expand All @@ -299,7 +284,6 @@ impl<'a> ModelBuilder<'a> {
turbo,
rescale,
token_chunk_size,
head_chunk_size,
})
}

Expand Down
35 changes: 11 additions & 24 deletions src/model/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ pub struct Model<'a> {
rescale: bool,
/// Whether to use fp16 GEMM for matmul computations.
turbo: bool,
/// The head matrix is too big for a storage buffer so it's divided into chunks.
head_chunk_size: usize,
/// To prevent the GPU device from lost, this limits the maximum batch-token it processes one time.
token_chunk_size: usize,

Expand Down Expand Up @@ -102,7 +100,7 @@ struct Embed<'a> {
#[derive(Debug)]
struct Head {
layer_norm: LayerNorm,
w: Vec<TensorGpu<f16, ReadWrite>>,
w: Matrix,
}

/// Runtime buffers.
Expand Down Expand Up @@ -446,7 +444,6 @@ impl<'a> FromBuilder for Model<'a> {
turbo,
rescale,
token_chunk_size,
head_chunk_size,
} = builder.prepare()?;

let embed = Embed {
Expand All @@ -466,7 +463,7 @@ impl<'a> FromBuilder for Model<'a> {
w: loader.load_vector_f16("ln_out.weight")?,
b: loader.load_vector_f16("ln_out.bias")?,
},
w: loader.load_head(head_chunk_size)?,
w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?),
};

context.queue.submit(None);
Expand Down Expand Up @@ -587,7 +584,6 @@ impl<'a> FromBuilder for Model<'a> {
info,
rescale,
turbo,
head_chunk_size,
token_chunk_size,
tensor,
runtime_cache: ResourceCache::new(1),
Expand Down Expand Up @@ -615,11 +611,6 @@ impl ModelBase for Model<'_> {
self.token_chunk_size
}

#[inline]
fn head_chunk_size(&self) -> usize {
self.head_chunk_size
}

#[inline]
fn head_shape(&self, num_batch: usize) -> Shape {
Shape::new(self.info.num_vocab, 1, num_batch, 1)
Expand Down Expand Up @@ -940,22 +931,18 @@ impl ModelRunInner for Model<'_> {
}

if num_header > 0 {
let mut ops = vec![
let ops = TensorOp::List(vec![
hook_op(Hook::PreHead),
TensorOp::layer_norm(&tensor.head.layer_norm.w, &tensor.head.layer_norm.b, head_x)?,
hook_op(Hook::PostHeadLayerNorm),
];

for (chunk, matrix) in tensor.head.w.iter().enumerate() {
let start = chunk * self.head_chunk_size;
let end = start + matrix.shape()[1];
let input = head_x.view(.., .., .., ..)?;
let output = output.head_o.view(start..end, .., .., ..)?;
ops.push(TensorOp::matmul_vec_fp16(matrix, input, output)?);
}

ops.push(hook_op(Hook::PostHead));
let ops = TensorOp::List(ops);
tensor.head.w.matmul_op(
buffer.half_x.view(.., .., .., ..)?,
head_x.view(.., .., .., ..)?,
output.head_o.view(.., .., .., ..)?,
turbo,
)?,
hook_op(Hook::PostHead),
]);

let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
pass.execute_tensor_op(&head_ops);
Expand Down
35 changes: 11 additions & 24 deletions src/model/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ pub struct Model<'a> {
rescale: bool,
/// Whether to use fp16 GEMM for matmul computations.
turbo: bool,
/// The head matrix is too big for a storage buffer so it's divided into chunks.
head_chunk_size: usize,
/// To prevent the GPU device from lost, this limits the maximum batch-token it processes one time.
token_chunk_size: usize,

Expand Down Expand Up @@ -105,7 +103,7 @@ struct Embed<'a> {
#[derive(Debug)]
struct Head {
layer_norm: LayerNorm,
w: Vec<TensorGpu<f16, ReadWrite>>,
w: Matrix,
}

/// Runtime buffers.
Expand Down Expand Up @@ -523,7 +521,6 @@ impl<'a> FromBuilder for Model<'a> {
turbo,
rescale,
token_chunk_size,
head_chunk_size,
} = builder.prepare()?;

let embed = Embed {
Expand All @@ -543,7 +540,7 @@ impl<'a> FromBuilder for Model<'a> {
w: loader.load_vector_f16("ln_out.weight")?,
b: loader.load_vector_f16("ln_out.bias")?,
},
w: loader.load_head(head_chunk_size)?,
w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?),
};

context.queue.submit(None);
Expand Down Expand Up @@ -687,7 +684,6 @@ impl<'a> FromBuilder for Model<'a> {
info,
rescale,
turbo,
head_chunk_size,
token_chunk_size,
tensor,
runtime_cache: ResourceCache::new(1),
Expand Down Expand Up @@ -715,11 +711,6 @@ impl ModelBase for Model<'_> {
self.token_chunk_size
}

#[inline]
fn head_chunk_size(&self) -> usize {
self.head_chunk_size
}

#[inline]
fn head_shape(&self, num_batch: usize) -> Shape {
Shape::new(self.info.num_vocab, 1, num_batch, 1)
Expand Down Expand Up @@ -1097,22 +1088,18 @@ impl ModelRunInner for Model<'_> {
}

if num_header > 0 {
let mut ops = vec![
let ops = TensorOp::List(vec![
hook_op(Hook::PreHead),
TensorOp::layer_norm(&tensor.head.layer_norm.w, &tensor.head.layer_norm.b, head_x)?,
hook_op(Hook::PostHeadLayerNorm),
];

for (chunk, matrix) in tensor.head.w.iter().enumerate() {
let start = chunk * self.head_chunk_size;
let end = start + matrix.shape()[1];
let input = head_x.view(.., .., .., ..)?;
let output = output.head_o.view(start..end, .., .., ..)?;
ops.push(TensorOp::matmul_vec_fp16(matrix, input, output)?);
}

ops.push(hook_op(Hook::PostHead));
let ops = TensorOp::List(ops);
tensor.head.w.matmul_op(
buffer.half_x.view(.., .., .., ..)?,
head_x.view(.., .., .., ..)?,
output.head_o.view(.., .., .., ..)?,
turbo,
)?,
hook_op(Hook::PostHead),
]);

let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
pass.execute_tensor_op(&head_ops);
Expand Down
Loading

0 comments on commit 026b583

Please sign in to comment.