Skip to content

Commit

Permalink
Do not load LoRA for embed.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Sep 22, 2024
1 parent d3377b3 commit 47d72bd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 45 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ trait-variant = "0.1"
uid = "0.1"
wasm-bindgen = "0.2"
wgpu = "22.1.0"
pollster = "0.3"

[dependencies.web-rwkv-derive]
path = "crates/web-rwkv-derive"
Expand All @@ -47,7 +46,7 @@ version = "0.2.5"
[dependencies.tokio]
default-features = false
features = ["macros", "rt", "sync", "time"]
version = "1.37"
version = "1.40"

[dev-dependencies]
cbor4ii = { version = "0.3.2", features = ["half-f16", "serde1"] }
Expand Down
25 changes: 3 additions & 22 deletions src/model/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,28 +608,9 @@ impl<R: Reader> Loader<R> {
}

pub fn load_embed(&self) -> Result<TensorCpu<f16>> {
let context = &self.context;
let name = "emb.weight";

let (dt, shape, tensor) = self.model.tensor(name)?;
let lora = self.lora_vectors(name)?;

if lora.is_empty() {
let tensor = TensorCpu::from_reader((dt, shape, tensor))?;
Ok(tensor)
} else {
let tensor = TensorCpu::from_reader((dt, shape, tensor))?.transfer_into(context);
let mut ops = vec![];
for lora in lora {
let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
ops.push(op);
}

context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(pollster::block_on(tensor.back()))
}
let (dt, shape, tensor) = self.model.tensor("emb.weight")?;
let tensor = TensorCpu::from_reader((dt, shape, tensor))?;
Ok(tensor)
}

pub fn load_head(&self, chunk_size: usize) -> Result<Vec<TensorGpu<f16, ReadWrite>>> {
Expand Down
24 changes: 3 additions & 21 deletions src/runtime/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,27 +607,9 @@ impl<R: Reader> Loader<R> {
}

pub fn load_embed(&self) -> Result<TensorCpu<f16>> {
let context = &self.context;
let name = "emb.weight";

let (dt, shape, tensor) = self.model.tensor(name)?;
let lora = self.lora_vectors(name)?;

if lora.is_empty() {
let tensor = TensorCpu::from_reader((dt, shape, tensor))?;
Ok(tensor)
} else {
let tensor = TensorCpu::from_reader((dt, shape, tensor))?.transfer_into(context);
let mut ops = vec![];
for lora in lora {
let factor = vec![lora.alpha, 1.0, 0.0, 0.0];
let factor = context.tensor_from_data([4, 1, 1, 1], factor)?;
let op = TensorOp::blend(&factor, &lora.tensor, &tensor)?;
ops.push(op);
}
context.queue.submit(context.encode(&TensorOp::List(ops)));
Ok(pollster::block_on(tensor.back()))
}
let (dt, shape, tensor) = self.model.tensor("emb.weight")?;
let tensor = TensorCpu::from_reader((dt, shape, tensor))?;
Ok(tensor)
}

pub fn load_head(&self, chunk_size: usize) -> Result<Vec<TensorGpu<f16, ReadWrite>>> {
Expand Down

0 comments on commit 47d72bd

Please sign in to comment.