Skip to content

Commit

Permalink
Do not quantize twice.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Aug 28, 2023
1 parent 24cd561 commit 43e1beb
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 82 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.2.2"
version = "0.2.3"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
4 changes: 2 additions & 2 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Cow, collections::HashMap, str::FromStr, sync::Arc};

use web_rwkv_derive::{Deref, Id};
use web_rwkv_derive::{Deref, DerefMut, Id};
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Adapter, Backends, BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, BufferUsages,
Expand Down Expand Up @@ -67,7 +67,7 @@ impl Instance {
}
}

#[derive(Debug, Clone, Copy, Deref, Id, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, Deref, DerefMut, Id, PartialEq, Eq, Hash)]
pub struct ContextId(usize);

#[derive(Debug)]
Expand Down
91 changes: 37 additions & 54 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use bitflags::bitflags;
use derive_getters::Getters;
use half::f16;
use safetensors::SafeTensors;
use web_rwkv_derive::{Deref, DerefMut};
use wgpu::{CommandEncoderDescriptor, ComputePassDescriptor};

use crate::{
Expand Down Expand Up @@ -71,10 +72,10 @@ enum Matrix<'a> {
Fp16(TensorGpu<'a, f16, ReadWrite>),
Int8 {
w: Box<TensorGpu<'a, u8, ReadWrite>>,
mx: Box<TensorGpu<'a, f16, ReadWrite>>,
rx: Box<TensorGpu<'a, f16, ReadWrite>>,
my: Box<TensorGpu<'a, f16, ReadWrite>>,
ry: Box<TensorGpu<'a, f16, ReadWrite>>,
mx: Box<TensorGpu<'a, f32, ReadWrite>>,
rx: Box<TensorGpu<'a, f32, ReadWrite>>,
my: Box<TensorGpu<'a, f32, ReadWrite>>,
ry: Box<TensorGpu<'a, f32, ReadWrite>>,
},
}

Expand Down Expand Up @@ -237,11 +238,8 @@ impl<'a> ModelBuffer<'a> {
}
}

#[derive(Debug, Clone)]
pub struct ModelState<'a> {
pub context: &'a Context,
pub state: TensorGpu<'a, f32, ReadWrite>,
}
#[derive(Debug, Clone, Deref, DerefMut)]
pub struct ModelState<'a>(pub TensorGpu<'a, f32, ReadWrite>);

impl<'a> ModelState<'a> {
pub fn new(context: &'a Context, info: &ModelInfo, num_batches: usize) -> Self {
Expand Down Expand Up @@ -269,88 +267,74 @@ impl<'a> ModelState<'a> {
data,
)
.unwrap();
Self { context, state }
Self(state)
}

pub fn load(&self, backed: &BackedState<'a, '_>) -> Result<(), TensorError> {
self.state.load(&backed.state)
self.0.load(backed)
}

fn att(&self, layer: usize) -> Result<TensorView<f32>, TensorError> {
let start = 5 * layer;
let end = start + 4;
self.state.as_view((.., start..end, ..))
self.as_view((.., start..end, ..))
}

fn ffn(&self, layer: usize) -> Result<TensorView<f32>, TensorError> {
let start = 5 * layer + 4;
self.state.as_view((.., start..=start, ..))
self.as_view((.., start..=start, ..))
}
}

#[derive(Debug, Clone)]
pub struct BackedState<'a, 'b> {
pub context: &'a Context,
pub state: TensorCpu<'a, 'b, f32>,
}
#[derive(Debug, Clone, Deref, DerefMut)]
pub struct BackedState<'a, 'b>(pub TensorCpu<'a, 'b, f32>);

