Skip to content

Commit

Permalink
Config the limits automatically.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jan 1, 2024
1 parent ece33db commit 1856547
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
10 changes: 3 additions & 7 deletions examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ fn sample(probs: Vec<f32>, _top_p: f32) -> u16 {
token as u16
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) -> Result<Context> {
let instance = Instance::new();
let limits = wgpu::Limits {
max_storage_buffer_binding_size: info.max_buffer_size() as u32,
..Default::default()
};
#[cfg(not(debug_assertions))]
let adapter = {
let backends = wgpu::Backends::all();
Expand All @@ -85,7 +81,7 @@ async fn create_context(info: &ModelInfo) -> Result<Context> {
.await?;
let context = ContextBuilder::new(adapter)
.with_default_pipelines()
.with_limits(limits)
.with_auto_limits(info, embed_device.unwrap_or_default().into())
.build()
.await?;
println!("{:#?}", context.adapter.get_info());
Expand Down Expand Up @@ -167,7 +163,7 @@ async fn run(cli: Cli) -> Result<()> {
let info = Loader::info(&map)?;
println!("{:#?}", info);

let context = create_context(&info).await?;
let context = create_context(&info, cli.embed_device).await?;

match info.version {
ModelVersion::V4 => {
Expand Down
11 changes: 3 additions & 8 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,8 @@ impl Sampler {
}
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) -> Result<Context> {
let instance = Instance::new();
let limits = wgpu::Limits {
max_storage_buffer_binding_size: info.max_buffer_size() as u32,
max_buffer_size: info.max_buffer_size() as u64,
..Default::default()
};
#[cfg(not(debug_assertions))]
let adapter = {
let backends = wgpu::Backends::all();
Expand All @@ -98,7 +93,7 @@ async fn create_context(info: &ModelInfo) -> Result<Context> {
.await?;
let context = ContextBuilder::new(adapter)
.with_default_pipelines()
.with_limits(limits)
.with_auto_limits(info, embed_device.unwrap_or_default().into())
.build()
.await?;
println!("{:#?}", context.adapter.get_info());
Expand Down Expand Up @@ -190,7 +185,7 @@ async fn run(cli: Cli) -> Result<()> {
let info = Loader::info(&map)?;
println!("{:#?}", info);

let context = create_context(&info).await?;
let context = create_context(&info, cli.embed_device).await?;

match info.version {
ModelVersion::V4 => {
Expand Down
10 changes: 3 additions & 7 deletions examples/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 {
token as u16
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, embed_device: Option<EmbedDevice>) -> Result<Context> {
let instance = Instance::new();
let limits = wgpu::Limits {
max_storage_buffer_binding_size: info.max_buffer_size() as u32,
..Default::default()
};
#[cfg(not(debug_assertions))]
let adapter = {
let backends = wgpu::Backends::all();
Expand All @@ -74,7 +70,7 @@ async fn create_context(info: &ModelInfo) -> Result<Context> {
.await?;
let context = ContextBuilder::new(adapter)
.with_default_pipelines()
.with_limits(limits)
.with_auto_limits(info, embed_device.unwrap_or_default().into())
.build()
.await?;
println!("{:#?}", context.adapter.get_info());
Expand Down Expand Up @@ -141,7 +137,7 @@ async fn run(cli: Cli) -> Result<()> {
let info = Loader::info(&map)?;
println!("{:#?}", info);

let context = create_context(&info).await?;
let context = create_context(&info, cli.embed_device).await?;

match info.version {
ModelVersion::V4 => {
Expand Down
22 changes: 18 additions & 4 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ use wgpu::{
ShaderModuleDescriptor, ShaderStages,
};

use crate::tensor::{
cache::ResourceCache,
shape::{IntoBytes, Shape},
TensorError, View,
use crate::{
model::{EmbedDevice, ModelInfo},
tensor::{
cache::ResourceCache,
shape::{IntoBytes, Shape},
TensorError, View,
},
};

#[derive(Deref)]
Expand Down Expand Up @@ -184,6 +187,17 @@ impl<'a> ContextBuilder<'a> {
self
}

/// Compute the limits automatically based on given model build info.
pub fn with_auto_limits(mut self, info: &ModelInfo, embed_device: EmbedDevice) -> Self {
let max_buffer_size = match embed_device {
EmbedDevice::Cpu => info.max_non_head_buffer_size(),
EmbedDevice::Gpu => info.max_buffer_size(),
};
self.limits.max_buffer_size = (256 << 20).max(max_buffer_size as u64);
self.limits.max_storage_buffer_binding_size = (128 << 20).max(max_buffer_size as u32);
self
}

pub fn with_features(mut self, features: Features) -> Self {
self.features = features;
self
Expand Down

0 comments on commit 1856547

Please sign in to comment.