Skip to content

Commit

Permalink
DumpHDFJob: write dump config into explicit file (#375)
Browse files Browse the repository at this point in the history
* write dump config into explicit file

* WIP

* dump + test

* older black...

* https://www.pinterest.com/pin/511158626426177729/

* fix python

* move data

---------

Co-authored-by: Benedikt Hilmes <[email protected]>
  • Loading branch information
JackTemaki and Atticus1806 authored Oct 23, 2023
1 parent 02bd426 commit 453aeb8
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/job_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ jobs:
with:
repository: "rwth-i6/sisyphus"
path: "sisyphus"
- uses: actions/checkout@v2
with:
repository: "rwth-i6/returnn"
path: "returnn"
- uses: actions/setup-python@v2
with:
python-version: 3.8
Expand Down
23 changes: 17 additions & 6 deletions returnn/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,26 @@ def __init__(
self.out_hdf = self.output_path("data.hdf")

def tasks(self):
if isinstance(self.data, (dict, str)):
yield Task("write_config", mini_task=True)
yield Task("run", resume="run", rqmt=self.rqmt)

def run(self):
def write_config(self):
"""
Optionally writes a config if self.data is either of type str or a dict, i.e.g not a tk.Path
"""
data = self.data
if isinstance(data, dict):
instanciate_delayed(data)
data = str(data)
elif isinstance(data, tk.Path):
data = data.get_path()
instanciate_delayed(data)
data = str(data)
with open("dataset.config", "wt") as dataset_file:
dataset_file.write("#!rnn.py\n")
dataset_file.write("train = %s\n" % str(data))

def run(self):
if isinstance(self.data, tk.Path):
data = self.data.get_path()
else:
data = "dataset.config"

(fd, tmp_hdf_file) = tempfile.mkstemp(prefix=gs.TMP_PREFIX, suffix=".hdf")
os.close(fd)
Expand Down
47 changes: 47 additions & 0 deletions tests/job_tests/returnn/test_dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import tempfile

from sisyphus import tk, setup_path

from i6_core.returnn.hdf import ReturnnDumpHDFJob

rel_path = setup_path(__package__)


def test_hdf_dump():
with tempfile.TemporaryDirectory() as tmpdir:
from sisyphus import gs

gs.WORK_DIR = tmpdir

# Case 1: tk.Path
with open(f"{tmpdir}/tmp.config", "wt") as f:
f.write("#!rnn.py\n")
f.write("train = {'class': 'DummyDataset', 'input_dim': 3, 'output_dim': 4, 'num_seqs': 2}")
data = rel_path(f"{tmpdir}/tmp.config")
job = ReturnnDumpHDFJob(data=data, returnn_root=tk.Path("returnn/"))
assert [task.name() for task in job.tasks()] == ["run"]
job._sis_setup_directory()
job.run()

# Case 2: dict
data2 = {"class": "DummyDataset", "input_dim": 3, "output_dim": 4, "num_seqs": 2}
job2 = ReturnnDumpHDFJob(data=data2, returnn_root=tk.Path("returnn/"))
assert [task.name() for task in job2.tasks()] == ["write_config", "run"]
job2._sis_setup_directory()
job2.write_config()
job2.run()

# Case 3: str
data3 = "{'class': 'DummyDataset', 'input_dim': 3, 'output_dim': 4, 'num_seqs': 2}"
job3 = ReturnnDumpHDFJob(data=data3, returnn_root=tk.Path("returnn/"))
assert [task.name() for task in job3.tasks()] == ["write_config", "run"]
job3._sis_setup_directory()
job3.write_config()
job3.run()

assert os.path.getsize(job.out_hdf) == os.path.getsize(job2.out_hdf) == os.path.getsize(job3.out_hdf), (
os.path.getsize(job.out_hdf),
os.path.getsize(job2.out_hdf),
os.path.getsize(job3.out_hdf),
)

0 comments on commit 453aeb8

Please sign in to comment.