diff --git a/forest/drivers/unified_model.py b/forest/drivers/unified_model.py index 308ff168..5ee64fb8 100644 --- a/forest/drivers/unified_model.py +++ b/forest/drivers/unified_model.py @@ -90,6 +90,27 @@ def full_path(self, name): return os.path.join(self.directory, name) +class Facade: + """Translation layer between dataset labels and file system patterns + """ + + def __init__(self, pattern, database): + self.pattern = pattern + self.database = database + + def variables(self, _label): + return self.database.variables(self.pattern) + + def initial_times(self, _label, variable): + return self.database.initial_times(self.pattern, variable) + + def valid_times(self, _label, variable, initial_time): + return self.database.valid_times(self.pattern, variable, initial_time) + + def pressures(self, _label, variable, initial_time): + return self.database.pressures(self.pattern, variable, initial_time) + + class Dataset: def __init__( self, @@ -114,7 +135,7 @@ def __init__( def navigator(self): if self.use_database: - return self.database + return Facade(self.pattern, self.database) else: return Navigator(self.pattern) diff --git a/forest/tutorial/core.py b/forest/tutorial/core.py index 94816ac4..9ae48bfc 100644 --- a/forest/tutorial/core.py +++ b/forest/tutorial/core.py @@ -111,9 +111,9 @@ def build_um_config(build_dir): pattern: "*{}" directory: {} locator: database - database_path: database.db + database_path: {} """.format( - UM_FILE, build_dir + UM_FILE, build_dir, db_path(build_dir) ) print("writing: {}".format(path)) with open(path, "w") as stream: @@ -181,13 +181,15 @@ def build_um(build_dir): var[1] = Z_1.T +def db_path(build_dir): + return os.path.join(build_dir, DB_FILE) + def build_database(build_dir): - db_path = os.path.join(build_dir, DB_FILE) um_path = os.path.join(build_dir, UM_FILE) if not os.path.exists(um_path): build_um(build_dir) - print("building: {}".format(db_path)) - database = forest.db.database.Database.connect(db_path) + print("building: {}".format(db_path(build_dir))) + database = forest.db.database.Database.connect(db_path(build_dir)) database.insert_netcdf(um_path) database.close() diff --git a/test/test_drivers_unified_model.py b/test/test_drivers_unified_model.py index a3a3b756..079dcebc 100644 --- a/test/test_drivers_unified_model.py +++ b/test/test_drivers_unified_model.py @@ -49,7 +49,10 @@ def test_navigator_use_database(tmpdir): } dataset = forest.drivers.get_dataset("unified_model", settings) navigator = dataset.navigator() - assert isinstance(navigator, forest.db.database.Database) + assert hasattr(navigator, "variables") + assert hasattr(navigator, "initial_times") + assert hasattr(navigator, "valid_times") + assert hasattr(navigator, "pressures") def test_loader_use_database(tmpdir): diff --git a/test/test_tutorial.py b/test/test_tutorial.py index 3cc37652..6a04e9c0 100644 --- a/test/test_tutorial.py +++ b/test/test_tutorial.py @@ -101,7 +101,7 @@ def test_build_all_builds_um_config_file(build_dir): "pattern": "*" + forest.tutorial.core.UM_FILE, "directory": build_dir, "locator": "database", - "database_path": "database.db", + "database_path": f"{build_dir}/database.db", }, }, }