Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate whisper model with eval interface #2

Draft
wants to merge 3 commits into
base: cli/eval/interface
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/fairseq2/recipes/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Any, Callable

from fairseq2.logging import get_log_writer
from fairseq2.recipes.cli import Cli, CliGroup, RecipeCommandHandler
from fairseq2.recipes.eval.configs import hf_presets
from fairseq2.recipes.eval.asr import AsrEvalConfig
from fairseq2.recipes.eval.configs import wav2vec2_presets, whisper_presets

log = get_log_writer(__name__)


def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None:
from fairseq2.recipes.eval.asr import load_wav2vec2_asr_evaluator
from fairseq2.recipes.eval.asr import ASREvaluator
Copy link

@antoine-tran antoine-tran Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if changing the function "load_wav2vec2_asr_evaluator" to ASREvaluator is the best way. I'm not picky between having a function or a callable class, but the problem is that Whisper is an end-to-end model, while wav2vec2 is - as the name suggests - only an encoder that generates a vector. For wav2vec2, we need a text tokenizer and decoder, while for Whisper it is not required. So the "ASREvaluator" is still not abstract enough (at least your currently proposal, with self.tokenizer and self.decoder)

Basically I think we need just 2 functions to generate the HFEvaluator accordingly from its config (see my comments above)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes lots of sense. I was just afraid of having two many functions just to support different models and thier requirements and that's why I created the class.


handler = RecipeCommandHandler(
load_wav2vec2_asr_evaluator,
preset_configs=hf_presets,
ASREvaluator(),
preset_configs=wav2vec2_presets,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don'y need to have 2 registries for wav2vec2 and whisper

default_preset="librispeech_asr",
)
group.add_command(
Expand All @@ -26,6 +30,21 @@ def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None:
)


def _add_whisper_asr_eval_cli(group: CliGroup) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is highly redundant I think. We can try parameterize the evaluator setup function, the presets can be customized at runtime.

from fairseq2.recipes.eval.asr import ASREvaluator

handler = RecipeCommandHandler(
ASREvaluator(),
preset_configs=whisper_presets,
default_preset="librispeech_asr",
)
group.add_command(
"whisper-asr",
handler,
help="evaluate a whisper ASR model in downstream benchmark",
)


def has_datasets() -> bool:
try:
import datasets # type: ignore[attr-defined,import-untyped,import-not-found]
Expand Down Expand Up @@ -57,3 +76,4 @@ def _setup_eval_cli(cli: Cli) -> None:

if all((has_datasets(), has_evaluate())):
_add_wav2vev2_asr_eval_cli(group)
# _add_whisper_asr_eval_cli(group)
Loading
Loading