From 520b3550b709e74d810e3bb96a9fc931b4cee388 Mon Sep 17 00:00:00 2001 From: cryscan Date: Fri, 15 Dec 2023 00:47:51 +0800 Subject: [PATCH] Format op namings. --- src/context.rs | 19 ++++++++++++------- src/model/v4.rs | 16 ++++++++-------- src/model/v5.rs | 16 ++++++++-------- src/model/v6.rs | 10 +++++----- src/shaders/add.wgsl | 10 +++++----- .../{time_mix.wgsl => time_mix_v4.wgsl} | 0 src/shaders/token_shift.wgsl | 14 +++++++------- src/tensor/ops.rs | 16 ++++++++-------- 8 files changed, 53 insertions(+), 48 deletions(-) rename src/shaders/{time_mix.wgsl => time_mix_v4.wgsl} (100%) diff --git a/src/context.rs b/src/context.rs index 83f1565..fd4b9ac 100644 --- a/src/context.rs +++ b/src/context.rs @@ -245,15 +245,15 @@ impl<'a> ContextBuilder<'a> { None, ) .with_pipeline( - "token_shift", + "token_shift_fp16", include_str!("shaders/token_shift.wgsl"), - "token_shift", + "token_shift_fp16", None, ) .with_pipeline( - "token_shift_rev", + "token_shift_rev_fp16", include_str!("shaders/token_shift.wgsl"), - "token_shift_rev", + "token_shift_rev_fp16", None, ) .with_pipeline( @@ -269,8 +269,8 @@ impl<'a> ContextBuilder<'a> { None, ) .with_pipeline( - "time_mix", - include_str!("shaders/time_mix.wgsl"), + "time_mix_v4", + include_str!("shaders/time_mix_v4.wgsl"), "time_mix", None, ) @@ -286,7 +286,12 @@ impl<'a> ContextBuilder<'a> { "time_mix", None, ) - .with_pipeline("add", include_str!("shaders/add.wgsl"), "add", None) + .with_pipeline( + "add_fp32", + include_str!("shaders/add.wgsl"), + "add_fp32", + None, + ) .with_pipeline( "add_fp16", include_str!("shaders/add.wgsl"), diff --git a/src/model/v4.rs b/src/model/v4.rs index dedb616..95f7259 100644 --- a/src/model/v4.rs +++ b/src/model/v4.rs @@ -736,7 +736,7 @@ impl ModelRun for Model<'_> { &layer.att_layer_norm.b, &buffer.att_x, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_k.view(.., .., .., ..)?, &buffer.att_x, @@ -744,7 +744,7 @@ impl ModelRun for Model<'_> { &buffer.att_kx, false, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_v.view(.., .., .., ..)?, &buffer.att_x, @@ -752,7 +752,7 @@ impl ModelRun for Model<'_> { &buffer.att_vx, false, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_r.view(.., .., .., ..)?, &buffer.att_x, @@ -778,7 +778,7 @@ impl ModelRun for Model<'_> { buffer.att_r.view(.., .., .., ..)?, turbo, )?, - TensorOp::time_mix( + TensorOp::time_mix_v4( &buffer.cursors, &layer.att.time_decay, &layer.att.time_first, @@ -793,7 +793,7 @@ impl ModelRun for Model<'_> { buffer.att_x.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, )?, - TensorOp::add( + TensorOp::add_fp32( buffer.input.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, )?, @@ -811,7 +811,7 @@ impl ModelRun for Model<'_> { &layer.ffn_layer_norm.b, &buffer.ffn_x, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.ffn.time_mix_k.view(.., .., .., ..)?, &buffer.ffn_x, @@ -819,7 +819,7 @@ impl ModelRun for Model<'_> { &buffer.ffn_kx, false, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.ffn.time_mix_r.view(.., .., .., ..)?, &buffer.ffn_x, @@ -853,7 +853,7 @@ impl ModelRun for Model<'_> { &buffer.ffn_x, state.ffn(index)?, )?, - TensorOp::add( + TensorOp::add_fp32( buffer.att_o.view(.., .., .., ..)?, buffer.ffn_x.view(.., .., .., ..)?, )?, diff --git a/src/model/v5.rs b/src/model/v5.rs index 0582199..f18e9ca 100644 --- a/src/model/v5.rs +++ b/src/model/v5.rs @@ -872,7 +872,7 @@ impl ModelRun for Model<'_> { &layer.att_layer_norm.b, &buffer.att_x, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_k.view(.., .., .., ..)?, &buffer.att_x, @@ -880,7 +880,7 @@ impl ModelRun for Model<'_> { &buffer.att_kx, false, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_v.view(.., .., .., ..)?, &buffer.att_x, @@ -888,7 +888,7 @@ impl ModelRun for Model<'_> { &buffer.att_vx, false, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_r.view(.., .., .., ..)?, &buffer.att_x, @@ -896,7 +896,7 @@ impl ModelRun for Model<'_> { &buffer.att_rx, false, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_g.view(.., .., .., ..)?, &buffer.att_x, @@ -945,7 +945,7 @@ impl ModelRun for Model<'_> { buffer.att_x.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, )?, - TensorOp::add( + TensorOp::add_fp32( buffer.input.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, )?, @@ -963,7 +963,7 @@ impl ModelRun for Model<'_> { &layer.ffn_layer_norm.b, &buffer.ffn_x, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.ffn.time_mix_k.view(.., .., .., ..)?, &buffer.ffn_x, @@ -971,7 +971,7 @@ impl ModelRun for Model<'_> { &buffer.ffn_kx, false, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.ffn.time_mix_r.view(.., .., .., ..)?, &buffer.ffn_x, @@ -1005,7 +1005,7 @@ impl ModelRun for Model<'_> { &buffer.ffn_x, state.ffn(index)?, )?, - TensorOp::add( + TensorOp::add_fp32( buffer.att_o.view(.., .., .., ..)?, buffer.ffn_x.view(.., .., .., ..)?, )?, diff --git a/src/model/v6.rs b/src/model/v6.rs index 2e8c579..dcacef4 100644 --- a/src/model/v6.rs +++ b/src/model/v6.rs @@ -929,7 +929,7 @@ impl ModelRun for Model<'_> { &layer.att_layer_norm.b, &buffer.att_x, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.att.time_mix_x.view(.., .., .., ..)?, &buffer.att_x, @@ -1073,7 +1073,7 @@ impl ModelRun for Model<'_> { buffer.att_x.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, )?, - TensorOp::add( + TensorOp::add_fp32( buffer.input.view(.., .., .., ..)?, buffer.att_o.view(.., .., .., ..)?, )?, @@ -1091,7 +1091,7 @@ impl ModelRun for Model<'_> { &layer.ffn_layer_norm.b, &buffer.ffn_x, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.ffn.time_mix_k.view(.., .., .., ..)?, &buffer.ffn_x, @@ -1099,7 +1099,7 @@ impl ModelRun for Model<'_> { &buffer.ffn_kx, true, )?, - TensorOp::token_shift( + TensorOp::token_shift_fp16( &buffer.cursors, layer.ffn.time_mix_r.view(.., .., .., ..)?, &buffer.ffn_x, @@ -1133,7 +1133,7 @@ impl ModelRun for Model<'_> { &buffer.ffn_x, state.ffn(index)?, )?, - TensorOp::add( + TensorOp::add_fp32( buffer.att_o.view(.., .., .., ..)?, buffer.ffn_x.view(.., .., .., ..)?, )?, diff --git a/src/shaders/add.wgsl b/src/shaders/add.wgsl index 74e33d9..b9f4be4 100644 --- a/src/shaders/add.wgsl +++ b/src/shaders/add.wgsl @@ -7,7 +7,7 @@ struct View { @group(0) @binding(0) var source: View; @group(0) @binding(1) var destination: View; -@group(0) @binding(2) var input: array>; // (B, T, C) +@group(0) @binding(2) var input_fp32: array>; // (B, T, C) @group(0) @binding(3) var input_fp16: array>; // (B, T, C) @group(0) @binding(4) var output: array>; // (B, T, C) @@ -23,20 +23,20 @@ fn compute_index(view: View, batch: u32, token: u32, index: u32) -> u32 { return ((view.offset.z + batch) * view.stride.y + view.offset.y + token) * stride + offset + index; } -fn fetch_input(batch: u32, token: u32, index: u32) -> vec4 { +fn fetch_input_fp32(batch: u32, token: u32, index: u32) -> vec4 { let _token = select(token, 0u, source.shape.y == 1u); - return input[compute_index(source, batch, _token, index)]; + return input_fp32[compute_index(source, batch, _token, index)]; } @compute @workgroup_size(128, 1, 1) -fn add(@builtin(global_invocation_id) invocation_id: vec3) { +fn add_fp32(@builtin(global_invocation_id) invocation_id: vec3) { let stride = destination.shape.x / 4u; let index = invocation_id.x; let token = invocation_id.y; let batch = invocation_id.z; if index < stride { - let x = fetch_input(batch, token, index); + let x = fetch_input_fp32(batch, token, index); let bti = compute_index(destination, batch, token, index); output[bti] = x + output[bti]; } diff --git a/src/shaders/time_mix.wgsl b/src/shaders/time_mix_v4.wgsl similarity index 100% rename from src/shaders/time_mix.wgsl rename to src/shaders/time_mix_v4.wgsl diff --git a/src/shaders/token_shift.wgsl b/src/shaders/token_shift.wgsl index bd591cd..df08304 100644 --- a/src/shaders/token_shift.wgsl +++ b/src/shaders/token_shift.wgsl @@ -14,7 +14,7 @@ struct Cursor { @group(0) @binding(1) var vx: View; // [C, _, B] / [C, 5L, B] @group(0) @binding(2) var cursors: array; // [A] -@group(0) @binding(3) var time_mix: array>; // (C) | (A, C) +@group(0) @binding(3) var time_mix_fp16: array>; // (C) | (A, C) @group(0) @binding(4) var time_mix_fp32: array>; // (C) | (A, C) @group(0) @binding(5) var x: array>; // (1, A, C) @@ -42,13 +42,13 @@ fn unpack4x16float(x: vec2) -> vec4 { return vec4(unpack2x16float(x.x), unpack2x16float(x.y)); } -fn fetch_time_mix(stack: u32, index: u32) -> vec4 { +fn fetch_time_mix_fp16(stack: u32, index: u32) -> vec4 { let token = select(stack, 0u, vt.shape.y == 1u); - return unpack4x16float(time_mix[compute_index(vt, 0u, token, index)]); + return unpack4x16float(time_mix_fp16[compute_index(vt, 0u, token, index)]); } @compute @workgroup_size(128, 1, 1) -fn token_shift(@builtin(global_invocation_id) invocation_id: vec3, @builtin(num_workgroups) num_blocks: vec3) { +fn token_shift_fp16(@builtin(global_invocation_id) invocation_id: vec3, @builtin(num_workgroups) num_blocks: vec3) { let stride = vx.shape.x / 4u; let index = invocation_id.x; let stack = invocation_id.y; @@ -60,7 +60,7 @@ fn token_shift(@builtin(global_invocation_id) invocation_id: vec3, @builtin } let bti = stack * stride + index; - let factor = fetch_time_mix(stack, index); + let factor = fetch_time_mix_fp16(stack, index); if token == 0u { output[bti] = mix(sx[compute_index(vx, cursor.batch, 0u, index)], x[bti], factor); } else { @@ -69,7 +69,7 @@ fn token_shift(@builtin(global_invocation_id) invocation_id: vec3, @builtin } @compute @workgroup_size(128, 1, 1) -fn token_shift_rev(@builtin(global_invocation_id) invocation_id: vec3, @builtin(num_workgroups) num_blocks: vec3) { +fn token_shift_rev_fp16(@builtin(global_invocation_id) invocation_id: vec3, @builtin(num_workgroups) num_blocks: vec3) { let stride = vx.shape.x / 4u; let index = invocation_id.x; let stack = invocation_id.y; @@ -81,7 +81,7 @@ fn token_shift_rev(@builtin(global_invocation_id) invocation_id: vec3, @bui } let bti = stack * stride + index; - let factor = fetch_time_mix(stack, index); + let factor = fetch_time_mix_fp16(stack, index); if token == 0u { output[bti] = mix(x[bti], sx[compute_index(vx, cursor.batch, 0u, index)], factor); } else { diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 8ab67e0..3bd4f81 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -632,7 +632,7 @@ impl<'a> TensorOp<'a> { } /// Add `input` onto `output`. - pub fn add( + pub fn add_fp32( input: TensorView<'a, f32>, output: TensorView<'a, f32>, ) -> Result { @@ -642,7 +642,7 @@ impl<'a> TensorOp<'a> { .or(input.check_shape(shape))?; let context = &output.tensor.context; - let pipeline = context.pipeline("add")?; + let pipeline = context.pipeline("add_fp32")?; let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor { label: None, layout: &pipeline.get_bind_group_layout(0), @@ -723,7 +723,7 @@ impl<'a> TensorOp<'a> { }) } - pub fn token_shift( + pub fn token_shift_fp16( cursors: &'a TensorGpu, time_mix: TensorView<'a, f16>, x: &'a TensorGpu, @@ -741,8 +741,8 @@ impl<'a> TensorOp<'a> { let context = &output.context; let pipeline = match reversed { - true => context.pipeline("token_shift_rev")?, - false => context.pipeline("token_shift")?, + false => context.pipeline("token_shift_fp16")?, + true => context.pipeline("token_shift_rev_fp16")?, }; let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor { label: None, @@ -804,8 +804,8 @@ impl<'a> TensorOp<'a> { let context = &output.context; let pipeline = match reversed { - true => context.pipeline("token_shift_rev_fp32")?, false => context.pipeline("token_shift_fp32")?, + true => context.pipeline("token_shift_rev_fp32")?, }; let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor { label: None, @@ -850,7 +850,7 @@ impl<'a> TensorOp<'a> { } #[allow(clippy::too_many_arguments)] - pub fn time_mix( + pub fn time_mix_v4( cursors: &'a TensorGpu, time_decay: &'a TensorGpu, time_first: &'a TensorGpu, @@ -871,7 +871,7 @@ impl<'a> TensorOp<'a> { state.check_shape(Shape::new(shape[0], 4, num_batch, 1))?; let context = &x.context; - let pipeline = context.pipeline("time_mix")?; + let pipeline = context.pipeline("time_mix_v4")?; let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor { label: None, layout: &pipeline.get_bind_group_layout(0),