diff --git a/src/model/run.rs b/src/model/run.rs index 8ab57c5..67291ee 100644 --- a/src/model/run.rs +++ b/src/model/run.rs @@ -4,7 +4,10 @@ use anyhow::Result; use half::f16; use itertools::Itertools; -use super::{ModelBase, ModelError, ModelInfo, ModelInput, ModelOutput, ModelState, OutputType}; +use super::{ + ModelBase, ModelError, ModelInfo, ModelInput, ModelOutput, ModelState, OutputType, + MIN_TOKEN_CHUNK_SIZE, +}; use crate::{ context::Context, tensor::{ @@ -152,7 +155,12 @@ where } // we only infer at most `token_chunk_size` tokens at a time - let mut num_token = num_token.min(self.token_chunk_size()); + let num_token = num_token.min(self.token_chunk_size()); + let mut num_token = match num_token > MIN_TOKEN_CHUNK_SIZE { + true => num_token - num_token % MIN_TOKEN_CHUNK_SIZE, + false => num_token, + }; + let mut inputs = vec![vec![]; max_batch]; let mut outputs: Vec> = vec![None; max_batch]; diff --git a/src/model/v4.rs b/src/model/v4.rs index 4ffe90f..92f767a 100644 --- a/src/model/v4.rs +++ b/src/model/v4.rs @@ -10,7 +10,7 @@ use super::{ run::{Header, HookMap, ModelRunInternal}, softmax::{ModelSoftmaxInternal, Softmax}, FromBuilder, ModelBase, ModelBuilder, ModelError, ModelInfo, OutputType, PreparedModelBuilder, - Quant, StateBuilder, + Quant, StateBuilder, MIN_TOKEN_CHUNK_SIZE, }; use crate::{ context::Context, @@ -625,7 +625,7 @@ impl ModelRunInternal for Model<'_> { #[inline] fn turbo(&self, num_token: usize) -> bool { - self.turbo && num_token == self.token_chunk_size + self.turbo && num_token % MIN_TOKEN_CHUNK_SIZE == 0 } fn run_internal( diff --git a/src/model/v5.rs b/src/model/v5.rs index 8000e58..011cfca 100644 --- a/src/model/v5.rs +++ b/src/model/v5.rs @@ -9,7 +9,7 @@ use super::{ run::{Header, HookMap, ModelRunInternal}, softmax::{ModelSoftmaxInternal, Softmax}, FromBuilder, ModelBase, ModelBuilder, ModelError, ModelInfo, PreparedModelBuilder, Quant, - StateBuilder, + StateBuilder, MIN_TOKEN_CHUNK_SIZE, }; use crate::{ context::Context, @@ -718,7 +718,7 @@ impl ModelRunInternal for Model<'_> { #[inline] fn turbo(&self, num_token: usize) -> bool { - self.turbo && num_token == self.token_chunk_size + self.turbo && num_token % MIN_TOKEN_CHUNK_SIZE == 0 } fn run_internal( diff --git a/src/model/v6.rs b/src/model/v6.rs index 8d5204a..ecb9cb9 100644 --- a/src/model/v6.rs +++ b/src/model/v6.rs @@ -9,7 +9,7 @@ use super::{ run::{Header, HookMap, ModelRunInternal}, softmax::{ModelSoftmaxInternal, Softmax}, FromBuilder, ModelBase, ModelBuilder, ModelError, ModelInfo, OutputType, PreparedModelBuilder, - Quant, StateBuilder, + Quant, StateBuilder, MIN_TOKEN_CHUNK_SIZE, }; use crate::{ context::Context, @@ -776,7 +776,7 @@ impl ModelRunInternal for Model<'_> { #[inline] fn turbo(&self, num_token: usize) -> bool { - self.turbo && num_token == self.token_chunk_size + self.turbo && num_token % MIN_TOKEN_CHUNK_SIZE == 0 } fn run_internal(