Skip to content

Commit

Permalink
Replace batch_size with variable
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 3, 2023
1 parent bdfebfc commit daa667c
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -1263,12 +1263,14 @@ export class PreTrainedModel extends Callable {
Object.assign(decoderFeeds, pastKeyValues)
} else {
// TODO support batches (i.e., batch_size > 1)
const batch_size = 1;

// @ts-ignore
if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
// @ts-ignore
let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv];
let encoder_dims = [batch_size, this.num_encoder_heads, 0, this.encoder_dim_kv];
// @ts-ignore
let decoder_dims = [1, this.num_decoder_heads, 0, this.decoder_dim_kv];
let decoder_dims = [batch_size, this.num_decoder_heads, 0, this.decoder_dim_kv];
// @ts-ignore
for (let i = 0; i < this.num_decoder_layers; ++i) {
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
Expand All @@ -1279,15 +1281,15 @@ export class PreTrainedModel extends Callable {
} else if (this.config.model_type === 'falcon') {
// NOTE: Custom implementation for Falcon
// @ts-ignore
let dims = [1 * this.num_heads, 0, this.dim_kv]
let dims = [batch_size * this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
}
} else if (this.config.multi_query) { // e.g., for `gpt_bigcode`
// @ts-ignore
let dims = [1 * this.num_heads, 0, 2 * this.dim_kv]
let dims = [batch_size * this.num_heads, 0, 2 * this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
Expand All @@ -1296,17 +1298,17 @@ export class PreTrainedModel extends Callable {
// NOTE: Custom implementation for Bloom

// @ts-ignore
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
let keyDims = [batch_size * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
// @ts-ignore
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
let valueDims = [batch_size * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
}
} else { // Decoder-only
// @ts-ignore
let dims = [1, this.num_heads, 0, this.dim_kv]
let dims = [batch_size, this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
Expand Down

0 comments on commit daa667c

Please sign in to comment.