Skip to content

Commit

Permalink
Make converter a lib feature.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Dec 10, 2023
1 parent ac0b450 commit 8fe45cd
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 108 deletions.
25 changes: 4 additions & 21 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ env:

# Space separated paths to include in the archive.
RELEASE_ADDS: README.md LICENSE assets
RELEASE_BINS: web-rwkv-converter
EXAMPLE_BINS: gen chat batch
EXAMPLE_BINS: gen chat batch converter

jobs:
build:
Expand Down Expand Up @@ -64,19 +63,15 @@ jobs:
if: matrix.build == 'linux'
run: |
rustup target add ${{ env.LINUX_TARGET }}
cargo build --release --workspace --target ${{ env.LINUX_TARGET }}
cargo build --release --examples --target ${{ env.LINUX_TARGET }}
cargo build --release --examples --target ${{ env.LINUX_TARGET }} --all-features
- name: Build (MacOS)
if: matrix.build == 'macos'
run: |
cargo build --release --workspace
cargo build --release --examples
cargo build --release --examples --all-features
- name: Build (Windows)
if: matrix.build == 'windows'
run: |
cargo build --release --workspace
cargo build --release --examples
cargo build --release --examples --all-features
- name: Create artifact directory
run: |
Expand All @@ -85,10 +80,6 @@ jobs:
- name: Create tarball (Linux)
if: matrix.build == 'linux'
run: |
for bin in ${{ env.RELEASE_BINS }}
do
mv ./target/${{ env.LINUX_TARGET }}/release/${bin} ./dist/${bin}
done
for bin in ${{ env.EXAMPLE_BINS }}
do
mv ./target/${{ env.LINUX_TARGET }}/release/examples/${bin} ./dist/${bin}
Expand All @@ -99,10 +90,6 @@ jobs:
if: matrix.build == 'windows'
shell: bash
run: |
for bin in ${{ env.RELEASE_BINS }}
do
mv ./target/release/${bin}.exe ./dist/${bin}.exe
done
for bin in ${{ env.EXAMPLE_BINS }}
do
mv ./target/release/examples/${bin}.exe ./dist/${bin}.exe
Expand All @@ -112,10 +99,6 @@ jobs:
- name: Create tarball (MacOS)
if: matrix.build == 'macos'
run: |
for bin in ${{ env.RELEASE_BINS }}
do
mv ./target/release/${bin} ./dist/${bin}
done
for bin in ${{ env.EXAMPLE_BINS }}
do
mv ./target/release/examples/${bin} ./dist/${bin}
Expand Down
18 changes: 14 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.4.4"
version = "0.4.5"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand All @@ -13,9 +13,6 @@ exclude = ["assets/", "crates/", "screenshots/"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[workspace]
members = ["crates/*"]

[dependencies]
wgpu = "0.18"
bytemuck = { version = "1.13", features = ["extern_crate_alloc"] }
Expand All @@ -36,6 +33,19 @@ itertools = "0.11"
log = "0.4"
web-rwkv-derive = { version = "0.2.0", path = "crates/web-rwkv-derive" }

[dependencies.repugnant-pickle]
git = "https://github.com/KerfuffleV2/repugnant-pickle"
tag = "v0.0.1"
features = ["torch"]
optional = true

[features]
converter = ["dep:repugnant-pickle"]

[[example]]
name = "converter"
required-features = ["converter"]

[dev-dependencies]
pollster = "0.3.0"
memmap2 = "0.7"
Expand Down
4 changes: 0 additions & 4 deletions crates/web-rwkv-converter/.gitignore

This file was deleted.

18 changes: 0 additions & 18 deletions crates/web-rwkv-converter/Cargo.toml

This file was deleted.

6 changes: 3 additions & 3 deletions examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ fn load_tokenizer() -> Result<Tokenizer> {
Ok(Tokenizer::new(&contents)?)
}

fn load_model<'a, M: Model>(
fn load_model<M: Model>(
context: &Context,
data: &'a [u8],
data: &[u8],
lora: Option<PathBuf>,
quant: Option<usize>,
quant_nf4: Option<usize>,
Expand All @@ -108,7 +108,7 @@ fn load_model<'a, M: Model>(
let quant_nf4 = quant_nf4
.map(|layer| (0..layer).map(|layer| (layer, Quant::NF4)).collect_vec())
.unwrap_or_default();
let quant = quant.into_iter().chain(quant_nf4.into_iter()).collect();
let quant = quant.into_iter().chain(quant_nf4).collect();
let model = ModelBuilder::new(context, data)
.with_quant(quant)
.with_turbo(turbo);
Expand Down
6 changes: 3 additions & 3 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ fn load_tokenizer() -> Result<Tokenizer> {
Ok(Tokenizer::new(&contents)?)
}

fn load_model<'a, M: Model>(
fn load_model<M: Model>(
context: &Context,
data: &'a [u8],
data: &[u8],
lora: Option<PathBuf>,
quant: Option<usize>,
quant_nf4: Option<usize>,
Expand All @@ -121,7 +121,7 @@ fn load_model<'a, M: Model>(
let quant_nf4 = quant_nf4
.map(|layer| (0..layer).map(|layer| (layer, Quant::NF4)).collect_vec())
.unwrap_or_default();
let quant = quant.into_iter().chain(quant_nf4.into_iter()).collect();
let quant = quant.into_iter().chain(quant_nf4).collect();
let model = ModelBuilder::new(context, data)
.with_quant(quant)
.with_turbo(turbo);
Expand Down
36 changes: 36 additions & 0 deletions examples/converter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::{fs::File, path::PathBuf};

use anyhow::Result;
use clap::Parser;
use memmap2::Mmap;
use web_rwkv::converter::{convert_safetensors, RENAME, TRANSPOSE};

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[arg(short, long, value_name = "FILE")]
input: PathBuf,
#[arg(short, long, value_name = "FILE")]
output: Option<PathBuf>,
}

fn main() -> Result<()> {
let cli = Cli::parse();

let file = File::open(&cli.input)?;
let map = unsafe { Mmap::map(&file)? };

let output = cli.output.unwrap_or_else(|| {
let path = cli
.input
.parent()
.map(|p| p.to_path_buf())
.unwrap_or_default();
let stem = cli.input.file_stem().expect("please name the file");
let name: PathBuf = [&stem.to_string_lossy(), "st"].join(".").into();
path.join(name)
});
convert_safetensors(cli.input, &map, output, RENAME, TRANSPOSE)?;

Ok(())
}
6 changes: 3 additions & 3 deletions examples/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ fn load_tokenizer() -> Result<Tokenizer> {
Ok(Tokenizer::new(&contents)?)
}

fn load_model<'a, M: Model>(
fn load_model<M: Model>(
context: &Context,
data: &'a [u8],
data: &[u8],
lora: Option<PathBuf>,
quant: Option<usize>,
quant_nf4: Option<usize>,
Expand All @@ -97,7 +97,7 @@ fn load_model<'a, M: Model>(
let quant_nf4 = quant_nf4
.map(|layer| (0..layer).map(|layer| (layer, Quant::NF4)).collect_vec())
.unwrap_or_default();
let quant = quant.into_iter().chain(quant_nf4.into_iter()).collect();
let quant = quant.into_iter().chain(quant_nf4).collect();
let model = ModelBuilder::new(context, data)
.with_quant(quant)
.with_turbo(turbo);
Expand Down
86 changes: 34 additions & 52 deletions crates/web-rwkv-converter/src/main.rs → src/converter.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,40 @@
use std::{collections::HashMap, fs::File, path::PathBuf};
use std::{collections::HashMap, path::Path};

use anyhow::Result;
use clap::Parser;
use half::{bf16, f16};
use memmap2::Mmap;
use itertools::Itertools;
use repugnant_pickle::{RepugnantTorchTensors as TorchTensors, TensorType};
use safetensors::{tensor::TensorView, Dtype};

pub const RENAME: [(&str, &str); 4] = [
("time_faaaa", "time_first"),
("time_maa", "time_mix"),
("lora_A", "lora.0"),
("lora_B", "lora.1"),
];

pub const TRANSPOSE: [&str; 4] = [
"time_mix_w1",
"time_mix_w2",
"time_decay_w1",
"time_decay_w2",
];

struct Tensor {
name: String,
shape: Vec<usize>,
data: Vec<f16>,
}

fn load_tensors(
data: &[u8],
fn load_tensors<'a, 'b, 'c>(
data: &'a [u8],
torch: TorchTensors,
rename: &[(&str, &str)],
transpose: &[&str],
rename: impl IntoIterator<Item = (&'b str, &'b str)>,
transpose: impl IntoIterator<Item = &'c str>,
) -> Vec<Tensor> {
let mut tensors = vec![];
let rename = rename.into_iter().collect_vec();
let transpose = transpose.into_iter().collect_vec();

for tensor in torch.into_iter() {
let name = rename
Expand Down Expand Up @@ -68,58 +83,25 @@ fn load_tensors(
tensors
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[arg(short, long, value_name = "FILE")]
input: PathBuf,
#[arg(short, long, value_name = "FILE")]
output: Option<PathBuf>,
}

fn main() -> Result<()> {
let cli = Cli::parse();
let tensors = TorchTensors::new_from_file(&cli.input)?;
// print!("{:#?}", tensors);

let file = File::open(&cli.input)?;
let map = unsafe { Mmap::map(&file)? };

let rename = [
("time_faaaa", "time_first"),
("time_maa", "time_mix"),
("lora_A", "lora.0"),
("lora_B", "lora.1"),
];
let transpose = [
"time_mix_w1",
"time_mix_w2",
"time_decay_w1",
"time_decay_w2",
];

let tensors = load_tensors(&map, tensors, &rename, &transpose);
pub fn convert_safetensors<'a, 'b, 'c>(
input: impl AsRef<Path>,
data: &'a [u8],
output: impl AsRef<Path>,
rename: impl IntoIterator<Item = (&'b str, &'b str)>,
transpose: impl IntoIterator<Item = &'c str>,
) -> Result<()> {
let torch = TorchTensors::new_from_file(input)?;
let tensors = load_tensors(data, torch, rename, transpose);
let views = tensors
.iter()
.map(|x| TensorView::new(Dtype::F16, x.shape.clone(), bytemuck::cast_slice(&x.data)))
.collect::<Result<Vec<_>, _>>()?;
let metadata: HashMap<String, TensorView> = tensors
let data = tensors
.iter()
.zip(views)
.map(|(tensor, view)| (tensor.name.clone(), view))
.collect();

let output = cli.output.unwrap_or_else(|| {
let path = cli
.input
.parent()
.map(|p| p.to_path_buf())
.unwrap_or_default();
let stem = cli.input.file_stem().expect("please name the file");
let name: PathBuf = [&stem.to_string_lossy(), "st"].join(".").into();
path.join(name)
});
safetensors::serialize_to_file(&metadata, &None, &output)?;
.collect::<HashMap<_, _>>();

safetensors::serialize_to_file(&data, &None, output.as_ref())?;
Ok(())
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod context;
#[cfg(feature = "converter")]
pub mod converter;
pub mod model;
pub mod num;
pub mod tensor;
Expand Down

0 comments on commit 8fe45cd

Please sign in to comment.