Skip to content

Commit

Permalink
Reorganize code.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Dec 13, 2023
1 parent c6d8109 commit 45268cf
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ where
"The Eiffel Tower is located in the city of",
"The name of the capital of Italy is",
"The Space Needle is located in downtown",
"User: 水是什么?\n\nAssistant: ",
"人们发现",
];
let mut prompts = prompts
.to_vec()
Expand Down
38 changes: 25 additions & 13 deletions src/model/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub trait ModelRun: ModelBase {
&self,
tokens: Vec<Vec<u16>>,
state: &Self::ModelState,
compute_head: Vec<bool>,
should_output: Vec<bool>,
) -> Result<(Arc<Output>, Vec<Option<usize>>)>;

/// Run the model for a batch of tokens as input.
Expand All @@ -63,32 +63,44 @@ pub trait ModelRun: ModelBase {
// we only infer at most `token_chunk_size` tokens at a time
let mut num_token = num_token.min(self.token_chunk_size());
let mut inputs = vec![vec![]; max_batch];
let mut compute_head = vec![false; max_batch];
let mut should_output = vec![false; max_batch];

// take `num_token` tokens out of all the inputs and put into `input`
// first pass, make sure each slot computes at least one token
for (index, (remain, input)) in tokens.iter_mut().zip(inputs.iter_mut()).enumerate() {
for (output, input, remain) in itertools::multizip((
should_output.iter_mut(),
inputs.iter_mut(),
tokens.iter_mut(),
)) {
let mid = 1.min(remain.len()).min(num_token);
num_token -= mid;

let (head, tail) = remain.split_at(mid);
compute_head[index] = tail.is_empty();
input.append(&mut head.to_vec());
*remain = tail.to_vec();
if mid > 0 {
let (head, tail) = remain.split_at(mid);
*output = tail.is_empty();
*input = [&input, head].concat();
*remain = tail.to_vec();
}
}

// second pass, assign rest token budgets from left to right
for (index, (remain, input)) in tokens.iter_mut().zip(inputs.iter_mut()).enumerate() {
for (output, input, remain) in itertools::multizip((
should_output.iter_mut(),
inputs.iter_mut(),
tokens.iter_mut(),
)) {
let mid = remain.len().min(num_token);
num_token -= mid;

let (head, tail) = remain.split_at(mid);
compute_head[index] = tail.is_empty();
input.append(&mut head.to_vec());
*remain = tail.to_vec();
if mid > 0 {
let (head, tail) = remain.split_at(mid);
*output = tail.is_empty();
*input = [&input, head].concat();
*remain = tail.to_vec();
}
}

let (output, redirect) = self.run_internal(inputs, state, compute_head)?;
let (output, redirect) = self.run_internal(inputs, state, should_output)?;
let output = output.map.clone().back_async().await;

Ok(redirect
Expand Down
4 changes: 2 additions & 2 deletions src/model/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ impl ModelRun for Model<'_> {
&self,
tokens: Vec<Vec<u16>>,
state: &ModelState,
compute_head: Vec<bool>,
should_output: Vec<bool>,
) -> Result<(Arc<Output>, Vec<Option<usize>>)> {
let context = &self.context;
let tensor = &self.tensor;
Expand Down Expand Up @@ -656,7 +656,7 @@ impl ModelRun for Model<'_> {
.cursors
.iter()
.filter(|cursor| cursor.len > 0)
.filter(|cursor| compute_head[cursor.batch])
.filter(|cursor| should_output[cursor.batch])
.enumerate()
.map(|(index, cursor)| {
redirect[cursor.batch] = Some(index);
Expand Down
4 changes: 2 additions & 2 deletions src/model/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ impl ModelRun for Model<'_> {
&self,
tokens: Vec<Vec<u16>>,
state: &ModelState,
compute_head: Vec<bool>,
should_output: Vec<bool>,
) -> Result<(Arc<Output>, Vec<Option<usize>>)> {
let context = &self.context;
let tensor = &self.tensor;
Expand Down Expand Up @@ -754,7 +754,7 @@ impl ModelRun for Model<'_> {
.cursors
.iter()
.filter(|cursor| cursor.len > 0)
.filter(|cursor| compute_head[cursor.batch])
.filter(|cursor| should_output[cursor.batch])
.enumerate()
.map(|(index, cursor)| {
redirect[cursor.batch] = Some(index);
Expand Down
4 changes: 2 additions & 2 deletions src/model/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ impl ModelRun for Model<'_> {
&self,
tokens: Vec<Vec<u16>>,
state: &ModelState,
compute_head: Vec<bool>,
should_output: Vec<bool>,
) -> Result<(Arc<Output>, Vec<Option<usize>>)> {
let context = &self.context;
let tensor = &self.tensor;
Expand Down Expand Up @@ -805,7 +805,7 @@ impl ModelRun for Model<'_> {
.cursors
.iter()
.filter(|cursor| cursor.len > 0)
.filter(|cursor| compute_head[cursor.batch])
.filter(|cursor| should_output[cursor.batch])
.enumerate()
.map(|(index, cursor)| {
redirect[cursor.batch] = Some(index);
Expand Down

0 comments on commit 45268cf

Please sign in to comment.