Skip to content

Commit

Permalink
Format op namings.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Dec 14, 2023
1 parent 13c6b8f commit 520b355
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 48 deletions.
19 changes: 12 additions & 7 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,15 @@ impl<'a> ContextBuilder<'a> {
None,
)
.with_pipeline(
"token_shift",
"token_shift_fp16",
include_str!("shaders/token_shift.wgsl"),
"token_shift",
"token_shift_fp16",
None,
)
.with_pipeline(
"token_shift_rev",
"token_shift_rev_fp16",
include_str!("shaders/token_shift.wgsl"),
"token_shift_rev",
"token_shift_rev_fp16",
None,
)
.with_pipeline(
Expand All @@ -269,8 +269,8 @@ impl<'a> ContextBuilder<'a> {
None,
)
.with_pipeline(
"time_mix",
include_str!("shaders/time_mix.wgsl"),
"time_mix_v4",
include_str!("shaders/time_mix_v4.wgsl"),
"time_mix",
None,
)
Expand All @@ -286,7 +286,12 @@ impl<'a> ContextBuilder<'a> {
"time_mix",
None,
)
.with_pipeline("add", include_str!("shaders/add.wgsl"), "add", None)
.with_pipeline(
"add_fp32",
include_str!("shaders/add.wgsl"),
"add_fp32",
None,
)
.with_pipeline(
"add_fp16",
include_str!("shaders/add.wgsl"),
Expand Down
16 changes: 8 additions & 8 deletions src/model/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,23 +736,23 @@ impl ModelRun for Model<'_> {
&layer.att_layer_norm.b,
&buffer.att_x,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_k.view(.., .., .., ..)?,
&buffer.att_x,
state.att(index)?,
&buffer.att_kx,
false,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_v.view(.., .., .., ..)?,
&buffer.att_x,
state.att(index)?,
&buffer.att_vx,
false,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_r.view(.., .., .., ..)?,
&buffer.att_x,
Expand All @@ -778,7 +778,7 @@ impl ModelRun for Model<'_> {
buffer.att_r.view(.., .., .., ..)?,
turbo,
)?,
TensorOp::time_mix(
TensorOp::time_mix_v4(
&buffer.cursors,
&layer.att.time_decay,
&layer.att.time_first,
Expand All @@ -793,7 +793,7 @@ impl ModelRun for Model<'_> {
buffer.att_x.view(.., .., .., ..)?,
buffer.att_o.view(.., .., .., ..)?,
)?,
TensorOp::add(
TensorOp::add_fp32(
buffer.input.view(.., .., .., ..)?,
buffer.att_o.view(.., .., .., ..)?,
)?,
Expand All @@ -811,15 +811,15 @@ impl ModelRun for Model<'_> {
&layer.ffn_layer_norm.b,
&buffer.ffn_x,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.ffn.time_mix_k.view(.., .., .., ..)?,
&buffer.ffn_x,
state.ffn(index)?,
&buffer.ffn_kx,
false,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.ffn.time_mix_r.view(.., .., .., ..)?,
&buffer.ffn_x,
Expand Down Expand Up @@ -853,7 +853,7 @@ impl ModelRun for Model<'_> {
&buffer.ffn_x,
state.ffn(index)?,
)?,
TensorOp::add(
TensorOp::add_fp32(
buffer.att_o.view(.., .., .., ..)?,
buffer.ffn_x.view(.., .., .., ..)?,
)?,
Expand Down
16 changes: 8 additions & 8 deletions src/model/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -872,31 +872,31 @@ impl ModelRun for Model<'_> {
&layer.att_layer_norm.b,
&buffer.att_x,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_k.view(.., .., .., ..)?,
&buffer.att_x,
state.att(index)?,
&buffer.att_kx,
false,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_v.view(.., .., .., ..)?,
&buffer.att_x,
state.att(index)?,
&buffer.att_vx,
false,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_r.view(.., .., .., ..)?,
&buffer.att_x,
state.att(index)?,
&buffer.att_rx,
false,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_g.view(.., .., .., ..)?,
&buffer.att_x,
Expand Down Expand Up @@ -945,7 +945,7 @@ impl ModelRun for Model<'_> {
buffer.att_x.view(.., .., .., ..)?,
buffer.att_o.view(.., .., .., ..)?,
)?,
TensorOp::add(
TensorOp::add_fp32(
buffer.input.view(.., .., .., ..)?,
buffer.att_o.view(.., .., .., ..)?,
)?,
Expand All @@ -963,15 +963,15 @@ impl ModelRun for Model<'_> {
&layer.ffn_layer_norm.b,
&buffer.ffn_x,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.ffn.time_mix_k.view(.., .., .., ..)?,
&buffer.ffn_x,
state.ffn(index)?,
&buffer.ffn_kx,
false,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.ffn.time_mix_r.view(.., .., .., ..)?,
&buffer.ffn_x,
Expand Down Expand Up @@ -1005,7 +1005,7 @@ impl ModelRun for Model<'_> {
&buffer.ffn_x,
state.ffn(index)?,
)?,
TensorOp::add(
TensorOp::add_fp32(
buffer.att_o.view(.., .., .., ..)?,
buffer.ffn_x.view(.., .., .., ..)?,
)?,
Expand Down
10 changes: 5 additions & 5 deletions src/model/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ impl ModelRun for Model<'_> {
&layer.att_layer_norm.b,
&buffer.att_x,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.att.time_mix_x.view(.., .., .., ..)?,
&buffer.att_x,
Expand Down Expand Up @@ -1073,7 +1073,7 @@ impl ModelRun for Model<'_> {
buffer.att_x.view(.., .., .., ..)?,
buffer.att_o.view(.., .., .., ..)?,
)?,
TensorOp::add(
TensorOp::add_fp32(
buffer.input.view(.., .., .., ..)?,
buffer.att_o.view(.., .., .., ..)?,
)?,
Expand All @@ -1091,15 +1091,15 @@ impl ModelRun for Model<'_> {
&layer.ffn_layer_norm.b,
&buffer.ffn_x,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.ffn.time_mix_k.view(.., .., .., ..)?,
&buffer.ffn_x,
state.ffn(index)?,
&buffer.ffn_kx,
true,
)?,
TensorOp::token_shift(
TensorOp::token_shift_fp16(
&buffer.cursors,
layer.ffn.time_mix_r.view(.., .., .., ..)?,
&buffer.ffn_x,
Expand Down Expand Up @@ -1133,7 +1133,7 @@ impl ModelRun for Model<'_> {
&buffer.ffn_x,
state.ffn(index)?,
)?,
TensorOp::add(
TensorOp::add_fp32(
buffer.att_o.view(.., .., .., ..)?,
buffer.ffn_x.view(.., .., .., ..)?,
)?,
Expand Down
10 changes: 5 additions & 5 deletions src/shaders/add.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct View {
@group(0) @binding(0) var<uniform> source: View;
@group(0) @binding(1) var<uniform> destination: View;

@group(0) @binding(2) var<storage, read> input: array<vec4<f32>>; // (B, T, C)
@group(0) @binding(2) var<storage, read> input_fp32: array<vec4<f32>>; // (B, T, C)
@group(0) @binding(3) var<storage, read> input_fp16: array<vec2<u32>>; // (B, T, C)
@group(0) @binding(4) var<storage, read_write> output: array<vec4<f32>>; // (B, T, C)

Expand All @@ -23,20 +23,20 @@ fn compute_index(view: View, batch: u32, token: u32, index: u32) -> u32 {
return ((view.offset.z + batch) * view.stride.y + view.offset.y + token) * stride + offset + index;
}

fn fetch_input(batch: u32, token: u32, index: u32) -> vec4<f32> {
fn fetch_input_fp32(batch: u32, token: u32, index: u32) -> vec4<f32> {
let _token = select(token, 0u, source.shape.y == 1u);
return input[compute_index(source, batch, _token, index)];
return input_fp32[compute_index(source, batch, _token, index)];
}

@compute @workgroup_size(128, 1, 1)
fn add(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
fn add_fp32(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let stride = destination.shape.x / 4u;
let index = invocation_id.x;
let token = invocation_id.y;
let batch = invocation_id.z;

if index < stride {
let x = fetch_input(batch, token, index);
let x = fetch_input_fp32(batch, token, index);
let bti = compute_index(destination, batch, token, index);
output[bti] = x + output[bti];
}
Expand Down
File renamed without changes.
14 changes: 7 additions & 7 deletions src/shaders/token_shift.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct Cursor {
@group(0) @binding(1) var<uniform> vx: View; // [C, _, B] / [C, 5L, B]
@group(0) @binding(2) var<storage, read> cursors: array<u32>; // [A]

@group(0) @binding(3) var<storage, read> time_mix: array<vec2<u32>>; // (C) | (A, C)
@group(0) @binding(3) var<storage, read> time_mix_fp16: array<vec2<u32>>; // (C) | (A, C)
@group(0) @binding(4) var<storage, read> time_mix_fp32: array<vec4<f32>>; // (C) | (A, C)

@group(0) @binding(5) var<storage, read> x: array<vec4<f32>>; // (1, A, C)
Expand Down Expand Up @@ -42,13 +42,13 @@ fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
return vec4<f32>(unpack2x16float(x.x), unpack2x16float(x.y));
}

fn fetch_time_mix(stack: u32, index: u32) -> vec4<f32> {
fn fetch_time_mix_fp16(stack: u32, index: u32) -> vec4<f32> {
let token = select(stack, 0u, vt.shape.y == 1u);
return unpack4x16float(time_mix[compute_index(vt, 0u, token, index)]);
return unpack4x16float(time_mix_fp16[compute_index(vt, 0u, token, index)]);
}

@compute @workgroup_size(128, 1, 1)
fn token_shift(@builtin(global_invocation_id) invocation_id: vec3<u32>, @builtin(num_workgroups) num_blocks: vec3<u32>) {
fn token_shift_fp16(@builtin(global_invocation_id) invocation_id: vec3<u32>, @builtin(num_workgroups) num_blocks: vec3<u32>) {
let stride = vx.shape.x / 4u;
let index = invocation_id.x;
let stack = invocation_id.y;
Expand All @@ -60,7 +60,7 @@ fn token_shift(@builtin(global_invocation_id) invocation_id: vec3<u32>, @builtin
}

let bti = stack * stride + index;
let factor = fetch_time_mix(stack, index);
let factor = fetch_time_mix_fp16(stack, index);
if token == 0u {
output[bti] = mix(sx[compute_index(vx, cursor.batch, 0u, index)], x[bti], factor);
} else {
Expand All @@ -69,7 +69,7 @@ fn token_shift(@builtin(global_invocation_id) invocation_id: vec3<u32>, @builtin
}

@compute @workgroup_size(128, 1, 1)
fn token_shift_rev(@builtin(global_invocation_id) invocation_id: vec3<u32>, @builtin(num_workgroups) num_blocks: vec3<u32>) {
fn token_shift_rev_fp16(@builtin(global_invocation_id) invocation_id: vec3<u32>, @builtin(num_workgroups) num_blocks: vec3<u32>) {
let stride = vx.shape.x / 4u;
let index = invocation_id.x;
let stack = invocation_id.y;
Expand All @@ -81,7 +81,7 @@ fn token_shift_rev(@builtin(global_invocation_id) invocation_id: vec3<u32>, @bui
}

let bti = stack * stride + index;
let factor = fetch_time_mix(stack, index);
let factor = fetch_time_mix_fp16(stack, index);
if token == 0u {
output[bti] = mix(x[bti], sx[compute_index(vx, cursor.batch, 0u, index)], factor);
} else {
Expand Down
16 changes: 8 additions & 8 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ impl<'a> TensorOp<'a> {
}

/// Add `input` onto `output`.
pub fn add(
pub fn add_fp32(
input: TensorView<'a, f32>,
output: TensorView<'a, f32>,
) -> Result<Self, TensorError> {
Expand All @@ -642,7 +642,7 @@ impl<'a> TensorOp<'a> {
.or(input.check_shape(shape))?;

let context = &output.tensor.context;
let pipeline = context.pipeline("add")?;
let pipeline = context.pipeline("add_fp32")?;
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
Expand Down Expand Up @@ -723,7 +723,7 @@ impl<'a> TensorOp<'a> {
})
}

pub fn token_shift(
pub fn token_shift_fp16(
cursors: &'a TensorGpu<u32, ReadWrite>,
time_mix: TensorView<'a, f16>,
x: &'a TensorGpu<f32, ReadWrite>,
Expand All @@ -741,8 +741,8 @@ impl<'a> TensorOp<'a> {

let context = &output.context;
let pipeline = match reversed {
true => context.pipeline("token_shift_rev")?,
false => context.pipeline("token_shift")?,
false => context.pipeline("token_shift_fp16")?,
true => context.pipeline("token_shift_rev_fp16")?,
};
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
label: None,
Expand Down Expand Up @@ -804,8 +804,8 @@ impl<'a> TensorOp<'a> {

let context = &output.context;
let pipeline = match reversed {
true => context.pipeline("token_shift_rev_fp32")?,
false => context.pipeline("token_shift_fp32")?,
true => context.pipeline("token_shift_rev_fp32")?,
};
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
label: None,
Expand Down Expand Up @@ -850,7 +850,7 @@ impl<'a> TensorOp<'a> {
}

#[allow(clippy::too_many_arguments)]
pub fn time_mix(
pub fn time_mix_v4(
cursors: &'a TensorGpu<u32, ReadWrite>,
time_decay: &'a TensorGpu<f32, ReadWrite>,
time_first: &'a TensorGpu<f32, ReadWrite>,
Expand All @@ -871,7 +871,7 @@ impl<'a> TensorOp<'a> {
state.check_shape(Shape::new(shape[0], 4, num_batch, 1))?;

let context = &x.context;
let pipeline = context.pipeline("time_mix")?;
let pipeline = context.pipeline("time_mix_v4")?;
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
Expand Down

0 comments on commit 520b355

Please sign in to comment.