From 10cbde6d3cd3aed94e1efc76a5d0aefb1ae6e87b Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:27:43 +0000 Subject: [PATCH 1/3] Rename configs to it's respective models --- src/fairseq2/recipes/eval/__init__.py | 4 ++-- src/fairseq2/recipes/eval/asr.py | 2 +- src/fairseq2/recipes/eval/configs.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/recipes/eval/__init__.py b/src/fairseq2/recipes/eval/__init__.py index e1b97cb0a..c21e2ca32 100644 --- a/src/fairseq2/recipes/eval/__init__.py +++ b/src/fairseq2/recipes/eval/__init__.py @@ -6,7 +6,7 @@ from fairseq2.logging import get_log_writer from fairseq2.recipes.cli import Cli, CliGroup, RecipeCommandHandler -from fairseq2.recipes.eval.configs import hf_presets +from fairseq2.recipes.eval.configs import wav2vec2_presets, whisper_presets log = get_log_writer(__name__) @@ -16,7 +16,7 @@ def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None: handler = RecipeCommandHandler( load_wav2vec2_asr_evaluator, - preset_configs=hf_presets, + preset_configs=wav2vec2_presets, default_preset="librispeech_asr", ) group.add_command( diff --git a/src/fairseq2/recipes/eval/asr.py b/src/fairseq2/recipes/eval/asr.py index 9aefb84b1..7b297c8fe 100644 --- a/src/fairseq2/recipes/eval/asr.py +++ b/src/fairseq2/recipes/eval/asr.py @@ -26,7 +26,7 @@ from fairseq2.models.sequence import SequenceBatch from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.recipes.eval.configs import HFEvalConfig, hf_presets +from fairseq2.recipes.eval.configs import HFEvalConfig, wav2vec2_presets, whisper_presets from fairseq2.recipes.evaluator import HFEvaluator from fairseq2.recipes.utils.setup import setup_root_gang from fairseq2.typing import META, DataType diff --git a/src/fairseq2/recipes/eval/configs.py b/src/fairseq2/recipes/eval/configs.py index 4880bf846..736198566 100644 --- a/src/fairseq2/recipes/eval/configs.py +++ b/src/fairseq2/recipes/eval/configs.py @@ -20,4 +20,5 @@ class HFEvalConfig: """The name of the model to evaluate.""" -hf_presets = ConfigRegistry[HFEvalConfig]() +wav2vec2_presets = ConfigRegistry[HFEvalConfig]() +whisper_presets = ConfigRegistry[HFEvalConfig]() From c3be4959e0a9bc1edc83410a7f1ca2b9ea2396e1 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:10:36 +0000 Subject: [PATCH 2/3] Refactor asr.py into ASREvaluator --- src/fairseq2/recipes/eval/__init__.py | 24 ++- src/fairseq2/recipes/eval/asr.py | 270 ++++++++++++-------------- 2 files changed, 150 insertions(+), 144 deletions(-) diff --git a/src/fairseq2/recipes/eval/__init__.py b/src/fairseq2/recipes/eval/__init__.py index c21e2ca32..1e1dd478c 100644 --- a/src/fairseq2/recipes/eval/__init__.py +++ b/src/fairseq2/recipes/eval/__init__.py @@ -4,18 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from pathlib import Path +from typing import Any, Callable + from fairseq2.logging import get_log_writer from fairseq2.recipes.cli import Cli, CliGroup, RecipeCommandHandler +from fairseq2.recipes.eval.asr import AsrEvalConfig from fairseq2.recipes.eval.configs import wav2vec2_presets, whisper_presets log = get_log_writer(__name__) def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None: - from fairseq2.recipes.eval.asr import load_wav2vec2_asr_evaluator + from fairseq2.recipes.eval.asr import ASREvaluator handler = RecipeCommandHandler( - load_wav2vec2_asr_evaluator, + ASREvaluator(), preset_configs=wav2vec2_presets, default_preset="librispeech_asr", ) @@ -26,6 +30,21 @@ def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None: ) +def _add_whisper_asr_eval_cli(group: CliGroup) -> None: + from fairseq2.recipes.eval.asr import ASREvaluator + + handler = RecipeCommandHandler( + ASREvaluator(), + preset_configs=whisper_presets, + default_preset="librispeech_asr", + ) + group.add_command( + "whisper-asr", + handler, + help="evaluate a whisper ASR model in downstream benchmark", + ) + + def has_datasets() -> bool: try: import datasets # type: ignore[attr-defined,import-untyped,import-not-found] @@ -57,3 +76,4 @@ def _setup_eval_cli(cli: Cli) -> None: if all((has_datasets(), has_evaluate())): _add_wav2vev2_asr_eval_cli(group) + # _add_whisper_asr_eval_cli(group) diff --git a/src/fairseq2/recipes/eval/asr.py b/src/fairseq2/recipes/eval/asr.py index 7b297c8fe..de11bad50 100644 --- a/src/fairseq2/recipes/eval/asr.py +++ b/src/fairseq2/recipes/eval/asr.py @@ -8,7 +8,7 @@ import math from dataclasses import dataclass from pathlib import Path -from typing import Any, List, Optional, Tuple, cast +from typing import Any, Callable, List, Optional, Tuple, cast import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] @@ -18,7 +18,6 @@ from fairseq2.data.data_pipeline import SequenceData from fairseq2.data.text import load_text_tokenizer -from fairseq2.data.text.text_tokenizer import TextTokenEncoder, TextTokenizer from fairseq2.datasets.batching import StaticBatching from fairseq2.datasets.huggingface import Example, create_hf_reader from fairseq2.logging import get_log_writer @@ -26,7 +25,11 @@ from fairseq2.models.sequence import SequenceBatch from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.recipes.eval.configs import HFEvalConfig, wav2vec2_presets, whisper_presets +from fairseq2.recipes.eval.configs import ( + HFEvalConfig, + wav2vec2_presets, + whisper_presets, +) from fairseq2.recipes.evaluator import HFEvaluator from fairseq2.recipes.utils.setup import setup_root_gang from fairseq2.typing import META, DataType @@ -39,9 +42,6 @@ class AsrEvalConfig(HFEvalConfig): """Holds the configuration of a ASR evaluation recipe.""" - # converter: Callable[[Example], Seq2SeqBatch] - # """The converter function to convert collated data into Seq2SeqBatch""" - tokenizer_name: str = "librispeech_asr" """The tokenizer to use.""" @@ -74,148 +74,134 @@ class AsrEvalConfig(HFEvalConfig): """The data type of the model.""" -def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: - """ - Converts a collated batch of examples into a Seq2SeqBatch. - - Args: - examples (dict): A dictionary containing "audio" and "text" keys. - - Returns: - Seq2SeqBatch: A batch of audio and text sequences. - """ - source_data = cast(SequenceData, examples["audio"]) - target_data = cast(SequenceData, examples["text"]) - - source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) - target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data) - - return Seq2SeqBatch( - source_seqs, - source_padding_mask, - target_seqs, - target_padding_mask, - examples, - ) - - -def _preprocess_example( - example: Example, encoder: TextTokenEncoder, device: torch.device -) -> Example: - """ - Preprocesses an individual example by converting the audio array to a PyTorch tensor - and encoding the text. - - Args: - example (dict): A dictionary containing "audio" and "text" keys. - - Returns: - dict: A dictionary with "audio" and "text" as PyTorch tensors. - """ - audio_tensor = ( - torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device) - ) - text_tensor = encoder(example["text"].lower()).to(device) - return {"audio": audio_tensor, "text": text_tensor} - - -def seq2seq_preprocessor(batch: Seq2SeqBatch) -> Tuple[SequenceBatch, SequenceBatch]: - return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( - batch.target_seqs, batch.target_padding_mask - ) - - -def postprocesser( - outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer -) -> Tuple[List[str], List[str]]: - decoder = tokenizer.create_decoder() - pad_idx = tokenizer.vocab_info.pad_idx - - hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) - predictions = [decoder(item) for item in hypotheses] - references = [decoder(item) for item in targets.seqs.to(torch.int32)] - - return predictions, references - - -@hf_presets.decorator("librispeech_asr") -def _librispeech_asr_config() -> AsrEvalConfig: +@wav2vec2_presets.decorator("librispeech_asr") +def _wav2vec2_librispeech_asr_config() -> AsrEvalConfig: return AsrEvalConfig( dataset_name="librispeech_asr", model_name="wav2vec2_asr_base_10h", - split="test.other" - # converter=librispeech_asr_to_batch, + split="test.other", ) -def load_wav2vec2_asr_evaluator( - config: HFEvalConfig, output_dir: Path -) -> HFEvaluator[Seq2SeqBatch]: - """ - Load the evaluator used for downstream evaluation of the model - in a downstream dataset and report BLEU scores - - Args: - config (HFEvalConfig): The configuration for the evaluation. - output_dir (Path): The output directory to store the evaluation results. - - Returns: - HFEvaluator: Evaluation process results. - """ - if not isinstance(config, AsrEvalConfig): - raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") - - iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) - max_samples = config.max_samples if config.max_samples is not None else math.inf - # Load a subset of the dataset if max_samples is set - ds = Dataset.from_generator( - lambda: ( - yield from ( - item for idx, item in enumerate(iterable_ds) if idx < max_samples - ) - ), - features=iterable_ds.features, - ) - - gang = setup_root_gang(log) - - if gang.rank == 0: - init_device = gang.device - else: - init_device = META - - tokenizer = load_text_tokenizer(config.tokenizer_name) - encoder = tokenizer.create_encoder(device=init_device) - - ds = ds.map(lambda x: _preprocess_example(x, encoder, init_device)) - format = { - "type": "torch", - "format_kwargs": {"dtype": torch.float16, "device": init_device}, - } - ds.set_format(**format, columns=["audio", "text"]) - - pipeline_reader = create_hf_reader( - dataset=ds, - gang=gang, - converter=_librispeech_asr_to_batch, - batching=StaticBatching(config.max_num_elements), - num_prefetch=config.num_prefetch, - pad_value=tokenizer.vocab_info.pad_idx, - max_seq_len=config.max_audio_len, - ) - - model = load_wav2vec2_asr_model( - config.model_name, device=init_device, dtype=config.dtype +@whisper_presets.decorator("librispeech_asr") +def _whisper_librispeech_asr_config() -> AsrEvalConfig: + return AsrEvalConfig( + dataset_name="librispeech_asr", model_name="whisper", split="test.other" ) - wall_watch = Stopwatch(start=True, device=init_device) - return HFEvaluator[Seq2SeqBatch]( - model=model, - metrics=["bleu"], - gang=gang, - data_reader=pipeline_reader, - wall_watch=wall_watch, - preprocessor=seq2seq_preprocessor, - postprocessor=lambda x, y: postprocesser(x, y, tokenizer), - ) +class ASREvaluator: + def __init__(self) -> None: + self.gang = setup_root_gang(log) + self.init_device = self.gang.device if self.gang.rank == 0 else META + + def _librispeech_asr_to_batch(self, examples: Example) -> Seq2SeqBatch: + source_data = cast(SequenceData, examples["audio"]) + target_data = cast(SequenceData, examples["text"]) + + source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data) + + return Seq2SeqBatch( + source_seqs, + source_padding_mask, + target_seqs, + target_padding_mask, + examples, + ) + + def _preprocess_example(self, example: Example) -> Example: + audio_tensor = ( + torch.from_numpy(example["audio"]["array"]) + .to(torch.float16) + .to(self.init_device) + ) + text_tensor = self.encoder(example["text"].lower()).to(self.init_device) + return {"audio": audio_tensor, "text": text_tensor} + + def seq2seq_preprocessor( + self, batch: Seq2SeqBatch + ) -> Tuple[SequenceBatch, SequenceBatch]: + return SequenceBatch( + batch.source_seqs, batch.source_padding_mask + ), SequenceBatch(batch.target_seqs, batch.target_padding_mask) + + def postprocesser( + self, outputs: Any, targets: SequenceBatch + ) -> Tuple[List[str], List[str]]: + decoder = self.tokenizer.create_decoder() + pad_idx = self.tokenizer.vocab_info.pad_idx + + hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) + predictions = [decoder(item) for item in hypotheses] + references = [decoder(item) for item in targets.seqs.to(torch.int32)] + + return predictions, references + + def _load_evaluator(self) -> HFEvaluator[Seq2SeqBatch]: + iterable_ds = load_dataset( + self.config.dataset_name, split=self.config.split, streaming=True + ) + max_samples = ( + self.config.max_samples if self.config.max_samples is not None else math.inf + ) + + ds = Dataset.from_generator( + lambda: ( + yield from ( + item for idx, item in enumerate(iterable_ds) if idx < max_samples + ) + ), + features=iterable_ds.features, + ) + + ds = ds.map(self._preprocess_example) + format = { + "type": "torch", + "format_kwargs": {"dtype": torch.float16, "device": self.init_device}, + } + ds.set_format(**format, columns=["audio", "text"]) + + pipeline_reader = create_hf_reader( + dataset=ds, + gang=self.gang, + converter=self._librispeech_asr_to_batch, + batching=StaticBatching(self.config.max_num_elements), + num_prefetch=self.config.num_prefetch, + pad_value=self.tokenizer.vocab_info.pad_idx, + max_seq_len=self.config.max_audio_len, + ) + + model = load_wav2vec2_asr_model( + self.config.model_name, device=self.init_device, dtype=self.config.dtype + ) + + wall_watch = Stopwatch(start=True, device=self.init_device) + + return HFEvaluator[Seq2SeqBatch]( + model=model, + metrics=["bleu"], + gang=self.gang, + data_reader=pipeline_reader, + wall_watch=wall_watch, + preprocessor=self.seq2seq_preprocessor, + postprocessor=lambda x, y: self.postprocesser(x, y), + ) + + def __call__(self, config: HFEvalConfig, output_dir: Path) -> Callable[[], None]: + """ + This method will run the evaluation process. + + Returns: + A callable that will run the evaluation process + """ + if not isinstance(config, AsrEvalConfig): + raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") + + self.config = config + self.output_dir = output_dir + + self.tokenizer = load_text_tokenizer(config.tokenizer_name) + self.encoder = self.tokenizer.create_encoder(device=self.init_device) + + return self._load_evaluator() From 4a91e4be9d1574689fe117d7bca27e893bf347f8 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:07:59 +0000 Subject: [PATCH 3/3] Refactor ASREvaluator with _load_model and _load_dataset --- src/fairseq2/recipes/eval/asr.py | 111 ++++++++++++++++++++++++------- 1 file changed, 86 insertions(+), 25 deletions(-) diff --git a/src/fairseq2/recipes/eval/asr.py b/src/fairseq2/recipes/eval/asr.py index de11bad50..dfd7b1657 100644 --- a/src/fairseq2/recipes/eval/asr.py +++ b/src/fairseq2/recipes/eval/asr.py @@ -21,9 +21,10 @@ from fairseq2.datasets.batching import StaticBatching from fairseq2.datasets.huggingface import Example, create_hf_reader from fairseq2.logging import get_log_writer +from fairseq2.models import load_model +from fairseq2.models.model import Model from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.models.sequence import SequenceBatch -from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model from fairseq2.nn.padding import get_seqs_and_padding_mask from fairseq2.recipes.eval.configs import ( HFEvalConfig, @@ -94,8 +95,15 @@ class ASREvaluator: def __init__(self) -> None: self.gang = setup_root_gang(log) self.init_device = self.gang.device if self.gang.rank == 0 else META + self.wall_watch = Stopwatch(device=self.init_device) - def _librispeech_asr_to_batch(self, examples: Example) -> Seq2SeqBatch: + def to_batch(self, examples: Example) -> Seq2SeqBatch: + """ + Convert the example data to a batch. + + Args: + examples: Collated and padded examples. + """ source_data = cast(SequenceData, examples["audio"]) target_data = cast(SequenceData, examples["text"]) @@ -111,17 +119,35 @@ def _librispeech_asr_to_batch(self, examples: Example) -> Seq2SeqBatch: ) def _preprocess_example(self, example: Example) -> Example: + """ + Preprocess the example data. + + Note: should be refactored and removed from this class. + + Args: + example: The example data. + + Returns: + audio and text tensors. + """ audio_tensor = ( torch.from_numpy(example["audio"]["array"]) - .to(torch.float16) + .to(self.config.dtype) .to(self.init_device) ) text_tensor = self.encoder(example["text"].lower()).to(self.init_device) return {"audio": audio_tensor, "text": text_tensor} - def seq2seq_preprocessor( - self, batch: Seq2SeqBatch - ) -> Tuple[SequenceBatch, SequenceBatch]: + def preprocessor(self, batch: Seq2SeqBatch) -> Tuple[SequenceBatch, SequenceBatch]: + """ + Preprocess the batch data. + + Args: + batch: The batch data. + + Returns: + A tuple of source and target sequences in the form of SequenceBatch. + """ return SequenceBatch( batch.source_seqs, batch.source_padding_mask ), SequenceBatch(batch.target_seqs, batch.target_padding_mask) @@ -129,6 +155,13 @@ def seq2seq_preprocessor( def postprocesser( self, outputs: Any, targets: SequenceBatch ) -> Tuple[List[str], List[str]]: + """ + Postprocess the outputs and targets to get the predictions and references. + + Args: + outputs: The model outputs. + targets: The target sequences. + """ decoder = self.tokenizer.create_decoder() pad_idx = self.tokenizer.vocab_info.pad_idx @@ -138,13 +171,19 @@ def postprocesser( return predictions, references - def _load_evaluator(self) -> HFEvaluator[Seq2SeqBatch]: - iterable_ds = load_dataset( - self.config.dataset_name, split=self.config.split, streaming=True - ) - max_samples = ( - self.config.max_samples if self.config.max_samples is not None else math.inf - ) + def _load_dataset( + self, dataset_name: str, split: str, max_samples: Optional[int] + ) -> Dataset: + """ + Load a huggingface dataset. + + Args: + dataset_name: The name of the dataset to load. + split: The split of the dataset to load. + max_samples: The maximum number of samples to load. + """ + iterable_ds = load_dataset(dataset_name, split=split, streaming=True) + max_samples = cast(int, max_samples if max_samples is not None else math.inf) ds = Dataset.from_generator( lambda: ( @@ -158,39 +197,57 @@ def _load_evaluator(self) -> HFEvaluator[Seq2SeqBatch]: ds = ds.map(self._preprocess_example) format = { "type": "torch", - "format_kwargs": {"dtype": torch.float16, "device": self.init_device}, + "format_kwargs": {"dtype": self.config.dtype, "device": self.init_device}, } ds.set_format(**format, columns=["audio", "text"]) + return ds + + def _load_model(self, model_name: str) -> Model: + """ + Load the model. + + Args: + model_name: The name of the model to load. + """ + model = load_model(model_name, device=self.init_device, dtype=self.config.dtype) + return cast(Model, model) + def _load_evaluator(self) -> HFEvaluator[Seq2SeqBatch]: + """ + Load the HFEvaluator for ASR. + + Returns: + The evaluator for ASR. + """ pipeline_reader = create_hf_reader( - dataset=ds, + dataset=self.dataset, gang=self.gang, - converter=self._librispeech_asr_to_batch, + converter=self.to_batch, batching=StaticBatching(self.config.max_num_elements), num_prefetch=self.config.num_prefetch, pad_value=self.tokenizer.vocab_info.pad_idx, max_seq_len=self.config.max_audio_len, ) - model = load_wav2vec2_asr_model( - self.config.model_name, device=self.init_device, dtype=self.config.dtype - ) - - wall_watch = Stopwatch(start=True, device=self.init_device) + self.wall_watch.start() return HFEvaluator[Seq2SeqBatch]( - model=model, + model=self.model, metrics=["bleu"], gang=self.gang, data_reader=pipeline_reader, - wall_watch=wall_watch, - preprocessor=self.seq2seq_preprocessor, + wall_watch=self.wall_watch, + preprocessor=self.preprocessor, postprocessor=lambda x, y: self.postprocesser(x, y), ) def __call__(self, config: HFEvalConfig, output_dir: Path) -> Callable[[], None]: """ - This method will run the evaluation process. + Create an evaluation process for ASR. + + Args: + config: The configuration of the evaluation process. + output_dir: The directory to store the evaluation results. Returns: A callable that will run the evaluation process @@ -201,7 +258,11 @@ def __call__(self, config: HFEvalConfig, output_dir: Path) -> Callable[[], None] self.config = config self.output_dir = output_dir + self.model = self._load_model(config.model_name) self.tokenizer = load_text_tokenizer(config.tokenizer_name) self.encoder = self.tokenizer.create_encoder(device=self.init_device) + self.dataset = self._load_dataset( + config.dataset_name, config.split, config.max_samples + ) return self._load_evaluator()