impl<'a, 'b> BackedState<'a, 'b> {
pub fn repeat(self, repeat: usize) -> Self {
let BackedState { context, state } = self;
let state = state.repeat(2, repeat);
Self { context, state }
let state = self.0.repeat(2, repeat);
Self(state)
}

pub fn take(self, batch: usize) -> Result<Self, TensorError> {
let state = self.state.into_slice((.., .., batch))?;
Ok(Self { state, ..self })
let state = self.0.into_slice((.., .., batch))?;
Ok(Self(state))
}

pub fn split(self) -> Vec<Self> {
if self.state.shape()[2] <= 1 {
if self.shape()[2] <= 1 {
return vec![self];
}
let Self { context, state } = self;
state
.split()
.into_iter()
.map(|state| Self { context, state })
.collect()
self.0.split().into_iter().map(Self).collect()
}

pub fn concat(batches: Vec<Self>) -> Result<Self, TensorError> {
if batches.is_empty() {
return Err(TensorError::Empty);
}
let context = batches[0].context;
let states: Vec<_> = batches.into_iter().map(|batch| batch.state).collect();
Ok(Self {
context,
state: TensorCpu::concat(states)?,
})
let states: Vec<_> = batches.into_iter().map(|batch| batch.0).collect();
Ok(Self(TensorCpu::concat(states)?))
}
}

impl<'a> From<ModelState<'a>> for BackedState<'a, '_> {
fn from(value: ModelState<'a>) -> Self {
let ModelState { context, state } = value;
let map = context.init_tensor(state.shape());
let context = value.context;
let map = context.init_tensor(value.shape());
let mut encoder = context
.device
.create_command_encoder(&CommandEncoderDescriptor::default());
encoder.copy_tensor(&state, &map).unwrap();
encoder.copy_tensor(&value, &map).unwrap();
context.queue.submit(Some(encoder.finish()));

let state = TensorCpu::from(map);
Self { context, state }
Self(state)
}
}

impl<'a> From<BackedState<'a, '_>> for ModelState<'a> {
fn from(value: BackedState<'a, '_>) -> Self {
let BackedState { context, state } = value;
let state = TensorGpu::from(state);
Self { context, state }
let state = TensorGpu::from(value.0);
Self(state)
}
}

Expand Down Expand Up @@ -466,25 +450,24 @@ impl<'a, 'b> ModelBuilder<'a, 'b> {
context.tensor_from_data(shape, bytemuck::cast_slice(tensor.data()))?;
let shape = matrix.shape();

let mx_f32 = context.init_tensor(Shape::new(shape[0], 1, 1));
let rx_f32 = context.init_tensor(Shape::new(shape[0], 1, 1));
let my_f32 = context.init_tensor(Shape::new(shape[1], 1, 1));
let ry_f32 = context.init_tensor(Shape::new(shape[1], 1, 1));
// let mx_f32 = context.init_tensor(Shape::new(shape[0], 1, 1));
// let rx_f32 = context.init_tensor(Shape::new(shape[0], 1, 1));
// let my_f32 = context.init_tensor(Shape::new(shape[1], 1, 1));
// let ry_f32 = context.init_tensor(Shape::new(shape[1], 1, 1));

let w = Box::new(context.init_tensor(matrix.shape()));

let mut ops =
TensorOp::quantize_mat_int8(&matrix, &mx_f32, &rx_f32, &my_f32, &ry_f32, &w)?;

let mx = Box::new(context.init_tensor(Shape::new(shape[0], 1, 1)));
let rx = Box::new(context.init_tensor(Shape::new(shape[0], 1, 1)));
let my = Box::new(context.init_tensor(Shape::new(shape[1], 1, 1)));
let ry = Box::new(context.init_tensor(Shape::new(shape[1], 1, 1)));

ops.push(TensorOp::quantize_vec_fp16(&mx_f32, &mx)?);
ops.push(TensorOp::quantize_vec_fp16(&rx_f32, &rx)?);
ops.push(TensorOp::quantize_vec_fp16(&my_f32, &my)?);
ops.push(TensorOp::quantize_vec_fp16(&ry_f32, &ry)?);
let ops = TensorOp::quantize_mat_int8(&matrix, &mx, &rx, &my, &ry, &w)?;

// ops.push(TensorOp::quantize_vec_fp16(&mx_f32, &mx)?);
// ops.push(TensorOp::quantize_vec_fp16(&rx_f32, &rx)?);
// ops.push(TensorOp::quantize_vec_fp16(&my_f32, &my)?);
// ops.push(TensorOp::quantize_vec_fp16(&ry_f32, &ry)?);

let mut encoder = context
.device
Expand Down
20 changes: 12 additions & 8 deletions src/shaders/matmul_int8.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ struct View {
@group(0) @binding(2) var<uniform> destination: View; // [R, T, B]

@group(0) @binding(3) var<storage, read> matrix: array<u32>; // (R, C)
@group(0) @binding(4) var<storage, read> mx: array<vec2<u32>>; // (C)
@group(0) @binding(5) var<storage, read> rx: array<vec2<u32>>; // (C)
@group(0) @binding(6) var<storage, read> my: array<vec2<u32>>; // (R)
@group(0) @binding(7) var<storage, read> ry: array<vec2<u32>>; // (R)
@group(0) @binding(4) var<storage, read> mx: array<vec4<f32>>; // (C)
@group(0) @binding(5) var<storage, read> rx: array<vec4<f32>>; // (C)
@group(0) @binding(6) var<storage, read> my: array<vec4<f32>>; // (R)
@group(0) @binding(7) var<storage, read> ry: array<vec4<f32>>; // (R)

@group(0) @binding(8) var<storage, read> input: array<vec4<f32>>; // (B, T, C)
@group(0) @binding(9) var<storage, read_write> output: array<vec4<f32>>; // (B, T, R)
Expand Down Expand Up @@ -53,8 +53,10 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let bb = compute_index(source, batch, token, 0u);
let cb = channel * 4u * stride.x;

let myc = unpack4x16float(my[channel]);
let ryc = unpack4x16float(ry[channel]);
// let myc = unpack4x16float(my[channel]);
// let ryc = unpack4x16float(ry[channel]);
let myc = my[channel];
let ryc = ry[channel];

var local_sum = vec4<f32>(0.0);
for (var i = index; i < stride.x; i += BLOCK_SIZE) {
Expand All @@ -64,8 +66,10 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
// read 4 elements from the input
let x = input[bti];

let mxi = unpack4x16float(mx[i]);
let rxi = unpack4x16float(rx[i]);
// let mxi = unpack4x16float(mx[i]);
// let rxi = unpack4x16float(rx[i]);
let mxi = mx[i];
let rxi = rx[i];

// read 4 rows from the matrix, each with 4 unpacked floats, forming a 4x4 sub-block
var m: mat4x4<f32>;
Expand Down
14 changes: 2 additions & 12 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl IntoBytes for View {

#[derive(Debug)]
pub struct Tensor<'a, D: Device, T: Scalar> {
context: &'a Context,
pub context: &'a Context,
shape: Shape,
data: D::Data,
phantom: PhantomData<(D, T)>,
Expand Down Expand Up @@ -197,11 +197,6 @@ impl<D: Device, T: Scalar> Tensor<'_, D, T> {
index * T::size()
}

#[inline]
pub fn context(&self) -> &Context {
self.context
}

#[inline]
pub fn data(&self) -> &D::Data {
&self.data
Expand Down Expand Up @@ -493,7 +488,7 @@ impl<'a, 'b, T: Scalar> TensorCpu<'a, 'b, T> {

#[derive(Debug, Clone)]
pub struct TensorView<'a, T: Scalar> {
context: &'a Context,
pub context: &'a Context,
view: View,
data: TensorBuffer,
phantom: PhantomData<T>,
Expand Down Expand Up @@ -532,11 +527,6 @@ impl<'a, 'b, T: Scalar> TensorExt<'a, 'b, T> for TensorView<'a, T> {
}

impl<'a, T: Scalar> TensorView<'a, T> {
#[inline]
pub fn context(&self) -> &Context {
self.context
}

#[inline]
pub fn data(&self) -> &TensorBuffer {
&self.data
Expand Down
10 changes: 5 additions & 5 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl<'a> TensorOp<'a> {
/// Softmax operator applied on `x`.
pub fn softmax(x: &'a TensorGpu<f32, ReadWrite>) -> Result<Self, TensorError> {
let shape = x.shape();
let context = x.context();
let context = x.context;
let pipeline = context.pipeline("softmax")?;
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
label: None,
Expand Down Expand Up @@ -186,10 +186,10 @@ impl<'a> TensorOp<'a> {
/// - `output` shape: `[R, T, B]`.
pub fn matmul_int8(
matrix: &'a TensorGpu<u8, ReadWrite>,
mx: &'a TensorGpu<f16, ReadWrite>,
rx: &'a TensorGpu<f16, ReadWrite>,
my: &'a TensorGpu<f16, ReadWrite>,
ry: &'a TensorGpu<f16, ReadWrite>,
mx: &'a TensorGpu<f32, ReadWrite>,
rx: &'a TensorGpu<f32, ReadWrite>,
my: &'a TensorGpu<f32, ReadWrite>,
ry: &'a TensorGpu<f32, ReadWrite>,
input: TensorView<'a, f32>,
output: TensorView<'a, f32>,
) -> Result<Self, TensorError> {
Expand Down

0 comments on commit 43e1beb

Please sign in to comment.