Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPS-free 2D bucket estimation and filtering #11738

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1079,23 +1079,22 @@ To run 2D bucketing with 30 buckets sub-divided into 5 sub-buckets each (150 buc

# The script's output:
Use the following options in your config:
use_bucketing=1
num_buckets=30
bucket_duration_bins=[[1.91,10],[1.91,17],[1.91,25],...
max_duration=...
max_tps=...
<other diagnostic information about the dataset>
The max_tps setting below is optional, use it if your data has low quality long transcript outliers:
max_tps=[13.2,13.2,11.8,11.8,...]

Note that the output in ``bucket_duration_bins`` is a nested list, where every bin specifies
the maximum duration and the maximum number of tokens that go into the bucket.
Passing this option to Lhotse dataloader will automatically enable 2D bucketing.
Note the presence of ``max_duration`` and ``max_tps`` (token-per-second) options:
these need to be included in dataloader's configuration to ensure we can use the buckets correctly at runtime
in case of outliers.
In general, if you change your data in training, it is highly advisable to re-estimate the duration bins.

Note that reasonable values for tokens-per-second rarely exceed 12tps with reasonably good tokenizers.
If you find your dataset's TPS is much higher than that, you may have some bad data outliers.
In that case you may specify ``--max_tps`` option to discard those both in bin estimation and dataloading.

Note the presence of ``max_tps`` (token-per-second) option.
It is optional to include it in the dataloader configuration: if you do, we will apply an extra filter
that discards examples which have more tokens per second than the threshold value.
The threshold is determined for each bucket separately based on data distribution, and can be controlled
with the option ``--token_outlier_threshold``.
This filtering is useful primarily for noisy datasets to discard low quality examples / outliers.

We also support aggregate tokenizers for 2D bucketing estimation:

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
super().__init__()
self.tokenizer = tokenizer
self.load_audio = AudioSamples(fault_tolerant=True)
self.padding_value = self.tokenizer.pad
self.padding_value = self.tokenizer.pad_id
self.prompt = prompt

def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ def _setup_monolingual_tokenizer(self, tokenizer_cfg: DictConfig):
if special_tokens is not None:
raise ValueError("`special_tokens` are no longer supported for SentencePiece based tokenizers.")

# Update special tokens
self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path)
if "custom_tokenizer" in self.tokenizer_cfg:
self.tokenizer = self.from_config_dict(
{"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "model_path": model_path}
)
else:
self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path)

if 'vocab_path' in self.tokenizer_cfg:
vocab_path = self.tokenizer_cfg.get('vocab_path')
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def read_dataset_config(config) -> tuple[CutSet, bool]:
"force_finite": config.get("force_finite", False),
"max_open_streams": config.get("max_open_streams", None),
"token_equivalent_duration": config.get("token_equivalent_duration", None),
"tarred_random_access": config.get("tarred_random_access", False),
"skip_missing_manifest_entries": config.get("skip_missing_manifest_entries", False),
}
input_cfg = config.input_cfg
if isinstance(input_cfg, (str, Path)):
Expand Down Expand Up @@ -510,11 +510,11 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]:
LazyNeMoTarredIterator(
config.manifest_filepath,
tar_paths=config.tarred_audio_filepaths,
tarred_random_access=config.tarred_random_access,
skip_missing_manifest_entries=config.skip_missing_manifest_entries,
**common_kwargs,
)
)
if not config.tarred_random_access and not force_finite:
if not force_finite:
cuts = cuts.repeat()
else:
cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs))
Expand Down Expand Up @@ -552,7 +552,7 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]:
nemo_iter = LazyNeMoTarredIterator(
manifest_path=manifest_path,
tar_paths=tar_path,
tarred_random_access=config.tarred_random_access,
skip_missing_manifest_entries=config.skip_missing_manifest_entries,
**common_kwargs,
)
else:
Expand Down
51 changes: 34 additions & 17 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
make_worker_init_fn,
)
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import CutSampler, TimeConstraint
from lhotse.dataset.sampling.base import CutSampler, SamplingConstraint, TimeConstraint
from lhotse.lazy import LazyFlattener
from lhotse.utils import fastcopy, fix_random_seed
from omegaconf import DictConfig, OmegaConf
Expand All @@ -44,6 +44,7 @@
read_cutset_from_config,
)
from nemo.collections.common.data.lhotse.sampling import (
BucketingFilter,
DurationFilter,
FixedBucketBatchSizeConstraint2D,
MultimodalFixedBucketBatchSizeConstraint2D,
Expand Down Expand Up @@ -76,7 +77,8 @@ class LhotseDataLoadingConfig:
cuts_path: str | None = None
shar_path: Any = None # str | list[str | tuple[str, float | int]] | None = None
# Enable this to support dataloading from JSON manifests that reference subsets of audio tar files.
tarred_random_access: bool = False
skip_missing_manifest_entries: bool = False
tarred_random_access: bool = False # deprecated, replaced by: skip_missing_manifest_entries
# 2. Batch size.
# a. Existing NeMo options.
batch_size: int | None = None
Expand All @@ -91,6 +93,7 @@ class LhotseDataLoadingConfig:
bucket_duration_bins: Any = None # list[float] | list[list[float]] | None = None
bucket_buffer_size: int = 10000
concurrent_bucketing: bool = True # fetches data in a background thread
bucketing_2d_strict_mode: bool = True # reduces padding by discarding significant outliers
# d. Other Lhotse sampling options.
shuffle_buffer_size: int | None = 10000
drop_last: bool = False
Expand All @@ -117,15 +120,15 @@ class LhotseDataLoadingConfig:
min_duration: float | None = -1
max_duration: float | None = float("inf")
min_tps: int = -1 # allowed tokens per second (audio-only)
max_tps: float = float("inf")
max_tps: Any = float("inf") # float | list[float]
# * Text input
min_tokens: int | None = None
max_tokens: int | None = None
# When true, combine context+answer lengths into a total length; otherwise report context length.
# For 2D bucketing it's always false, as we report a tuple of (context_len, answer_len).
measure_total_length: bool = True
min_tpt: int = -1 # allowed tokens per token (text-only)
max_tpt: float = float("inf")
max_tpt: Any = float("inf") # float | list[float]

# 3. Supported existing NeMo options.
shuffle: bool = False
Expand Down Expand Up @@ -530,7 +533,7 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No
# Select the strategy customizing Lhotse sampler behaviour.
# Provides support for dynamic batch sizes, multimodal dataloading, 2D bucketing, etc.
bucket_duration_bins = determine_bucket_duration_bins(config)
constraint = determine_sampling_constraint(bucket_duration_bins, config)
cuts, constraint = determine_sampling_constraint(cuts, bucket_duration_bins, config)

# 3. The sampler.
if config.use_bucketing:
Expand Down Expand Up @@ -608,13 +611,15 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No
return sampler, use_iterable_dataset


def determine_sampling_constraint(bucket_duration_bins, config):
def determine_sampling_constraint(cuts: CutSet, bucket_duration_bins, config) -> tuple[CutSet, SamplingConstraint]:
"""
Select an appropriate sampling strategy (constraint) for Lhotse samplers based on the configuration.
Sampling constraint affects the batch size (static/dynamic) and bucketing behaviour (1D/2D).
It is the appropriate customization point to introduce support of other modalities,
as it defines a method for example sequence length measurement (audio duration, text tokens, etc.).

Some constraints apply extra filter on ``cuts`` which is why we accept and return the ``CutSet``.

Lhotse's default is :class:`TimeConstraint` for regular audio data, other available options are
multimodal constraints (joint text + audio) and their 2D bucketing extensions.
"""
Expand All @@ -627,7 +632,10 @@ def determine_sampling_constraint(bucket_duration_bins, config):
max_seq_len_buckets=bucket_duration_bins,
batch_sizes=config.bucket_batch_size,
token_equivalent_duration=config.token_equivalent_duration,
strict_2d=config.bucketing_2d_strict_mode,
max_ratio=config.max_tpt if isinstance(config.max_tpt, Sequence) else None,
)
cuts = cuts.filter(BucketingFilter(constraint))
else:
constraint = MultimodalSamplingConstraint(
token_equivalent_duration=config.token_equivalent_duration,
Expand All @@ -643,14 +651,17 @@ def determine_sampling_constraint(bucket_duration_bins, config):
constraint = FixedBucketBatchSizeConstraint2D(
max_seq_len_buckets=bucket_duration_bins,
batch_sizes=config.bucket_batch_size,
strict_2d=config.bucketing_2d_strict_mode,
max_ratio=config.max_tps if isinstance(config.max_tps, Sequence) else None,
)
cuts = cuts.filter(BucketingFilter(constraint))
else:
constraint = TimeConstraint(
max_cuts=config.batch_size,
max_duration=config.batch_duration,
quadratic_duration=config.quadratic_duration,
)
return constraint
return cuts, constraint


def determine_bucket_duration_bins(config):
Expand Down Expand Up @@ -702,22 +713,28 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi
supported_keys = set(OmegaConf.to_container(default).keys())
received_keys = set(OmegaConf.to_container(config).keys())
unsupported_keys = received_keys - supported_keys
unsupported_keys.discard("use_lhotse")
if unsupported_keys:
warnings.warn(
f"The following configuration keys are no longer supported " f"and ignored: {','.join(unsupported_keys)}",
category=DeprecationWarning,
logging.warning(
f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}",
)
config = OmegaConf.masked_copy(config, list(supported_keys))

return OmegaConf.merge(default, config)
config = OmegaConf.merge(default, config)

if config.get("tarred_random_access", False):
logging.warning(
"Option 'tarred_random_access' is deprecated and replaced with 'skip_missing_manifest_entries'.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be also add version from which this would be removed

)
config.skip_missing_manifest_entries = True
if config.skip_missing_manifest_entries:
logging.warning(
"Note: skip_missing_manifest_entries is set to True. "
"If any of your manifests and tar files are mismatched, the entire tar file will be skipped without warning. "
"It's your responsibility to ensure data integrity with this setting."
)

def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool:
assert not (
config.force_map_dataset and config.force_iterable_dataset
), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True"
use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset
return use_iterable_dataset
return config


def tokenize(example, tokenizer):
Expand Down
42 changes: 19 additions & 23 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ class LazyNeMoTarredIterator:
This can be used for other cloud storage APIs such as S3, GCS, etc.
The same mechanism applies to ``manifest_path``.

If your data has been filtered so that the JSON manifests refer to just a subset of recordings,
set ``skip_missing_manifest_entries` to ``True``.
This will still read the tar files sequentially (very fast) and discard the audio files that
are not present in the corresponding manifest.

The ``shard_seed`` argument is used to seed the RNG shuffling the shards.
By default, it's ``trng`` which samples a seed number from OS-provided TRNG (see Python ``secrets`` module).
Seed is resolved lazily so that every dataloading worker may sample a different one.
Expand Down Expand Up @@ -264,10 +269,10 @@ def __init__(
shard_seed: int | Literal["trng", "randomized"] = "trng",
text_field: str = "text",
lang_field: str = "lang",
tarred_random_access: bool = False,
skip_missing_manifest_entries: bool = False,
extra_fields: list[dict[str, str]] | None = None,
) -> None:
self.tarred_random_access = tarred_random_access
self.skip_missing_manifest_entries = skip_missing_manifest_entries
self.shard_id_to_manifest: dict[int, Iterable[dict]]
self.paths = expand_sharded_filepaths(manifest_path)
if len(self.paths) == 1:
Expand Down Expand Up @@ -346,29 +351,21 @@ def _validate(self) -> None:
def shard_ids(self) -> List[int]:
return sorted(self.shard_id_to_manifest.keys())

def _iter_random_read(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]:
with tarfile.open(fileobj=BytesIO(open_best(tar_path, mode="rb").read()), mode="r") as tar:
for data in shard_manifest:
def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]:
with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar:
for tar_info in tar:
try:
tar_info = tar.getmember(data)
data = shard_manifest[tar_info.name]
raw_audio = tar.extractfile(tar_info).read()
yield data, raw_audio, tar_info
except KeyError as e:
raise RuntimeError(
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"The following audio_filepath='{data['audio_filepath']}' was not found in the tar file."
) from e

def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]:
with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar:
for tar_info in tar:
assert tar_info.name in shard_manifest, (
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"Cannot locate JSON entry for tar file '{tar_info.name}'"
)
data = shard_manifest[tar_info.name]
raw_audio = tar.extractfile(tar_info).read()
yield data, raw_audio, tar_info
if self.skip_missing_manifest_entries:
continue
else:
raise RuntimeError(
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"Cannot locate JSON entry for tar file '{tar_info.name}'"
) from e

def __iter__(self) -> Generator[Cut, None, None]:
shard_ids = self.shard_ids
Expand All @@ -384,7 +381,6 @@ def __iter__(self) -> Generator[Cut, None, None]:
# They have multiple JSONL entries where audio paths end with '-sub1', '-sub2', etc. for each offset.
offset_pattern = re.compile(r'^(?P<stem>.+)(?P<sub>-sub\d+)(?P<ext>\.\w+)?$')

iter_fn = self._iter_random_read if self.tarred_random_access else self._iter_sequential
for sid in shard_ids:
manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]

Expand All @@ -398,7 +394,7 @@ def basename(d: dict) -> str:
shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid])
tar_path = self.shard_id_to_tar_path[sid]
try:
for data, raw_audio, tar_info in iter_fn(tar_path, shard_manifest, manifest_path):
for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path):
meta = soundfile.info(BytesIO(raw_audio))
recording = Recording(
id=tar_info.path,
Expand Down
Loading
Loading