diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index 8298567ff7cc..586bedc03c32 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -1079,23 +1079,22 @@ To run 2D bucketing with 30 buckets sub-divided into 5 sub-buckets each (150 buc # The script's output: Use the following options in your config: + use_bucketing=1 num_buckets=30 bucket_duration_bins=[[1.91,10],[1.91,17],[1.91,25],... - max_duration=... - max_tps=... - + The max_tps setting below is optional, use it if your data has low quality long transcript outliers: + max_tps=[13.2,13.2,11.8,11.8,...] Note that the output in ``bucket_duration_bins`` is a nested list, where every bin specifies the maximum duration and the maximum number of tokens that go into the bucket. Passing this option to Lhotse dataloader will automatically enable 2D bucketing. -Note the presence of ``max_duration`` and ``max_tps`` (token-per-second) options: -these need to be included in dataloader's configuration to ensure we can use the buckets correctly at runtime -in case of outliers. -In general, if you change your data in training, it is highly advisable to re-estimate the duration bins. - -Note that reasonable values for tokens-per-second rarely exceed 12tps with reasonably good tokenizers. -If you find your dataset's TPS is much higher than that, you may have some bad data outliers. -In that case you may specify ``--max_tps`` option to discard those both in bin estimation and dataloading. + +Note the presence of ``max_tps`` (token-per-second) option. +It is optional to include it in the dataloader configuration: if you do, we will apply an extra filter +that discards examples which have more tokens per second than the threshold value. +The threshold is determined for each bucket separately based on data distribution, and can be controlled +with the option ``--token_outlier_threshold``. +This filtering is useful primarily for noisy datasets to discard low quality examples / outliers. We also support aggregate tokenizers for 2D bucketing estimation: diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index f40dffb79467..1bb1410555d6 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -71,7 +71,7 @@ def __init__( super().__init__() self.tokenizer = tokenizer self.load_audio = AudioSamples(fault_tolerant=True) - self.padding_value = self.tokenizer.pad + self.padding_value = self.tokenizer.pad_id self.prompt = prompt def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index 25ade32fffd8..47fbddae5edc 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -110,8 +110,12 @@ def _setup_monolingual_tokenizer(self, tokenizer_cfg: DictConfig): if special_tokens is not None: raise ValueError("`special_tokens` are no longer supported for SentencePiece based tokenizers.") - # Update special tokens - self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path) + if "custom_tokenizer" in self.tokenizer_cfg: + self.tokenizer = self.from_config_dict( + {"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "model_path": model_path} + ) + else: + self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path) if 'vocab_path' in self.tokenizer_cfg: vocab_path = self.tokenizer_cfg.get('vocab_path') diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 63e93d8cf860..b2c74c16065a 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -190,7 +190,7 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: "force_finite": config.get("force_finite", False), "max_open_streams": config.get("max_open_streams", None), "token_equivalent_duration": config.get("token_equivalent_duration", None), - "tarred_random_access": config.get("tarred_random_access", False), + "skip_missing_manifest_entries": config.get("skip_missing_manifest_entries", False), } input_cfg = config.input_cfg if isinstance(input_cfg, (str, Path)): @@ -510,11 +510,11 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: LazyNeMoTarredIterator( config.manifest_filepath, tar_paths=config.tarred_audio_filepaths, - tarred_random_access=config.tarred_random_access, + skip_missing_manifest_entries=config.skip_missing_manifest_entries, **common_kwargs, ) ) - if not config.tarred_random_access and not force_finite: + if not force_finite: cuts = cuts.repeat() else: cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs)) @@ -552,7 +552,7 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: nemo_iter = LazyNeMoTarredIterator( manifest_path=manifest_path, tar_paths=tar_path, - tarred_random_access=config.tarred_random_access, + skip_missing_manifest_entries=config.skip_missing_manifest_entries, **common_kwargs, ) else: diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index bad866e6dac9..b17fa51a2660 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -33,7 +33,7 @@ make_worker_init_fn, ) from lhotse.dataset.dataloading import resolve_seed -from lhotse.dataset.sampling.base import CutSampler, TimeConstraint +from lhotse.dataset.sampling.base import CutSampler, SamplingConstraint, TimeConstraint from lhotse.lazy import LazyFlattener from lhotse.utils import fastcopy, fix_random_seed from omegaconf import DictConfig, OmegaConf @@ -44,6 +44,7 @@ read_cutset_from_config, ) from nemo.collections.common.data.lhotse.sampling import ( + BucketingFilter, DurationFilter, FixedBucketBatchSizeConstraint2D, MultimodalFixedBucketBatchSizeConstraint2D, @@ -76,7 +77,8 @@ class LhotseDataLoadingConfig: cuts_path: str | None = None shar_path: Any = None # str | list[str | tuple[str, float | int]] | None = None # Enable this to support dataloading from JSON manifests that reference subsets of audio tar files. - tarred_random_access: bool = False + skip_missing_manifest_entries: bool = False + tarred_random_access: bool = False # deprecated, replaced by: skip_missing_manifest_entries # 2. Batch size. # a. Existing NeMo options. batch_size: int | None = None @@ -91,6 +93,7 @@ class LhotseDataLoadingConfig: bucket_duration_bins: Any = None # list[float] | list[list[float]] | None = None bucket_buffer_size: int = 10000 concurrent_bucketing: bool = True # fetches data in a background thread + bucketing_2d_strict_mode: bool = True # reduces padding by discarding significant outliers # d. Other Lhotse sampling options. shuffle_buffer_size: int | None = 10000 drop_last: bool = False @@ -117,7 +120,7 @@ class LhotseDataLoadingConfig: min_duration: float | None = -1 max_duration: float | None = float("inf") min_tps: int = -1 # allowed tokens per second (audio-only) - max_tps: float = float("inf") + max_tps: Any = float("inf") # float | list[float] # * Text input min_tokens: int | None = None max_tokens: int | None = None @@ -125,7 +128,7 @@ class LhotseDataLoadingConfig: # For 2D bucketing it's always false, as we report a tuple of (context_len, answer_len). measure_total_length: bool = True min_tpt: int = -1 # allowed tokens per token (text-only) - max_tpt: float = float("inf") + max_tpt: Any = float("inf") # float | list[float] # 3. Supported existing NeMo options. shuffle: bool = False @@ -530,7 +533,7 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # Select the strategy customizing Lhotse sampler behaviour. # Provides support for dynamic batch sizes, multimodal dataloading, 2D bucketing, etc. bucket_duration_bins = determine_bucket_duration_bins(config) - constraint = determine_sampling_constraint(bucket_duration_bins, config) + cuts, constraint = determine_sampling_constraint(cuts, bucket_duration_bins, config) # 3. The sampler. if config.use_bucketing: @@ -608,13 +611,15 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No return sampler, use_iterable_dataset -def determine_sampling_constraint(bucket_duration_bins, config): +def determine_sampling_constraint(cuts: CutSet, bucket_duration_bins, config) -> tuple[CutSet, SamplingConstraint]: """ Select an appropriate sampling strategy (constraint) for Lhotse samplers based on the configuration. Sampling constraint affects the batch size (static/dynamic) and bucketing behaviour (1D/2D). It is the appropriate customization point to introduce support of other modalities, as it defines a method for example sequence length measurement (audio duration, text tokens, etc.). + Some constraints apply extra filter on ``cuts`` which is why we accept and return the ``CutSet``. + Lhotse's default is :class:`TimeConstraint` for regular audio data, other available options are multimodal constraints (joint text + audio) and their 2D bucketing extensions. """ @@ -627,7 +632,10 @@ def determine_sampling_constraint(bucket_duration_bins, config): max_seq_len_buckets=bucket_duration_bins, batch_sizes=config.bucket_batch_size, token_equivalent_duration=config.token_equivalent_duration, + strict_2d=config.bucketing_2d_strict_mode, + max_ratio=config.max_tpt if isinstance(config.max_tpt, Sequence) else None, ) + cuts = cuts.filter(BucketingFilter(constraint)) else: constraint = MultimodalSamplingConstraint( token_equivalent_duration=config.token_equivalent_duration, @@ -643,14 +651,17 @@ def determine_sampling_constraint(bucket_duration_bins, config): constraint = FixedBucketBatchSizeConstraint2D( max_seq_len_buckets=bucket_duration_bins, batch_sizes=config.bucket_batch_size, + strict_2d=config.bucketing_2d_strict_mode, + max_ratio=config.max_tps if isinstance(config.max_tps, Sequence) else None, ) + cuts = cuts.filter(BucketingFilter(constraint)) else: constraint = TimeConstraint( max_cuts=config.batch_size, max_duration=config.batch_duration, quadratic_duration=config.quadratic_duration, ) - return constraint + return cuts, constraint def determine_bucket_duration_bins(config): @@ -702,22 +713,28 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi supported_keys = set(OmegaConf.to_container(default).keys()) received_keys = set(OmegaConf.to_container(config).keys()) unsupported_keys = received_keys - supported_keys + unsupported_keys.discard("use_lhotse") if unsupported_keys: - warnings.warn( - f"The following configuration keys are no longer supported " f"and ignored: {','.join(unsupported_keys)}", - category=DeprecationWarning, + logging.warning( + f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}", ) config = OmegaConf.masked_copy(config, list(supported_keys)) - return OmegaConf.merge(default, config) + config = OmegaConf.merge(default, config) + if config.get("tarred_random_access", False): + logging.warning( + "Option 'tarred_random_access' is deprecated and replaced with 'skip_missing_manifest_entries'.", + ) + config.skip_missing_manifest_entries = True + if config.skip_missing_manifest_entries: + logging.warning( + "Note: skip_missing_manifest_entries is set to True. " + "If any of your manifests and tar files are mismatched, the entire tar file will be skipped without warning. " + "It's your responsibility to ensure data integrity with this setting." + ) -def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: - assert not ( - config.force_map_dataset and config.force_iterable_dataset - ), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True" - use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset - return use_iterable_dataset + return config def tokenize(example, tokenizer): diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index a34a2c074a11..ce05c177154e 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -223,6 +223,11 @@ class LazyNeMoTarredIterator: This can be used for other cloud storage APIs such as S3, GCS, etc. The same mechanism applies to ``manifest_path``. + If your data has been filtered so that the JSON manifests refer to just a subset of recordings, + set ``skip_missing_manifest_entries` to ``True``. + This will still read the tar files sequentially (very fast) and discard the audio files that + are not present in the corresponding manifest. + The ``shard_seed`` argument is used to seed the RNG shuffling the shards. By default, it's ``trng`` which samples a seed number from OS-provided TRNG (see Python ``secrets`` module). Seed is resolved lazily so that every dataloading worker may sample a different one. @@ -264,10 +269,10 @@ def __init__( shard_seed: int | Literal["trng", "randomized"] = "trng", text_field: str = "text", lang_field: str = "lang", - tarred_random_access: bool = False, + skip_missing_manifest_entries: bool = False, extra_fields: list[dict[str, str]] | None = None, ) -> None: - self.tarred_random_access = tarred_random_access + self.skip_missing_manifest_entries = skip_missing_manifest_entries self.shard_id_to_manifest: dict[int, Iterable[dict]] self.paths = expand_sharded_filepaths(manifest_path) if len(self.paths) == 1: @@ -346,29 +351,21 @@ def _validate(self) -> None: def shard_ids(self) -> List[int]: return sorted(self.shard_id_to_manifest.keys()) - def _iter_random_read(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]: - with tarfile.open(fileobj=BytesIO(open_best(tar_path, mode="rb").read()), mode="r") as tar: - for data in shard_manifest: + def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]: + with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar: + for tar_info in tar: try: - tar_info = tar.getmember(data) + data = shard_manifest[tar_info.name] raw_audio = tar.extractfile(tar_info).read() yield data, raw_audio, tar_info except KeyError as e: - raise RuntimeError( - f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " - f"The following audio_filepath='{data['audio_filepath']}' was not found in the tar file." - ) from e - - def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]: - with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar: - for tar_info in tar: - assert tar_info.name in shard_manifest, ( - f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " - f"Cannot locate JSON entry for tar file '{tar_info.name}'" - ) - data = shard_manifest[tar_info.name] - raw_audio = tar.extractfile(tar_info).read() - yield data, raw_audio, tar_info + if self.skip_missing_manifest_entries: + continue + else: + raise RuntimeError( + f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " + f"Cannot locate JSON entry for tar file '{tar_info.name}'" + ) from e def __iter__(self) -> Generator[Cut, None, None]: shard_ids = self.shard_ids @@ -384,7 +381,6 @@ def __iter__(self) -> Generator[Cut, None, None]: # They have multiple JSONL entries where audio paths end with '-sub1', '-sub2', etc. for each offset. offset_pattern = re.compile(r'^(?P.+)(?P-sub\d+)(?P\.\w+)?$') - iter_fn = self._iter_random_read if self.tarred_random_access else self._iter_sequential for sid in shard_ids: manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0] @@ -398,7 +394,7 @@ def basename(d: dict) -> str: shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid]) tar_path = self.shard_id_to_tar_path[sid] try: - for data, raw_audio, tar_info in iter_fn(tar_path, shard_manifest, manifest_path): + for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path): meta = soundfile.info(BytesIO(raw_audio)) recording = Recording( id=tar_info.path, diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index d645e3816300..f5b1a2987754 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -15,9 +15,11 @@ import bisect import logging import math +from bisect import bisect_left, bisect_right from dataclasses import dataclass from typing import Any, Sequence +import numpy as np from lhotse.cut import Cut from lhotse.dataset import SamplingConstraint, TokenConstraint from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint @@ -110,11 +112,30 @@ class FixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint): """ Sampling strategy that customizes Lhotse samplers to support 2D bucket selection (it also supports 1D). It is intended only for audio examples (i.e., Lhotse Cut objects). + + When ``strict_2d`` is set, we only consider sub-buckets for a single bucket that is the best match. + When set to ``False``, we'll promote an example to buckets with larger 1st dim if they can accommodate the 2nd dim. + + When ``max_ratio`` is set, it discards the examples that exceed a specific output-to-input length ratio. + ``max_ratio`` must be a list with the same length as the number of buckets. + ``max_ratio`` is only applied when ``strict_2d`` is set to ``True``. """ + strict_2d: bool = True + max_ratio: list[float] | None = None + + def __post_init__(self): + if isinstance(self.max_seq_len_buckets[0], Sequence): + self.max_seq_len_buckets = np.asarray(self.max_seq_len_buckets) + if self.max_ratio is not None: + assert isinstance(self.max_ratio, Sequence), f"self.max_ratio must be a list, but we got: {self.max_ratio}" + assert len(self.max_ratio) == len( + self.max_seq_len_buckets + ), f"{len(self.max_ratio)=} != {len(self.max_seq_len_buckets)=}" + @property def bucketing_2d_enabled(self) -> bool: - return isinstance(self.max_seq_len_buckets[0], Sequence) and len(self.max_seq_len_buckets[0]) == 2 + return isinstance(self.max_seq_len_buckets, np.ndarray) def measure_length(self, example: Cut) -> tuple[float, float] | float: if self.bucketing_2d_enabled: @@ -123,41 +144,66 @@ def measure_length(self, example: Cut) -> tuple[float, float] | float: return example.duration def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = None) -> int: - if not self.bucketing_2d_enabled: - return super().select_bucket(buckets=buckets, example=example, example_len=example_len) if example_len is None: example_len = self.measure_length(example) - bucket_idx = bisect.bisect_left(buckets, example_len) - # For 2D bucketing we have to refine the initially found bucket_idx, as bisect - # looks primarily at the first index of a tuple (i.e. duration). - # For example, with buckets [(1, 1), (1, 2), (2, 2), (2, 4)] and example (1.5, 3) - # bisect would allocate it to bucket_idx=2 instead of bucket_idx=3. - # To refine, we'll try to push the example to as many buckets to the right as possible, - # as long as they have the same dim0 length (e.g. audio duration) and the example's dim1 - # is smaller than the bin's dim1 (e.g., output token sequence length). - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - num_buckets = len(self.max_seq_len_buckets) - while ( - (next_idx := bucket_idx + 1) < num_buckets # There is a next bucket - and (bin := self.max_seq_len_buckets[next_idx])[0] == bin_dim0 # The next bucket has the same 1st dim. - # The example's 2nd dim is between that of the current and the next bucket; or, - # the next bucket's 2nd dim is still smaller than example. - and (bin_dim1 < example_len[1] <= bin[1] or bin[1] < example_len[1]) - ): - bucket_idx = next_idx - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - - if example_len[0] > bin_dim0 or example_len[1] > bin_dim1: - logging.warning( - f"Data sample exceeds 2D bucket specification: lengths={example_len} bucket=({bin_dim0}, {bin_dim1}) " - f"(there is no larger bucket that would fit this example). " - f"We will keep it but expect OutOfMemoryError to happen during the training. " - f"You can fix this by stricter filtering with max_duration, max_tokens, max_tps, max_tpt; " - f"or re-estimating your bucket bins to match the actual data length distribution. " - f"Details: {example=}" - ) - - return bucket_idx + return find_smallest_bucket( + self.max_seq_len_buckets, example_len, strict=self.strict_2d, max_ratio=self.max_ratio + ) + + +def find_smallest_bucket( + buckets: np.ndarray, + example_lens: float | Sequence[float], + strict: bool = True, + max_ratio: Sequence[float] | None = None, +) -> int | None: + """ + Find the smallest bucket that fits a given example. + Each bucket and ``example_lens`` are floats (1-D bucketing) + or tuples of (dim0, dim1, dim2, ...) (N-D bucketing, typically 2-D). + Assumes the buckets have been sorted ascendingly. + Returns a tuple of (smallest_bin, bin_idx), or (None, None) if no bucket fits the example. + """ + # 1D bucketing - binary search. + if isinstance(example_lens, (float, int)): # 1-D + idx = bisect_left(buckets, example_lens) + if idx == len(buckets): + return None + return idx + + # 2D bucketing 'strict' mode: only consider sub-buckets for the specific bucket that matches this example. + # E.g. for buckets = [(10, 5), (10, 10), (20, 12), (20, 18)] + # and example_lens = (8, 11) + # we will return None because we only consider the first two buckets based on dim0 (=8). + if strict: + # Find the first 2D bucket that accepts this example + dim0_begin = bisect_left(buckets[:, 0], example_lens[0]) + if dim0_begin == buckets.shape[0]: + return None + # Find the last 2D bucket that accepts this example + dim0_end = dim0_begin + while dim0_end < buckets.shape[0] and buckets[dim0_end, 0] == buckets[dim0_begin, 0]: + dim0_end += 1 + # Find the smallest 2D bucket in this range that accepts this example + dim1_begin = bisect_left(buckets[dim0_begin:dim0_end, 1], example_lens[1]) + if dim1_begin == dim0_end - dim0_begin: + return None + fit_idx = dim0_begin + dim1_begin + # Apply max_ratio (token-per-second/token-per-token) filtering if requested + if max_ratio is not None and example_lens[1] / example_lens[0] > max_ratio[fit_idx]: + return None + return fit_idx + + # 2D bucketing 'lenient' mode - linear search (as 2nd dim may not be growing monotonically). + # E.g. for buckets = [(10, 5), (10, 10), (20, 12), (20, 18)] + # and example_lens = (8, 11) + # we will return bucket_idx=2 because (20, 12) fits (8, 11) at the cost of more padding. + does_fit = np.all(np.asarray(example_lens) <= buckets, axis=1) + min_fit_idx = np.argmax(does_fit) + if min_fit_idx or does_fit[min_fit_idx]: + return min_fit_idx.item() + else: + return None @dataclass @@ -270,6 +316,8 @@ class TokenPerSecondFilter: def __init__(self, tps_min: float | None, tps_max: float | None) -> None: self.tps_min = ifnone(tps_min, -1) + if isinstance(tps_max, Sequence): + tps_max = float("inf") # filtering handled in bucketing filter self.tps_max = ifnone(tps_max, float("inf")) assert tps_min <= tps_max, f"{tps_min=} {tps_max=}" self.enabled = tps_min > 0 or tps_max < float("inf") @@ -290,6 +338,8 @@ class TokenPerTokenFilter: def __init__(self, tpt_min: float | None, tpt_max: float | None) -> None: self.tpt_min = ifnone(tpt_min, -1) + if isinstance(tpt_max, Sequence): + tpt_max = float("inf") # filtering handled in bucketing filter self.tpt_max = ifnone(tpt_max, float("inf")) assert tpt_min <= tpt_max, f"{tpt_min=} {tpt_max=}" self.enabled = tpt_min > 0 or tpt_max < float("inf") @@ -301,6 +351,24 @@ def __call__(self, example) -> bool: return self.tpt_min <= tpt <= self.tpt_max +class BucketingFilter: + """ + Filters out examples that did not fit into any of the buckets. + Intended mainly for 2D bucketing. This filter is only active when + the constraint passed to it is of type ``FixedBucketBatchSizeConstraint2D``, + and is otherwise disabled. + """ + + def __init__(self, sampling_constraint: SamplingConstraint) -> None: + self.constraint = sampling_constraint + self.enabled = isinstance(self.constraint, FixedBucketBatchSizeConstraint2D) + + def __call__(self, example) -> bool: + if not self.enabled: + return True + return self.constraint.select_bucket(self.constraint.max_seq_len_buckets, example) is not None + + def _measure_tokens(cut: Cut) -> int: if hasattr(cut, "input_ids"): return len(cut.input_ids) # tokenized with prompt formatter diff --git a/nemo/collections/common/prompts/canary2.py b/nemo/collections/common/prompts/canary2.py index 3aed7a3bfa10..2aa657d294cc 100644 --- a/nemo/collections/common/prompts/canary2.py +++ b/nemo/collections/common/prompts/canary2.py @@ -26,6 +26,7 @@ CANARY_BOS, CANARY_EOS, CANARY_SPECIAL_TOKENIZER, + CanaryTokenizer, ) @@ -196,8 +197,13 @@ def canary2(cut: Cut, prompt: Canary2PromptFormatter) -> dict[str, torch.Tensor] ), ) ans = prompt.encode_dialog(turns) + if isinstance(prompt.tokenizer, CanaryTokenizer): + eos = prompt.tokenizer.eos + else: # SPE + eos = prompt.tokenizer.token_to_id(CANARY_EOS) + assert eos > -1, "Invalid tokenizer: tokenizer.token_to_id('{CANARY_EOS}') returned {eos}" assert ( - ans["answer_ids"][-1].item() == prompt.tokenizer.eos + ans["answer_ids"][-1].item() == eos ), f"Expected the last token in answer_ids to be EOS, but we got {ans['answer_ids']}" ans["answer_ids"] = ans["answer_ids"][:-1] # Strip Canary's EOS return ans diff --git a/nemo/collections/common/tokenizers/canary_tokenizer.py b/nemo/collections/common/tokenizers/canary_tokenizer.py index 04dc6e3a68a9..c0972e5c8c63 100644 --- a/nemo/collections/common/tokenizers/canary_tokenizer.py +++ b/nemo/collections/common/tokenizers/canary_tokenizer.py @@ -191,6 +191,29 @@ def build_special_tokenizer( return spl_tokenizer +class CanaryBPETokenizer(SentencePieceTokenizer): + """ + Thin wrapper around SPE tokenizer that overwrites SPE's BOS/EOS/PAD with Canary's special tokens + for compatibility with CanaryTokenizer (aggregate). + """ + + @cached_property + def eos_id(self) -> int: + return self.token_to_id(CANARY_EOS) + + @cached_property + def bos_id(self) -> int: + return self.token_to_id(CANARY_BOS) + + @cached_property + def nospeech_id(self) -> int: + return self.token_to_id(CANARY_NOSPEECH) + + @cached_property + def pad_id(self) -> int: + return self.token_to_id(CANARY_PAD) + + def _map_canary1_to_canary2_lang(lang: str, available_langs: list[str]) -> str: if len(lang) != 2 or lang in available_langs: return lang diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 0f4a021e09cc..5f4f0f0a1c11 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -15,6 +15,7 @@ import argparse import ast import math +import warnings from functools import partial from itertools import islice from pathlib import Path @@ -25,16 +26,17 @@ from lhotse.cut import Cut from omegaconf import OmegaConf -from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.data import apply_prompt_format_fn from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize -from nemo.collections.common.data.lhotse.sampling import ( - DurationFilter, - FixedBucketBatchSizeConstraint2D, - TokenPerSecondFilter, -) +from nemo.collections.common.data.lhotse.sampling import DurationFilter, FixedBucketBatchSizeConstraint2D from nemo.collections.common.prompts.formatter import PromptFormatter -from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer +from nemo.collections.common.tokenizers import ( + AggregateTokenizer, + CanaryTokenizer, + SentencePieceTokenizer, + TokenizerSpec, +) def parse_args(): @@ -107,11 +109,14 @@ def parse_args(): help="If specified, we'll filter out utterances longer than this.", ) parser.add_argument( - "--max_tps", + "--max_tps", type=float, default=None, help="Deprecated. TPS is automatically determined per bucket." + ) + parser.add_argument( + "--token_outlier_threshold", type=float, - default=float("inf"), - help="If specified, we'll filter out utterances with more tokens/second than this. " - "On regular utterances and BPE tokenizers with 1024 tokens 10-12tps is generally a reasonable limit.", + default=4.0, + help="The lower this is, the more outliers in transcript token count will be filtered out. " + "By default allow token counts at 4 sigma away from distribution mean, computed separately for every bucket.", ) parser.add_argument( "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." @@ -134,12 +139,19 @@ def parse_args(): return parser.parse_args() +def sort_two_arrays(A, B): + joint = np.rec.fromarrays([A, B]) + joint.sort() + return joint.f0, joint.f1 + + def estimate_duration_buckets( cuts: Iterable[Cut], num_buckets: int, num_subbuckets: int, max_tps: float, max_duration: float, + token_outlier_threshold: float, quiet: bool, ) -> list[tuple[float, float]]: """ @@ -159,10 +171,7 @@ def estimate_duration_buckets( num_tokens.append(toks) sizes = np.array(sizes, dtype=np.float32) num_tokens = np.array(num_tokens, dtype=np.int32) - joint = np.rec.fromarrays([sizes, num_tokens]) - joint.sort() - sizes = joint.f0 - num_tokens = joint.f1 + sizes, num_tokens = sort_two_arrays(sizes, num_tokens) # We are building buckets with equal duration (empirically leads to more even bucket exhaustion over time). # We need to determine how much duration to allocate per bucket. @@ -170,19 +179,12 @@ def estimate_duration_buckets( if not quiet: print("Duration distribution:") - print(pd.Series(sizes).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) + print(pd.Series(sizes).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.995, 0.999])) if math.isinf(max_duration): max_duration = sizes[-1] - tps = num_tokens / sizes - if not quiet: - print("Token per second distribution:") - print(pd.Series(tps).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) - if math.isinf(max_tps): - max_tps = tps.max() - del tps - bins = [] + tps_thresholds = [] bin_indexes = [0] tot = 0.0 @@ -193,55 +195,84 @@ def _estimate_token_buckets(max_bucket_duration): # Note that this estimation is biased towards more padding if you have # a lot of zero-token examples (e.g. non-speech). nonlocal bins - num_tokens_bucket = num_tokens[bin_indexes[-1] : binidx] - num_tokens_bucket.sort() + + # Start by discarding outlier examples as defined by token-per-second (TPS) attribute. + # We empirically determined high TPS examples to cause severe OOMs limiting batch sizes. + # We cap the TPS for each top-level bucket at 4 standard deviations of TPS. + # Examples exceeding that TPS value will be discarded during sampling at training time. + num_tokens_bucket_all = num_tokens[bin_indexes[-1] : binidx] + sizes_bucket_all = sizes[bin_indexes[-1] : binidx] + non_outlier_indexes = find_non_outliers_z_score( + num_tokens_bucket_all / sizes_bucket_all, threshold=token_outlier_threshold + ) + num_tokens_bucket = num_tokens_bucket_all[non_outlier_indexes] + sizes_bucket = sizes_bucket_all[non_outlier_indexes] + max_tps_bucket = (num_tokens_bucket / sizes_bucket).max() + num_tokens_bucket, sizes_bucket = sort_two_arrays(num_tokens_bucket, sizes_bucket) + if not quiet: + outlier_tps = np.delete(num_tokens_bucket_all / sizes_bucket_all, non_outlier_indexes) + print( + f"[bucket <= {max_bucket_duration:.2f}s] [{num_tokens_bucket.min()} - {num_tokens_bucket.max()}] [approx-max-tps: {max_tps_bucket:.2f}] Discarded {binidx - bin_indexes[-1] - len(num_tokens_bucket)} max token outliers", + end=" ", + ) + if len(outlier_tps) > 0: + print(f"min-outlier: {outlier_tps.min():.2f}, max-outlier: {outlier_tps.max():.2f}).", end="") + print() + tokens_per_subbucket = num_tokens_bucket.sum() / num_subbuckets tot_toks = 0 # Iterate over token counts, and whenever we hit tokens_per_subbucket, create a new 2D bucket bin. - for num_toks in num_tokens_bucket: + for num_toks, size in zip(num_tokens_bucket, sizes_bucket): # Threshold hit: we are creating a new (max_duration, max_num_tokens) bin. if tot_toks > tokens_per_subbucket: bins.append((max_bucket_duration, num_toks)) + tps_thresholds.append(max_tps_bucket) tot_toks = 0 tot_toks += num_toks - bins.append((size, math.ceil(size * max_tps))) + bins.append((max_bucket_duration, num_toks)) + tps_thresholds.append(max_tps_bucket) # Iterate over data, and whenever we hit size_per_bucket, create a new bucket bin. for binidx, size in enumerate(sizes): if tot > size_per_bucket: # Threshold hit: we are creating a new duration bin (multiplied by number of token bins). _estimate_token_buckets(max_bucket_duration=size) + bin_indexes.append(binidx) tot = 0.0 tot += size # Estimate an extra 2D bin set for global max duration. _estimate_token_buckets(max_bucket_duration=max_duration) - return bins + return bins, tps_thresholds + + +def find_non_outliers_z_score(data, threshold=4): + # Note: we don't apply abs() here because we only filter the upper end of the distribution. + # We don't mind low-token-counts for bucketing purposes. + z_scores = (data - np.mean(data)) / np.std(data) + return np.where(z_scores <= threshold) -def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrapper: +def load_tokenizer(paths: list[str], langs: list[str] = None, is_canary: bool = True) -> TokenizerSpec: if len(paths) == 1: tok = SentencePieceTokenizer(paths[0]) else: assert langs is not None and len(paths) == len( langs ), f"Cannot create AggregateTokenizer; each tokenizer must have assigned a language via --langs option (we got --tokenizers={paths} and --langs={langs})" - tok = AggregateTokenizer({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) - return TokenizerWrapper(tok) + if is_canary: + tokcls = CanaryTokenizer + else: + tokcls = AggregateTokenizer + tok = tokcls({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) + return tok def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): if prompt is not None: - turns = prompt.get_default_dialog_slots() - last_turn = {"role": prompt.OUTPUT_ROLE, "slots": prompt.get_slots(prompt.OUTPUT_ROLE)} - assert len(last_turn["slots"]) == 1 # TODO: not sure how to handle multi-slot for system output here - for key in last_turn["slots"]: - last_turn["slots"][key] = cut.supervisions[0].text - last_turn["slots"][prompt.PROMPT_LANGUAGE_SLOT] = cut.supervisions[0].language - turns.append(last_turn) - ans = prompt.encode_dialog(turns) - cut.supervisions[0].tokens = ans["input_ids"] + encoded = apply_prompt_format_fn(cut, prompt) + cut.supervisions[0].tokens = encoded["input_ids"] elif tokenizer is not None: cut = tokenize(cut, tokenizer) @@ -274,15 +305,25 @@ def main(): if not args.quiet: pd.set_option('display.float_format', lambda x: '%.2f' % x) + if args.max_tps is not None: + warnings.warn( + "The option --max_tps has been deprecated in favor of " + "automatic TPS determination that's variable across buckets." + ) + tokenizer = None prompt = None if args.tokenizer is not None: - tokenizer = load_tokenizer(args.tokenizer, args.langs) + tokenizer = load_tokenizer( + paths=args.tokenizer, + langs=args.langs, + is_canary=args.prompt_format is not None and 'canary' in args.prompt_format, + ) if args.prompt_format is not None: prompt_defaults = None if args.prompt is not None: prompt_defaults = ast.literal_eval(args.prompt) - prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer._tokenizer, defaults=prompt_defaults) + prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer, defaults=prompt_defaults) if '=' in args.input: inp_arg = args.input @@ -302,28 +343,28 @@ def main(): duration_filter = RejectionsCounter(DurationFilter(args.min_duration, args.max_duration), "Duration filtering") cuts = cuts.filter(duration_filter) cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt)) - tps_filter = RejectionsCounter(TokenPerSecondFilter(-1, args.max_tps), "Token per second filtering") - cuts = cuts.filter(tps_filter) if (N := args.num_examples) > 0: cuts = islice(cuts, N) - duration_bins = estimate_duration_buckets( + duration_bins, tps_thresholds = estimate_duration_buckets( cuts, num_buckets=args.buckets, num_subbuckets=args.sub_buckets, - max_tps=args.max_tps, max_duration=args.max_duration, + max_tps=args.max_tps, + token_outlier_threshold=args.token_outlier_threshold, quiet=args.quiet, ) duration_bins = "[" + ','.join(f"[{b:.3f},{sb:d}]" for b, sb in duration_bins) + "]" - if args.quiet: - print(duration_bins) - return - duration_filter.print_report() - tps_filter.print_report() + tps_thresholds = "[" + ",".join(f"{t:.2f}" for t in tps_thresholds) + "]" + if not args.quiet: + duration_filter.print_report() print("Use the following options in your config:") + print(f"\tuse_bucketing=1") print(f"\tnum_buckets={args.buckets}") print(f"\tbucket_duration_bins={duration_bins}") + print(f"The max_tps setting below is optional, use it if your data has low quality long transcript outliers:") + print(f"\tmax_tps={tps_thresholds}") if __name__ == "__main__": diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index b44c2c46c629..d46179742ff8 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -408,7 +408,9 @@ def oomptimizer( ( "text" if any( - isinstance(item["type"].elements_type, LabelsType) and item["seq_length"] == direction + isinstance(item["type"], NeuralType) + and isinstance(item["type"].elements_type, LabelsType) + and item["seq_length"] == direction for item in schema["inputs"] if item["type"] != "dummy" ) @@ -518,8 +520,6 @@ def step(): if is_2d_bucketing: # 2D bucketing doesn't support bucket merging. final_profile = [["[" + ",".join(map(str, b)) + "]", bs] for (b, _, __), bs in profile.items()] - max_input_len, max_output_len = buckets[-1] - ratio = max_output_len / max_input_len else: click.echo("Bucket merging stage...") final_profile = [] @@ -532,7 +532,6 @@ def step(): final_profile[-1][0] = bucket continue final_profile.append([bucket, bs]) - max_input_len = final_profile[-1][0] click.secho(f"The profile was created with the following settings:") click.secho(f"* using {memory_fraction:.1%} of available GPU RAM.") @@ -541,9 +540,6 @@ def step(): click.secho("The final profile is:", bold=True) click.secho("\tbucket_duration_bins=[" + ",".join(str(seqlen) for seqlen, bs in final_profile) + "]", bold=True) click.secho("\tbucket_batch_size=[" + ",".join(str(bs) for seqlen, bs in final_profile) + "]", bold=True) - click.secho("\t(The following flags are suitable for ASR/speech-to-text models):") - click.secho(f"\tmax_tps={ratio}", bold=True) - click.secho(f"\tmax_duration={max_input_len}", bold=True) if __name__ == "__main__": diff --git a/tests/collections/common/test_2d_bucketing_constraint.py b/tests/collections/common/test_2d_bucketing_constraint.py index 36cb9825ac5b..285df28d4ab8 100644 --- a/tests/collections/common/test_2d_bucketing_constraint.py +++ b/tests/collections/common/test_2d_bucketing_constraint.py @@ -14,29 +14,37 @@ import numpy as np import pytest +import torch.utils.data from lhotse import CutSet, Seconds, SupervisionSegment from lhotse.dataset import DynamicBucketingSampler -from lhotse.testing.dummies import DummyManifest, dummy_cut -from nemo.collections.common.data.lhotse.sampling import FixedBucketBatchSizeConstraint2D +from lhotse.testing.dummies import dummy_cut +from lhotse.testing.random import deterministic_rng + +from nemo.collections.common.data.lhotse.dataloader import ( + BucketingFilter, + FixedBucketBatchSizeConstraint2D, + get_lhotse_dataloader_from_config, +) + + +def make_cut(id_: int = 0, duration: Seconds = 1.0, num_tokens: int = 10): + supervision = SupervisionSegment(f"blah-{id_}", f"blah-{id_}", 0.0, duration, text="a" * num_tokens) + supervision.tokens = np.zeros((num_tokens,), dtype=np.int32).tolist() + return dummy_cut(id_, duration=duration, supervisions=[supervision]) @pytest.fixture def cuts(): - def _cut(id_: int, duration: Seconds, num_tokens: int): - supervision = SupervisionSegment(f"blah-{id_}", f"blah-{id_}", 0.0, duration, text="a" * num_tokens) - supervision.tokens = np.zeros((num_tokens,), dtype=np.int32) - return dummy_cut(id_, duration=duration, supervisions=[supervision]) - return CutSet( - [_cut(i, duration=2.0, num_tokens=4) for i in range(20)] - + [_cut(i, duration=2.0, num_tokens=8) for i in range(20)] - + [_cut(i, duration=2.0, num_tokens=12) for i in range(20)] - + [_cut(i, duration=8.0, num_tokens=8) for i in range(20)] - + [_cut(i, duration=8.0, num_tokens=12) for i in range(20)] - + [_cut(i, duration=8.0, num_tokens=16) for i in range(20)] - + [_cut(i, duration=14.0, num_tokens=12) for i in range(20)] - + [_cut(i, duration=14.0, num_tokens=16) for i in range(20)] - + [_cut(i, duration=14.0, num_tokens=20) for i in range(20)] + [make_cut(i, duration=2.0, num_tokens=4) for i in range(20)] + + [make_cut(i, duration=2.0, num_tokens=8) for i in range(20)] + + [make_cut(i, duration=2.0, num_tokens=12) for i in range(20)] + + [make_cut(i, duration=8.0, num_tokens=8) for i in range(20)] + + [make_cut(i, duration=8.0, num_tokens=12) for i in range(20)] + + [make_cut(i, duration=8.0, num_tokens=16) for i in range(20)] + + [make_cut(i, duration=14.0, num_tokens=12) for i in range(20)] + + [make_cut(i, duration=14.0, num_tokens=16) for i in range(20)] + + [make_cut(i, duration=14.0, num_tokens=20) for i in range(20)] ) @@ -63,6 +71,7 @@ def test_2d_bucketing_expected_bucket_allocation(cuts): constraint=FixedBucketBatchSizeConstraint2D( max_seq_len_buckets=duration_bins, batch_sizes=batch_sizes, + strict_2d=False, ), buffer_size=1000, seed=0, @@ -79,7 +88,7 @@ def test_2d_bucketing_expected_bucket_allocation(cuts): for cut in batch: # First, check that the sampled examples are indeed below the max duration/num_tokens for its bucket. assert cut.duration <= max_duration - assert cut.supervisions[0].tokens.shape[0] <= max_num_tokens + assert len(cut.supervisions[0].tokens) <= max_num_tokens # Then, find the previous compatible bucket for each of training example's dimensions, # and verify that it was not possible to assign the example to that smaller bucket. # We should skip this for bucket_idx==0 (no previous buckets available). @@ -97,7 +106,176 @@ def test_2d_bucketing_expected_bucket_allocation(cuts): prev_max_num_tokens = max( tok for dur, tok in duration_bins[:bin_index] if dur == max_duration and tok < max_num_tokens ) - assert cut.supervisions[0].tokens.shape[0] > prev_max_num_tokens + assert len(cut.supervisions[0].tokens) > prev_max_num_tokens except ValueError as e: if "max() arg is an empty sequence" not in str(e): raise + + +@pytest.mark.parametrize( + ["duration", "num_tokens", "should_keep", "bucket_idx"], + [ + # Buckets for duration range [0.0-5.0]: + # * Sweep num_tokens + (2.0, 0, True, 0), + (2.0, 5, True, 0), + (2.0, 10, True, 0), + (2.0, 11, True, 1), + (2.0, 20, True, 1), + (2.0, 21, True, 3), + (2.0, 30, True, 3), + (2.0, 31, False, None), + # * Check the upper bound duration 5.0 + (5.0, 0, True, 0), + (5.0, 5, True, 0), + (5.0, 10, True, 0), + (5.0, 11, True, 1), + (5.0, 20, True, 1), + (5.0, 21, True, 3), + (5.0, 30, True, 3), + (5.0, 31, False, None), + # Buckets for duration range [5.0, 10.0] + # * Sweep num_tokens + (8.0, 0, True, 2), + (8.0, 15, True, 2), + (8.0, 16, True, 3), + (8.0, 30, True, 3), + (8.0, 31, False, None), + # * Check the upper bound duration 10.0 + (10.0, 0, True, 2), + (10.0, 15, True, 2), + (10.0, 16, True, 3), + (10.0, 30, True, 3), + (10.0, 31, False, None), + # Durations above max duration + (20.0, 0, False, None), + (20.0, 1000, False, None), + ], +) +def test_2d_bucketing_filter_lenient(duration, num_tokens, should_keep, bucket_idx): + buckets = [(5.0, 10), (5.0, 20), (10.0, 15), (10.0, 30)] + batch_sizes = [4, 3, 2, 1] + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=False) + filter_2d = BucketingFilter(constraint) + + cut = make_cut(duration=duration, num_tokens=num_tokens) + assert filter_2d(cut) == should_keep + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == bucket_idx + + +@pytest.mark.parametrize( + ["duration", "num_tokens", "should_keep", "bucket_idx"], + [ + # Buckets for duration range [0.0-5.0]: + # * Sweep num_tokens + (2.0, 0, True, 0), + (2.0, 5, True, 0), + (2.0, 10, True, 0), + (2.0, 11, True, 1), + (2.0, 20, True, 1), + (2.0, 21, False, None), # <-- strict + (2.0, 30, False, None), # <-- strict + (2.0, 31, False, None), + # * Check the upper bound duration 5.0 + (5.0, 0, True, 0), + (5.0, 5, True, 0), + (5.0, 10, True, 0), + (5.0, 11, True, 1), + (5.0, 20, True, 1), + (5.0, 21, False, None), # <-- strict + (5.0, 30, False, None), # <-- strict + (5.0, 31, False, None), + # Buckets for duration range [5.0, 10.0] + # * Sweep num_tokens + (8.0, 0, True, 2), + (8.0, 15, True, 2), + (8.0, 16, True, 3), + (8.0, 30, True, 3), + (8.0, 31, False, None), + # * Check the upper bound duration 10.0 + (10.0, 0, True, 2), + (10.0, 15, True, 2), + (10.0, 16, True, 3), + (10.0, 30, True, 3), + (10.0, 31, False, None), + # Durations above max duration + (20.0, 0, False, None), + (20.0, 1000, False, None), + ], +) +def test_2d_bucketing_filter_strict(duration, num_tokens, should_keep, bucket_idx): + buckets = [(5.0, 10), (5.0, 20), (10.0, 15), (10.0, 30)] + batch_sizes = [4, 3, 2, 1] + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=True) + filter_2d = BucketingFilter(constraint) + + cut = make_cut(duration=duration, num_tokens=num_tokens) + assert filter_2d(cut) == should_keep + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == bucket_idx + + +def test_2d_bucketing_filter_strict_max_ratio(): + buckets = [(5.0, 10), (5.0, 20), (10.0, 15), (10.0, 30)] + max_ratio = [4.0, 4.0, 3.0, 3.0] + batch_sizes = [4, 3, 2, 1] + + # Without max_ratio it works because both dims fit bucket at idx 1 + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=True) + filter_2d = BucketingFilter(constraint) + cut = make_cut(duration=2.0, num_tokens=20) + assert filter_2d(cut) == True + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == 1 + + # With max_ratio it's filtered out because 20 / 2.0 = 10.0 but max_ratio is 4.0 + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=True, max_ratio=max_ratio) + filter_2d = BucketingFilter(constraint) + cut = make_cut(duration=2.0, num_tokens=20) + assert filter_2d(cut) == False + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == None + + +class _Identity(torch.utils.data.Dataset): + def __getitem__(self, item): + return item + + +def test_2d_bucketing_strict_mode_flag_works(deterministic_rng, tmp_path): + cuts_path = tmp_path / "cuts.jsonl" + CutSet([make_cut(0, duration=1.0, num_tokens=10), make_cut(0, duration=1.0, num_tokens=100)]).to_file(cuts_path) + + # Strict mode enabled + dloader = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "use_bucketing": True, + "bucket_duration_bins": [(5.0, 10), (5.0, 20), (10.0, 150), (10.0, 300)], + "bucket_batch_size": [1, 1, 1, 1], + "bucketing_2d_strict_mode": True, + }, + global_rank=0, + world_size=1, + dataset=_Identity(), + ) + batches = [b for b in dloader] + assert len(batches) == 1 + assert len(batches[0]) == 1 + assert len(batches[0][0].supervisions[0].tokens) == 10 + + # Strict mode disabled + dloader = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "use_bucketing": True, + "bucket_duration_bins": [(5.0, 10), (5.0, 20), (10.0, 150), (10.0, 300)], + "bucket_batch_size": [1, 1, 1, 1], + "bucketing_2d_strict_mode": False, + }, + global_rank=0, + world_size=1, + dataset=_Identity(), + ) + batches = [b for b in dloader] + assert len(batches) == 2 + assert len(batches[0]) == 1 + assert len(batches[0][0].supervisions[0].tokens) == 100 + assert len(batches[1][0].supervisions[0].tokens) == 10