Skip to content

Commit

Permalink
ExtractSeqLensJob (#522)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz authored Jul 4, 2024
1 parent 5d44735 commit 9e62747
Showing 1 changed file with 113 additions and 1 deletion.
114 changes: 113 additions & 1 deletion returnn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import shutil
import subprocess
from typing import Optional
from typing import Optional, Any, Dict

import numpy

Expand Down Expand Up @@ -149,3 +149,115 @@ def run(self):
hdf_writer.insert_batch(numpy.asarray([[speaker_index]], dtype="int32"), [1], [segment_name])

hdf_writer.close()


class ExtractSeqLensJob(Job):
"""
Extracts sequence lengths from a dataset for one specific key.
"""

def __init__(
self,
dataset: Dict[str, Any],
post_dataset: Optional[Dict[str, Any]] = None,
*,
key: str,
output_format: str,
returnn_config: Optional[ReturnnConfig] = None,
returnn_root: Optional[tk.Path] = None,
):
"""
:param dataset: dict for :func:`returnn.datasets.init_dataset`
:param post_dataset: extension of the dataset dict, which is not hashed
:param key: e.g. "data", "classes" or whatever the dataset provides
:param output_format: "py" or "txt".
"py" will write a Python dict seq_tag -> seq_len.
"txt" will write one seq_len per line.
:param returnn_config: for the RETURNN global config.
This is optional and only needed if you use any custom functions (e.g. audio pre_process)
which expect some configuration in the global config.
:param returnn_root: inserted to ``sys.path`` for the RETURNN import.
"""
super().__init__()
self.dataset = dataset
self.post_dataset = post_dataset
self.key = key
assert output_format in {"py", "txt"}
self.output_format = output_format
self.returnn_config = returnn_config
self.returnn_root = returnn_root

self.out_returnn_config_file = self.output_path("returnn.config")
self.out_file = self.output_path(f"seq_lens.{output_format}")

self.rqmt = {"gpu": 0, "cpu": 1, "mem": 4, "time": 1}

@classmethod
def hash(cls, parsed_args):
"""hash"""
parsed_args = parsed_args.copy()
parsed_args.pop("post_dataset")
return super().hash(parsed_args)

def tasks(self):
"""tasks"""
yield Task("create_files", mini_task=True)
yield Task("run", rqmt=self.rqmt)

def create_files(self):
"""create files"""
config = self.returnn_config or ReturnnConfig({})
assert "dataset" not in config.config and "dataset" not in config.post_config
dataset_dict = self.dataset.copy()
if self.post_dataset:
# The modification to the config here is not part of the hash anymore,
# so merge dataset and post_dataset now.
dataset_dict.update(self.post_dataset)
config.config["dataset"] = dataset_dict
config.write(self.out_returnn_config_file.get_path())

def run(self):
"""run"""
import tempfile
import shutil
import sys

if self.returnn_root is not None:
sys.path.insert(0, self.returnn_root.get_path())

from returnn.config import set_global_config, Config
from returnn.datasets import init_dataset

config = Config()
config.load_file(self.out_returnn_config_file.get_path())
set_global_config(config)

dataset_dict = config.typed_value("dataset")
assert isinstance(dataset_dict, dict)
dataset = init_dataset(dataset_dict)
dataset.init_seq_order(epoch=1)

with tempfile.NamedTemporaryFile("w") as tmp_file:
if self.output_format == "py":
tmp_file.write("{\n")

seq_idx = 0
while dataset.is_less_than_num_seqs(seq_idx):
dataset.load_seqs(seq_idx, seq_idx + 1)
seq_tag = dataset.get_tag(seq_idx)
seq_len = dataset.get_seq_length(seq_idx)
assert self.key in seq_len.keys()
seq_len_ = seq_len[self.key]
if self.output_format == "py":
tmp_file.write(f"{seq_tag!r}: {seq_len_},\n")
elif self.output_format == "txt":
tmp_file.write(f"{seq_len_}\n")
else:
raise ValueError(f"{self}: invalid output_format {self.output_format!r}")
seq_idx += 1

if self.output_format == "py":
tmp_file.write("}\n")
tmp_file.flush()

shutil.copyfile(tmp_file.name, self.out_file.get_path())

0 comments on commit 9e62747

Please sign in to comment.