Skip to content

Commit

Permalink
Made tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
annaelisalappe committed Jan 10, 2025
1 parent ef77442 commit 6723936
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 206 deletions.
97 changes: 36 additions & 61 deletions src/itwinai/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jsonargparse import ActionConfigFile
from jsonargparse import ArgumentParser as JAPArgumentParser
from jsonargparse._formatters import DefaultHelpFormatter
from omegaconf import OmegaConf, dictconfig
from omegaconf import OmegaConf, dictconfig, errors, listconfig

from .pipeline import Pipeline

Expand Down Expand Up @@ -81,29 +81,23 @@ def __init__(

def build_from_config(
self,
type: Literal["pipeline", "step"] = "pipeline",
pipeline_nested_key: str = "pipeline",
override_keys: Dict | None = None,
steps: str | None = None,
step_idx: Union[str, int] | None = None,
override_keys: Dict = {},
steps: List[str] | List[int] | None = None,
verbose: bool = False,
) -> Pipeline:
"""Parses the pipeline and instantiated all classes defined within it.
Args:
type (Literal["pipeline", "step"]): The type of object to build. If set to "step",
only the object(s) defined by the step given by 'step_idx' is built.
Defaults to "pipeline".
pipeline_nested_key (str, optional): Nested key in the configuration file
identifying the pipeline object. Defaults to "pipeline".
override_keys ([Dict[str, Any]], optional): A dict mapping
nested keys to the value to override. Defaults to None.
steps (str, optional): If building a pipeline, allows you to select which step(s)
to include in the pipeline. Accepted values are indices,
python slices (e.g., 0:3 or 2:10:100), and string names of steps.
override_keys ([Dict[str, Any]]): A dict mapping
nested keys to the value to override. Defaults to {}.
steps ((list(str) | list(int)), optional): If building a pipeline, allows you to
select which step(s) to include in the pipeline. Accepted values are lists
of indices for config files where the steps are defined as lists, or lists of
names of steps, if they are defined in the configuration as a dict.
Defaults to None.
step_idx (Union[str, int], optional): If building only a step, used to identify
which one. Must be set if 'type' is set to 'step'. Defaults to None.
verbose (bool): if True, prints the assembled pipeline
to console formatted as JSON.
Expand All @@ -113,11 +107,8 @@ def build_from_config(
conf = self.parse_pipeline(pipeline_nested_key, override_keys)

# Select steps
if type == "pipeline":
if steps:
conf.steps = self._get_selected_steps(steps, conf.steps)
else:
conf = conf.steps[step_idx]
if steps:
conf = self._select_steps(conf, steps)

# Resolve interpolated parameters
OmegaConf.resolve(conf)
Expand All @@ -131,15 +122,15 @@ def build_from_config(
def parse_pipeline(
self,
pipeline_nested_key: str = "pipeline",
override_keys: Dict | None = None,
override_keys: Dict = {},
) -> dictconfig.DictConfig:
"""Parses the pipeline from a yaml file into an OmegaConf DictConfig.
Args:
pipeline_nested_key (str, optional): Nested key in the configuration file
identifying the pipeline object. Defaults to "pipeline".
override_keys (Dict | None, optional): A dict mapping
nested keys to the value to override. Defaults to None.
override_keys (Dict): A dict mapping
nested keys to the value to override. Defaults to {}.
Raises:
e: Failed to load config from yaml. Most likely due to a badly structured
Expand All @@ -161,9 +152,6 @@ def parse_pipeline(
)
raise e

if pipeline_nested_key not in raw_conf:
raise ValueError(f"Pipeline key {pipeline_nested_key} not found.")

# Override keys
for override_key, override_value in override_keys.items():
inferred_type = ast.literal_eval(override_value)
Expand All @@ -173,49 +161,36 @@ def parse_pipeline(
f"Successfully overrode key {override_key}."
f"It now has the value {inferred_type} of type {type(inferred_type)}."
)

conf = raw_conf[pipeline_nested_key]
try:
conf = OmegaConf.select(raw_conf, pipeline_nested_key, throw_on_missing=True)
except Exception as e:
e.add_note(f"Could not find pipeline key {pipeline_nested_key} in config.")
raise e

return conf

def _get_selected_steps(self, steps: str, conf_steps: list):
"""Selects the steps of the pipeline to be executed.
def _select_steps(
self, conf: listconfig.Listconfig | dictconfig.DictConfig, steps: List[int] | List[str]
):
"""Selects the steps given from the configuration object.
If only one step is selected, returns a configuration with only that step. Otherwise
returns a pipeline with all the selected steps as a list.
Args:
steps (str): Selects the steps of the pipeline. Accepted values are indices,
python slices (e.g., 0:3 or 2:10:100), and string names of steps.
conf_steps (list): A list of all the steps in the pipeline configuration.
Raises:
ValueError: Invalid slice notation
IndexError: Index out of range
ValueError: Invalid step name given
conf (listconfig.Listconfig | dictconfig.DictConfig):
The configuration of the pipeline
steps (List[int] | List[str]): The list of steps
Returns:
list: The steps selected from the pipeline.
listconfig.Listconfig | dictconfig.DictConfig: The updated configuration
"""
# If steps is given as a slice
if ":" in steps:
try:
slice_obj = slice(*[int(x) if x else None for x in steps.split(":")])
return conf_steps[slice_obj]
except ValueError:
raise ValueError(f"Invalid slice notation: {steps}")

# If steps is given as a single index
elif steps.isdigit():
index = int(steps)
if 0 <= index < len(conf_steps):
return [conf_steps[index]]
else:
raise IndexError(f"Step index out of range: {index}")

# If steps is given as a name
else:
selected_steps = [step for step in conf_steps if step.get("_target_") == steps]
if not selected_steps:
raise ValueError(f"No steps found with name: {steps}")
return selected_steps
if len(steps) == 1:
return OmegaConf.create(conf.steps[steps[0]])

selected_steps = [conf.steps[step] for step in steps]
OmegaConf.update(conf, conf.steps, selected_steps)

return conf


class ArgumentParser(JAPArgumentParser):
Expand Down
102 changes: 53 additions & 49 deletions tests/components/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,71 +11,75 @@

pytest.PIPE_LIST_YAML = """
my-list-pipeline:
class_path: itwinai.pipeline.Pipeline
init_args:
steps:
- class_path: itwinai.tests.dummy_components.FakePreproc
init_args:
max_items: 32
name: my-preproc
_target_: itwinai.pipeline.Pipeline
steps:
- _target_: itwinai.tests.dummy_components.FakePreproc
max_items: 32
name: my-preproc
- class_path: itwinai.tests.dummy_components.FakeTrainer
init_args:
lr: 0.001
batch_size: 32
name: my-trainer
- _target_: itwinai.tests.dummy_components.FakeTrainer
lr: 0.001
batch_size: 32
name: my-trainer
- class_path: itwinai.tests.dummy_components.FakeSaver
init_args:
save_path: ./some/path
name: my-saver
- _target_: itwinai.tests.dummy_components.FakeSaver
save_path: ./some/path
name: my-saver
"""

pytest.PIPE_DICT_YAML = """
my-dict-pipeline:
class_path: itwinai.pipeline.Pipeline
init_args:
steps:
preproc-step:
class_path: itwinai.tests.dummy_components.FakePreproc
init_args:
_target_: itwinai.pipeline.Pipeline
steps:
preproc-step:
_target_: itwinai.tests.dummy_components.FakePreproc
max_items: 33
name: my-preproc
train-step:
_target_: itwinai.tests.dummy_components.FakeTrainer
lr: 0.001
batch_size: 32
name: my-trainer
save-step:
_target_: itwinai.tests.dummy_components.FakeSaver
save_path: ./some/path
name: my-saver
"""

pytest.NESTED_PIPELINE = """
some:
field:
my-nested-pipeline:
_target_: itwinai.pipeline.Pipeline
steps:
- _target_: itwinai.tests.FakePreproc
max_items: 32
name: my-preproc
train-step:
class_path: itwinai.tests.dummy_components.FakeTrainer
init_args:
- _target_: itwinai.tests.FakeTrainer
lr: 0.001
batch_size: 32
name: my-trainer
save-step:
class_path: itwinai.tests.dummy_components.FakeSaver
init_args:
- _target_: itwinai.tests.FakeSaver
save_path: ./some/path
name: my-saver
"""

pytest.NESTED_PIPELINE = """
some:
field:
nst-pipeline:
class_path: itwinai.pipeline.Pipeline
init_args:
steps:
- class_path: itwinai.tests.FakePreproc
init_args:
max_items: 32
name: my-preproc
- class_path: itwinai.tests.FakeTrainer
init_args:
lr: 0.001
batch_size: 32
name: my-trainer
pytest.INTERPOLATED_VALUES_PIPELINE = """
max_items: 33
name: my-trainer
my-interpolation-pipeline:
_target_: itwinai.pipeline.Pipeline
steps:
- _target_: itwinai.tests.dummy_components.FakePreproc
max_items: ${max_items}
name: my-preproc
- class_path: itwinai.tests.FakeSaver
init_args:
save_path: ./some/path
name: my-saver
- _target_: itwinai.tests.dummy_components.FakeTrainer
lr: 0.001
batch_size: 32
name: ${name}
"""
Loading

0 comments on commit 6723936

Please sign in to comment.