Skip to content

Commit

Permalink
Better batching for turbo.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 6, 2024
1 parent 9f28bf9 commit 816162c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
12 changes: 10 additions & 2 deletions src/model/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use anyhow::Result;
use half::f16;
use itertools::Itertools;

use super::{ModelBase, ModelError, ModelInfo, ModelInput, ModelOutput, ModelState, OutputType};
use super::{
ModelBase, ModelError, ModelInfo, ModelInput, ModelOutput, ModelState, OutputType,
MIN_TOKEN_CHUNK_SIZE,
};
use crate::{
context::Context,
tensor::{
Expand Down Expand Up @@ -152,7 +155,12 @@ where
}

// we only infer at most `token_chunk_size` tokens at a time
let mut num_token = num_token.min(self.token_chunk_size());
let num_token = num_token.min(self.token_chunk_size());
let mut num_token = match num_token > MIN_TOKEN_CHUNK_SIZE {
true => num_token - num_token % MIN_TOKEN_CHUNK_SIZE,
false => num_token,
};

let mut inputs = vec![vec![]; max_batch];
let mut outputs: Vec<Option<OutputType>> = vec![None; max_batch];

Expand Down
4 changes: 2 additions & 2 deletions src/model/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::{
run::{Header, HookMap, ModelRunInternal},
softmax::{ModelSoftmaxInternal, Softmax},
FromBuilder, ModelBase, ModelBuilder, ModelError, ModelInfo, OutputType, PreparedModelBuilder,
Quant, StateBuilder,
Quant, StateBuilder, MIN_TOKEN_CHUNK_SIZE,
};
use crate::{
context::Context,
Expand Down Expand Up @@ -625,7 +625,7 @@ impl ModelRunInternal for Model<'_> {

#[inline]
fn turbo(&self, num_token: usize) -> bool {
self.turbo && num_token == self.token_chunk_size
self.turbo && num_token % MIN_TOKEN_CHUNK_SIZE == 0
}

fn run_internal(
Expand Down
4 changes: 2 additions & 2 deletions src/model/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
run::{Header, HookMap, ModelRunInternal},
softmax::{ModelSoftmaxInternal, Softmax},
FromBuilder, ModelBase, ModelBuilder, ModelError, ModelInfo, PreparedModelBuilder, Quant,
StateBuilder,
StateBuilder, MIN_TOKEN_CHUNK_SIZE,
};
use crate::{
context::Context,
Expand Down Expand Up @@ -718,7 +718,7 @@ impl ModelRunInternal for Model<'_> {

#[inline]
fn turbo(&self, num_token: usize) -> bool {
self.turbo && num_token == self.token_chunk_size
self.turbo && num_token % MIN_TOKEN_CHUNK_SIZE == 0
}

fn run_internal(
Expand Down
4 changes: 2 additions & 2 deletions src/model/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
run::{Header, HookMap, ModelRunInternal},
softmax::{ModelSoftmaxInternal, Softmax},
FromBuilder, ModelBase, ModelBuilder, ModelError, ModelInfo, OutputType, PreparedModelBuilder,
Quant, StateBuilder,
Quant, StateBuilder, MIN_TOKEN_CHUNK_SIZE,
};
use crate::{
context::Context,
Expand Down Expand Up @@ -776,7 +776,7 @@ impl ModelRunInternal for Model<'_> {

#[inline]
fn turbo(&self, num_token: usize) -> bool {
self.turbo && num_token == self.token_chunk_size
self.turbo && num_token % MIN_TOKEN_CHUNK_SIZE == 0
}

fn run_internal(
Expand Down

0 comments on commit 816162c

Please sign in to comment.