Skip to content

Commit

Permalink
Tensor op execute (#31)
Browse files Browse the repository at this point in the history
* Switch to the new tensor op encoding API.

* Reduce sep frequency.
  • Loading branch information
cryscan authored May 10, 2024
1 parent 0f9ae19 commit d771285
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 472 deletions.
19 changes: 4 additions & 15 deletions examples/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ use web_rwkv::{
v5, Build, BuildFuture, ContextAutoLimits, Model, ModelBuilder, ModelInfo, ModelInput,
ModelOutput, ModelState, ModelVersion, Quant, StateBuilder,
},
tensor::{
kind::ReadWrite,
ops::{TensorOp, TensorPass},
TensorError, TensorGpu, TensorShape,
},
tensor::{kind::ReadWrite, ops::TensorOp, TensorError, TensorGpu, TensorShape},
tokenizer::Tokenizer,
};

Expand Down Expand Up @@ -244,10 +240,8 @@ async fn run(cli: Cli) -> Result<()> {
}

// map the activations into vocab space
let mut encoder = context.device.create_command_encoder(&Default::default());

let tensor = model.tensor();
let ops = TensorOp::List(vec![
let ops = vec![
TensorOp::layer_norm(
&tensor.head.layer_norm.w,
&tensor.head.layer_norm.b,
Expand All @@ -259,13 +253,8 @@ async fn run(cli: Cli) -> Result<()> {
buffer.out.view(.., .., .., ..)?,
Default::default(),
)?,
]);

let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&ops);
drop(pass);

context.queue.submit(Some(encoder.finish()));
];
context.queue.submit(context.encode(&TensorOp::List(ops)));

