diff --git a/examples/batch.rs b/examples/batch.rs index 268b246..1dced52 100644 --- a/examples/batch.rs +++ b/examples/batch.rs @@ -58,12 +58,8 @@ fn sample(probs: Vec, _top_p: f32) -> u16 { token as u16 } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, embed_device: Option) -> Result { let instance = Instance::new(); - let limits = wgpu::Limits { - max_storage_buffer_binding_size: info.max_buffer_size() as u32, - ..Default::default() - }; #[cfg(not(debug_assertions))] let adapter = { let backends = wgpu::Backends::all(); @@ -85,7 +81,7 @@ async fn create_context(info: &ModelInfo) -> Result { .await?; let context = ContextBuilder::new(adapter) .with_default_pipelines() - .with_limits(limits) + .with_auto_limits(info, embed_device.unwrap_or_default().into()) .build() .await?; println!("{:#?}", context.adapter.get_info()); @@ -167,7 +163,7 @@ async fn run(cli: Cli) -> Result<()> { let info = Loader::info(&map)?; println!("{:#?}", info); - let context = create_context(&info).await?; + let context = create_context(&info, cli.embed_device).await?; match info.version { ModelVersion::V4 => { diff --git a/examples/chat.rs b/examples/chat.rs index a5f3241..b173812 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -70,13 +70,8 @@ impl Sampler { } } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, embed_device: Option) -> Result { let instance = Instance::new(); - let limits = wgpu::Limits { - max_storage_buffer_binding_size: info.max_buffer_size() as u32, - max_buffer_size: info.max_buffer_size() as u64, - ..Default::default() - }; #[cfg(not(debug_assertions))] let adapter = { let backends = wgpu::Backends::all(); @@ -98,7 +93,7 @@ async fn create_context(info: &ModelInfo) -> Result { .await?; let context = ContextBuilder::new(adapter) .with_default_pipelines() - .with_limits(limits) + .with_auto_limits(info, embed_device.unwrap_or_default().into()) .build() .await?; println!("{:#?}", context.adapter.get_info()); @@ -190,7 +185,7 @@ async fn run(cli: Cli) -> Result<()> { let info = Loader::info(&map)?; println!("{:#?}", info); - let context = create_context(&info).await?; + let context = create_context(&info, cli.embed_device).await?; match info.version { ModelVersion::V4 => { diff --git a/examples/gen.rs b/examples/gen.rs index eb3f7a8..ce8326a 100644 --- a/examples/gen.rs +++ b/examples/gen.rs @@ -47,12 +47,8 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 { token as u16 } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, embed_device: Option) -> Result { let instance = Instance::new(); - let limits = wgpu::Limits { - max_storage_buffer_binding_size: info.max_buffer_size() as u32, - ..Default::default() - }; #[cfg(not(debug_assertions))] let adapter = { let backends = wgpu::Backends::all(); @@ -74,7 +70,7 @@ async fn create_context(info: &ModelInfo) -> Result { .await?; let context = ContextBuilder::new(adapter) .with_default_pipelines() - .with_limits(limits) + .with_auto_limits(info, embed_device.unwrap_or_default().into()) .build() .await?; println!("{:#?}", context.adapter.get_info()); @@ -141,7 +137,7 @@ async fn run(cli: Cli) -> Result<()> { let info = Loader::info(&map)?; println!("{:#?}", info); - let context = create_context(&info).await?; + let context = create_context(&info, cli.embed_device).await?; match info.version { ModelVersion::V4 => { diff --git a/src/context.rs b/src/context.rs index b5dc69a..59ca362 100644 --- a/src/context.rs +++ b/src/context.rs @@ -9,10 +9,13 @@ use wgpu::{ ShaderModuleDescriptor, ShaderStages, }; -use crate::tensor::{ - cache::ResourceCache, - shape::{IntoBytes, Shape}, - TensorError, View, +use crate::{ + model::{EmbedDevice, ModelInfo}, + tensor::{ + cache::ResourceCache, + shape::{IntoBytes, Shape}, + TensorError, View, + }, }; #[derive(Deref)] @@ -184,6 +187,17 @@ impl<'a> ContextBuilder<'a> { self } + /// Compute the limits automatically based on given model build info. + pub fn with_auto_limits(mut self, info: &ModelInfo, embed_device: EmbedDevice) -> Self { + let max_buffer_size = match embed_device { + EmbedDevice::Cpu => info.max_non_head_buffer_size(), + EmbedDevice::Gpu => info.max_buffer_size(), + }; + self.limits.max_buffer_size = (256 << 20).max(max_buffer_size as u64); + self.limits.max_storage_buffer_binding_size = (128 << 20).max(max_buffer_size as u32); + self + } + pub fn with_features(mut self, features: Features) -> Self { self.features = features; self