Skip to content

Commit

Permalink
more convenient way of getting meta-information (dimensions) from tra…
Browse files Browse the repository at this point in the history
…nsformers model
  • Loading branch information
Jeronymous committed Jan 29, 2024
1 parent 987c5ff commit dabd52c
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions whisper_timestamped/transcribe.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -2483,17 +2483,18 @@ def __init__(self, model, processor, generation_config):
self.device = model.device

# Dimensions
model_config = model.config
self.dims = whisper.model.ModelDimensions(
n_mels = model.get_encoder().get_input_embeddings().in_channels,
n_audio_ctx = 1500,
n_audio_state = model.get_encoder().get_input_embeddings().out_channels,
n_audio_head = model.get_encoder().layers[0].self_attn.num_heads,
n_audio_layer = len(model.get_encoder().layers),
n_vocab = model.get_decoder().get_input_embeddings().num_embeddings,
n_text_ctx = 448,
n_text_state = model.get_decoder().get_input_embeddings().embedding_dim,
n_text_head = model.get_decoder().layers[0].self_attn.num_heads,
n_text_layer = len(model.get_decoder().layers),
n_mels = model_config.num_mel_bins, # model.get_encoder().get_input_embeddings().in_channels, # 80
n_audio_ctx = model_config.max_source_positions, # 1500
n_audio_state = model_config.d_model, # model.get_encoder().get_input_embeddings().out_channels, # 768
n_audio_head = model_config.encoder_attention_heads, # model.get_encoder().layers[0].self_attn.num_heads,
n_audio_layer = model_config.encoder_layers, # len(model.get_encoder().layers),
n_vocab = model_config.vocab_size, # model.get_decoder().get_input_embeddings().num_embeddings, # ~51865
n_text_ctx = model_config.max_length, # 448
n_text_state = model_config.d_model, # model.get_decoder().get_input_embeddings().embedding_dim, # 768
n_text_head = model_config.decoder_attention_heads, # model.get_decoder().layers[0].self_attn.num_heads,
n_text_layer = model_config.decoder_layers, # len(model.get_decoder().layers),
)

# Tokenization
Expand Down Expand Up @@ -3092,4 +3093,4 @@ def filtered_keys(result, keys = [


if __name__ == "__main__":
cli()
cli()

0 comments on commit dabd52c

Please sign in to comment.