Skip to content

Commit

Permalink
Add support for Falcon models
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 3, 2023
1 parent 2206ffb commit bdfebfc
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
4 changes: 4 additions & 0 deletions scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@
'per_channel': False,
'reduce_range': False,
},
'falcon': {
'per_channel': False,
'reduce_range': False,
},

# Encoder-decoder models
'whisper': {
Expand Down
5 changes: 5 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@
# Document Question Answering
'naver-clova-ix/donut-base-finetuned-docvqa',
],
'falcon': [
# Text generation
'Rocketknight1/tiny-random-falcon-7b',
'fxmarty/really-tiny-falcon-testing',
],
'gpt_neo': [
# Text generation
'EleutherAI/gpt-neo-125M',
Expand Down
43 changes: 42 additions & 1 deletion src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -1276,9 +1276,18 @@ export class PreTrainedModel extends Callable {
decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims)
}
} else if (this.config.model_type === 'falcon') {
// NOTE: Custom implementation for Falcon
// @ts-ignore
let dims = [1 * this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
}
} else if (this.config.multi_query) { // e.g., for `gpt_bigcode`
// @ts-ignore
let dims = [1, 0, 2 * this.dim_kv]
let dims = [1 * this.num_heads, 0, 2 * this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
Expand Down Expand Up @@ -3784,6 +3793,36 @@ export class MistralModel extends MistralPreTrainedModel { }
export class MistralForCausalLM extends MistralPreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// Falcon models
/**
* The bare Falcon Model outputting raw hidden-states without any specific head on top.
*/
export class FalconPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `FalconPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id

this.num_heads = this.config.num_attention_heads;
this.num_layers = this.config.num_hidden_layers;
this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
}
}

export class FalconModel extends FalconPreTrainedModel { }

export class FalconForCausalLM extends FalconPreTrainedModel { }
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// AutoModels, used to simplify construction of PreTrainedModels
Expand Down Expand Up @@ -3912,6 +3951,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['mpt', ['MptModel', MptModel]],
['opt', ['OPTModel', OPTModel]],
['mistral', ['MistralModel', MistralModel]],
['falcon', ['FalconModel', FalconModel]],
]);

const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([
Expand Down Expand Up @@ -3977,6 +4017,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
['opt', ['OPTForCausalLM', OPTForCausalLM]],
['mbart', ['MBartForCausalLM', MBartForCausalLM]],
['mistral', ['MistralForCausalLM', MistralForCausalLM]],
['falcon', ['FalconForCausalLM', FalconForCausalLM]],
]);

const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
Expand Down

0 comments on commit bdfebfc

Please sign in to comment.