// for each layer, choose the top 5 tokens
let backed = buffer.out.back().await.to_vec();
Expand Down
95 changes: 37 additions & 58 deletions src/model/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
tensor::{
kind::ReadWrite,
matrix::Matrix,
ops::{TensorCommand, TensorOp, TensorPass},
ops::TensorOp,
shape::{Shape, TensorDimension},
TensorCpu, TensorError, TensorGpu, TensorInit, TensorInto, TensorReshape, TensorShape,
},
Expand Down Expand Up @@ -339,7 +339,7 @@ impl<R: Reader> Loader<R> {
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
.transfer_into(context);

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in self.lora_vectors(name).await? {
let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -353,10 +353,10 @@ impl<R: Reader> Loader<R> {
)?;

let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
self.context.queue.submit(Some(encoder.finish()));

context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(tensor)
}

Expand All @@ -373,7 +373,7 @@ impl<R: Reader> Loader<R> {
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
.transfer_into(context);

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in self.lora_vectors(name).await? {
let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -387,16 +387,13 @@ impl<R: Reader> Loader<R> {
)?;

let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}

let op = TensorOp::opposite_exp(&tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
drop(pass);
ops.push(op);

self.context.queue.submit(Some(encoder.finish()));
context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(tensor)
}

Expand All @@ -414,7 +411,7 @@ impl<R: Reader> Loader<R> {
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
.transfer_into(context);

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in self.lora_vectors(name).await? {
let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -428,16 +425,13 @@ impl<R: Reader> Loader<R> {
)?;

let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}

let op = TensorOp::stable_exp(&tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
drop(pass);
ops.push(op);

self.context.queue.submit(Some(encoder.finish()));
context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(tensor)
}

Expand All @@ -460,7 +454,7 @@ impl<R: Reader> Loader<R> {
.transfer_into(context);
let tensor_f16: TensorGpu<f16, _> = context.tensor_init(tensor_f32.shape());

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in lora {
let factor = vec![lora.alpha, 1.0 - lora.alpha, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -474,19 +468,16 @@ impl<R: Reader> Loader<R> {
)?;

let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}

let op = TensorOp::blit(
tensor_f32.view(.., .., .., ..)?,
tensor_f16.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
drop(pass);
ops.push(op);

context.queue.submit(Some(encoder.finish()));
context.queue.submit(context.encode(&TensorOp::List(ops)));
tensor_f16
};
Ok(tensor)
Expand All @@ -500,7 +491,7 @@ impl<R: Reader> Loader<R> {
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor: TensorGpu<_, _> = TensorCpu::from_reader(tensor)?.transfer_into(context);

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in self.lora_matrices(name.as_ref()).await? {
let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -510,18 +501,16 @@ impl<R: Reader> Loader<R> {
lora.y.view(.., .., .., ..)?,
tensor.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
for lora in self.lora_vectors(name.as_ref()).await? {
let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
context.queue.submit(Some(encoder.finish()));

context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(tensor)
}

Expand All @@ -536,7 +525,7 @@ impl<R: Reader> Loader<R> {
.map(|x| f16::from_f32(discount * x.to_f32()))
.transfer_into(context);

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in self.lora_matrices(name.as_ref()).await? {
let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -546,18 +535,16 @@ impl<R: Reader> Loader<R> {
lora.y.view(.., .., .., ..)?,
tensor.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
for lora in self.lora_vectors(name.as_ref()).await? {
let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
context.queue.submit(Some(encoder.finish()));

context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(tensor)
}

Expand All @@ -571,7 +558,7 @@ impl<R: Reader> Loader<R> {
let tensor = TensorCpu::from_reader(tensor)?;
matrix.load(&tensor)?;

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in self.lora_matrices(name.as_ref()).await? {
let factor = vec![lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -581,18 +568,16 @@ impl<R: Reader> Loader<R> {
lora.y.view(.., .., .., ..)?,
matrix.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
for lora in self.lora_vectors(name.as_ref()).await? {
let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
let op = TensorOp::blend(&factor, &lora.tensor, matrix)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
context.queue.submit(Some(encoder.finish()));

context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(())
}

Expand All @@ -611,7 +596,7 @@ impl<R: Reader> Loader<R> {
.reshape(Full, Full, Dimension(1), Dimension(1))?;
matrix.load(&tensor)?;

let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in self.lora_matrices(name.as_ref()).await? {
let factor = vec![discount * lora.alpha / lora.rank as f32, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
Expand All @@ -621,18 +606,16 @@ impl<R: Reader> Loader<R> {
lora.y.view(.., .., .., ..)?,
matrix.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
for lora in self.lora_vectors(name.as_ref()).await? {
let factor = vec![discount * lora.alpha, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
let op = TensorOp::blend(&factor, &lora.tensor, matrix)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}
context.queue.submit(Some(encoder.finish()));

context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(())
}

Expand All @@ -648,20 +631,16 @@ impl<R: Reader> Loader<R> {
Ok(tensor)
} else {
let tensor = TensorCpu::from_reader((dt, shape, tensor))?.transfer_into(context);
let mut encoder = context.device.create_command_encoder(&Default::default());
let mut ops = vec![];
for lora in lora {
let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
ops.push(op);
}

let map = context.tensor_init(tensor.shape());
encoder.copy_tensor(&tensor, &map)?;

context.queue.submit(Some(encoder.finish()));
Ok(map.back().await)
context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(tensor.back().await)
}
}

Expand Down
17 changes: 3 additions & 14 deletions src/model/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ use super::{ModelBase, ModelInfo, ModelOutput};
use crate::{
context::Context,
tensor::{
kind::ReadWrite,
ops::{TensorOp, TensorPass},
shape::Shape,
TensorCpu, TensorError, TensorGpu, TensorInit, TensorShape,
kind::ReadWrite, ops::TensorOp, shape::Shape, TensorCpu, TensorError, TensorGpu,
TensorInit, TensorShape,
},
};

Expand Down Expand Up @@ -75,16 +73,7 @@ impl<M: ModelBase> ModelSoftmax for M {
softmax.buffer.load(&input)?;

let op = TensorOp::softmax(&softmax.buffer)?;

let mut encoder = self
.context()
.device
.create_command_encoder(&Default::default());

let mut pass = encoder.begin_compute_pass(&Default::default());
pass.execute_tensor_op(&op);
drop(pass);
context.queue.submit(Some(encoder.finish()));
context.queue.submit(context.encode(&op));

let output = softmax.buffer.back().await;
Ok(redirect
Expand Down
Loading

0 comments on commit d771285

Please sign in to comment.