Skip to content

Commit

Permalink
Allow embed on gpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jan 1, 2024
1 parent d3a0d9a commit 9414803
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 66 deletions.
27 changes: 25 additions & 2 deletions examples/batch.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use clap::Parser;
use clap::{Parser, ValueEnum};
#[cfg(not(debug_assertions))]
use crossterm::terminal::{
disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen,
Expand Down Expand Up @@ -106,6 +106,7 @@ fn load_model<M: Model>(
lora: Option<PathBuf>,
quant: Option<usize>,
quant_nf4: Option<usize>,
embed_device: Option<EmbedDevice>,
turbo: bool,
) -> Result<M> {
let quant = quant
Expand All @@ -117,7 +118,8 @@ fn load_model<M: Model>(
let quant = quant.into_iter().chain(quant_nf4).collect();
let model = ModelBuilder::new(context, data)
.with_quant(quant)
.with_turbo(turbo);
.with_turbo(turbo)
.with_embed_device(embed_device.unwrap_or_default().into());
match lora {
Some(lora) => {
let file = File::open(lora)?;
Expand Down Expand Up @@ -175,6 +177,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
// The model state should keep the same batch as input.
Expand All @@ -192,6 +195,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
// The model state should keep the same batch as input.
Expand All @@ -209,6 +213,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
// The model state should keep the same batch as input.
Expand Down Expand Up @@ -334,6 +339,22 @@ where
Ok(())
}

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ValueEnum)]
enum EmbedDevice {
#[default]
Cpu,
Gpu,
}

impl From<EmbedDevice> for web_rwkv::model::EmbedDevice {
fn from(value: EmbedDevice) -> Self {
match value {
EmbedDevice::Cpu => Self::Cpu,
EmbedDevice::Gpu => Self::Gpu,
}
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
Expand All @@ -345,6 +366,8 @@ struct Cli {
quant: Option<usize>,
#[arg(long, value_name = "LAYERS")]
quant_nf4: Option<usize>,
#[arg(short, long)]
embed_device: Option<EmbedDevice>,
#[arg(short, long, action)]
turbo: bool,
#[arg(short, long, default_value_t = 4)]
Expand Down
28 changes: 26 additions & 2 deletions examples/chat.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use clap::{Args, Parser};
use clap::{Args, Parser, ValueEnum};
#[cfg(not(debug_assertions))]
use dialoguer::{theme::ColorfulTheme, Select};
use itertools::Itertools;
Expand Down Expand Up @@ -74,6 +74,7 @@ async fn create_context(info: &ModelInfo) -> Result<Context> {
let instance = Instance::new();
let limits = wgpu::Limits {
max_storage_buffer_binding_size: info.max_buffer_size() as u32,
max_buffer_size: info.max_buffer_size() as u64,
..Default::default()
};
#[cfg(not(debug_assertions))]
Expand Down Expand Up @@ -118,6 +119,7 @@ fn load_model<M: Model>(
lora: Option<PathBuf>,
quant: Option<usize>,
quant_nf4: Option<usize>,
embed_device: Option<EmbedDevice>,
turbo: bool,
) -> Result<M> {
let quant = quant
Expand All @@ -129,7 +131,8 @@ fn load_model<M: Model>(
let quant = quant.into_iter().chain(quant_nf4).collect();
let model = ModelBuilder::new(context, data)
.with_quant(quant)
.with_turbo(turbo);
.with_turbo(turbo)
.with_embed_device(embed_device.unwrap_or_default().into());
match lora {
Some(lora) => {
let file = File::open(lora)?;
Expand Down Expand Up @@ -197,6 +200,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
let state: v4::ModelState = StateBuilder::new(&context, model.info()).build();
Expand All @@ -209,6 +213,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
let state: v5::ModelState = StateBuilder::new(&context, model.info()).build();
Expand All @@ -221,6 +226,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
let state: v6::ModelState = StateBuilder::new(&context, model.info()).build();
Expand Down Expand Up @@ -336,6 +342,22 @@ where
Ok(())
}

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ValueEnum)]
enum EmbedDevice {
#[default]
Cpu,
Gpu,
}

impl From<EmbedDevice> for web_rwkv::model::EmbedDevice {
fn from(value: EmbedDevice) -> Self {
match value {
EmbedDevice::Cpu => Self::Cpu,
EmbedDevice::Gpu => Self::Gpu,
}
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
Expand All @@ -349,6 +371,8 @@ struct Cli {
quant: Option<usize>,
#[arg(long, value_name = "LAYERS")]
quant_nf4: Option<usize>,
#[arg(short, long)]
embed_device: Option<EmbedDevice>,
#[arg(short, long, action)]
turbo: bool,
#[command(flatten)]
Expand Down
27 changes: 25 additions & 2 deletions examples/gen.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use clap::Parser;
use clap::{Parser, ValueEnum};
#[cfg(not(debug_assertions))]
use dialoguer::{theme::ColorfulTheme, Select};
use itertools::Itertools;
Expand Down Expand Up @@ -95,6 +95,7 @@ fn load_model<M: Model>(
lora: Option<PathBuf>,
quant: Option<usize>,
quant_nf4: Option<usize>,
embed_device: Option<EmbedDevice>,
turbo: bool,
) -> Result<M> {
let quant = quant
Expand All @@ -106,7 +107,8 @@ fn load_model<M: Model>(
let quant = quant.into_iter().chain(quant_nf4).collect();
let model = ModelBuilder::new(context, data)
.with_quant(quant)
.with_turbo(turbo);
.with_turbo(turbo)
.with_embed_device(embed_device.unwrap_or_default().into());
match lora {
Some(lora) => {
let file = File::open(lora)?;
Expand Down Expand Up @@ -149,6 +151,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
let state: v4::ModelState = StateBuilder::new(&context, model.info()).build();
Expand All @@ -161,6 +164,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
let state: v5::ModelState = StateBuilder::new(&context, model.info()).build();
Expand All @@ -173,6 +177,7 @@ async fn run(cli: Cli) -> Result<()> {
cli.lora,
cli.quant,
cli.quant_nf4,
cli.embed_device,
cli.turbo,
)?;
let state: v6::ModelState = StateBuilder::new(&context, model.info()).build();
Expand Down Expand Up @@ -217,6 +222,22 @@ where
Ok(())
}

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ValueEnum)]
enum EmbedDevice {
#[default]
Cpu,
Gpu,
}

impl From<EmbedDevice> for web_rwkv::model::EmbedDevice {
fn from(value: EmbedDevice) -> Self {
match value {
EmbedDevice::Cpu => Self::Cpu,
EmbedDevice::Gpu => Self::Gpu,
}
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
Expand All @@ -228,6 +249,8 @@ struct Cli {
quant: Option<usize>,
#[arg(long, value_name = "LAYERS")]
quant_nf4: Option<usize>,
#[arg(short, long)]
embed_device: Option<EmbedDevice>,
#[arg(short, long, action)]
turbo: bool,
}
Expand Down
47 changes: 25 additions & 22 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,26 @@ impl<'a> ContextBuilder<'a> {
}

pub async fn build(self) -> Result<Context, CreateEnvironmentError> {
let (device, queue) = self
.adapter
let Self {
adapter,
features,
limits,
pipelines,
} = self;

let (device, queue) = adapter
.request_device(
&DeviceDescriptor {
label: None,
features: self.features,
limits: self.limits,
features,
limits,
},
None,
)
.await
.map_err(|_| CreateEnvironmentError::RequestDeviceFailed)?;
let pipelines = self
.pipelines

let pipelines = pipelines
.into_iter()
.map(|(name, (shader, entry_point, layout))| {
let module = &device.create_shader_module(ShaderModuleDescriptor {
Expand Down Expand Up @@ -158,10 +164,11 @@ impl<'a> ContextBuilder<'a> {
)
})
.collect();

Ok(Context(
ContextInner {
id: ContextId::new(),
adapter: self.adapter,
adapter,
device,
queue,
pipelines,
Expand All @@ -172,30 +179,29 @@ impl<'a> ContextBuilder<'a> {
))
}

pub fn with_limits(self, limits: Limits) -> Self {
Self { limits, ..self }
pub fn with_limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}

pub fn with_features(self, features: Features) -> Self {
Self { features, ..self }
pub fn with_features(mut self, features: Features) -> Self {
self.features = features;
self
}

pub fn with_pipeline(
self,
mut self,
name: &'a str,
shader: &'a str,
entry_point: &'a str,
layout: Option<&'a [BindGroupLayoutEntry]>,
) -> Self {
let mut pipelines = self.pipelines;
pipelines.insert(name, (shader, entry_point, layout));
Self { pipelines, ..self }
self.pipelines.insert(name, (shader, entry_point, layout));
self
}

pub fn with_default_pipelines(self) -> Self {
self.with_core_pipelines()
.with_util_pipelines()
.with_quant_pipelines()
self.with_core_pipelines().with_quant_pipelines()
}

fn with_core_pipelines(self) -> Self {
Expand Down Expand Up @@ -333,10 +339,7 @@ impl<'a> ContextBuilder<'a> {
"softmax",
None,
)
}

fn with_util_pipelines(self) -> Self {
self.with_pipeline("blit", include_str!("shaders/blit.wgsl"), "blit", None)
.with_pipeline("blit", include_str!("shaders/blit.wgsl"), "blit", None)
.with_pipeline(
"transpose",
include_str!("shaders/blit.wgsl"),
Expand Down
13 changes: 9 additions & 4 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ pub mod v6;
pub const RESCALE_LAYER: usize = 6;

pub const MIN_TOKEN_CHUNK_SIZE: usize = 32;
pub const HEAD_CHUNK_SIZES: [usize; 8] = [
0x4000, 0x3000, 0x2000, 0x1800, 0x1600, 0x1400, 0x1200, 0x1000,
pub const HEAD_CHUNK_SIZES: [usize; 10] = [
0x10000, 0x5000, 0x4000, 0x3000, 0x2000, 0x1800, 0x1600, 0x1400, 0x1200, 0x1000,
];

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down Expand Up @@ -72,9 +72,14 @@ pub struct ModelInfo {
}

impl ModelInfo {
/// Computes the required storage buffer size.
/// Computes the required storage buffer size, not including head.
pub fn max_non_head_buffer_size(&self) -> usize {
(self.num_emb * self.num_hidden * f16::size()).max(256 << 20)
}

/// Computes the required storage buffer size, including head.
pub fn max_buffer_size(&self) -> usize {
(self.num_emb * self.num_hidden * f16::size()).max(128 << 20)
(self.num_emb * self.num_vocab * f16::size()).max(256 << 20)
}
}

Expand Down
Loading

0 comments on commit 9414803

Please sign in to comment.