Skip to content

Commit

Permalink
Fix a bug about head gathering.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Sep 7, 2023
1 parent 9f6a0e2 commit e61023f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ async fn run(cli: Cli) -> Result<()> {
// [`BackedState::repeat`] is helpful if you want to create batch of states from the same input.
let state = ModelState::new(&context, model.info(), tokens.len());

let mut num_tokens = [100usize, 200, 300, 400]
let mut num_tokens = [100usize, 400, 200, 300]
.to_vec()
.repeat((cli.batch + prompts.len() - 1) / prompts.len())[..cli.batch]
.repeat((cli.batch + 3) / 4)[..cli.batch]
.to_vec();
loop {
#[cfg(not(debug_assertions))]
Expand Down
2 changes: 1 addition & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ impl<'a> Model<'a> {
let stack = self.request_stack(num_batch);

// gather and group copy operations
let (head_ops, head_x) = if num_token == 1 || num_token == max_batch {
let (head_ops, head_x) = if num_token == 1 || num_token == num_header {
(vec![], &buffer.ffn_x)
} else {
let mut start = 0;
Expand Down

0 comments on commit e61023f

Please sign in to comment.