Skip to content

Commit

Permalink
UPDATE save parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Dec 12, 2023
1 parent c0cfd1e commit 390a911
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 37 deletions.
62 changes: 55 additions & 7 deletions src/itwinai/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class BaseComponent(ABC, Serializable):
name (Optional[str], optional): unique identifier for a step.
Defaults to None.
"""
_name: str = 'unnamed'
_name: str = None
parameters: Dict[Any, Any] = None

def __init__(
Expand All @@ -144,6 +144,7 @@ def __init__(
# debug: bool = False,
) -> None:
self.save_parameters(name=name)
self.name = name

@property
def name(self) -> str:
Expand Down Expand Up @@ -329,6 +330,8 @@ class Adapter(BaseComponent):
def __init__(self, policy: List[Any], name: Optional[str] = None) -> None:
super().__init__(name=name)
self.save_parameters(policy=policy, name=name)
self.name = name
self.policy = policy

@monitor_exec
def execute(self, *args) -> Tuple:
Expand Down Expand Up @@ -370,15 +373,15 @@ def execute(self, *args) -> Tuple:

class DataSplitter(BaseComponent):
"""Splits a dataset into train, validation, and test splits."""
train_proportion: float
validation_proportion: float
test_proportion: float
_train_proportion: Union[int, float]
_validation_proportion: Union[int, float]
_test_proportion: Union[int, float]

def __init__(
self,
train_proportion: float,
validation_proportion: float,
test_proportion: float,
train_proportion: Union[int, float],
validation_proportion: Union[int, float],
test_proportion: Union[int, float],
name: Optional[str] = None
) -> None:
super().__init__(name)
Expand All @@ -388,6 +391,51 @@ def __init__(
test_proportion=test_proportion,
name=name
)
self.train_proportion = train_proportion
self.validation_proportion = validation_proportion
self.test_proportion = test_proportion

@property
def train_proportion(self) -> Union[int, float]:
"""Training set proportion."""
return self._train_proportion

@train_proportion.setter
def train_proportion(self, prop: Union[int, float]) -> None:
if isinstance(prop, float) and not 0.0 <= prop <= 1.0:
raise ValueError(
"Train proportion should be in the interval [0.0, 1.0] "
f"if given as float. Received {prop}"
)
self._train_proportion = prop

@property
def validation_proportion(self) -> Union[int, float]:
"""Validation set proportion."""
return self._validation_proportion

@validation_proportion.setter
def validation_proportion(self, prop: Union[int, float]) -> None:
if isinstance(prop, float) and not 0.0 <= prop <= 1.0:
raise ValueError(
"Validation proportion should be in the interval [0.0, 1.0] "
f"if given as float. Received {prop}"
)
self._validation_proportion = prop

@property
def test_proportion(self) -> Union[int, float]:
"""Test set proportion."""
return self._ttest_proportion

@test_proportion.setter
def test_proportion(self, prop: Union[int, float]) -> None:
if isinstance(prop, float) and not 0.0 <= prop <= 1.0:
raise ValueError(
"Test proportion should be in the interval [0.0, 1.0] "
f"if given as float. Received {prop}"
)
self._test_proportion = prop

@abstractmethod
@monitor_exec
Expand Down
17 changes: 17 additions & 0 deletions src/itwinai/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def save_parameters(self, **kwargs) -> None:
# for k, v in kwargs.items():
# self.__setattr__(k, v)

@staticmethod
def locals2params(locals: Dict, pop_self: bool = True) -> Dict:
"""Remove ``self`` from the output of ``locals()``.
Args:
locals (Dict): output of ``locals()`` called in the constructor
of a class.
pop_self (bool, optional): whether to remove ``self``.
Defaults to True.
Returns:
Dict: cleaned ``locals()``.
"""
if pop_self:
locals.pop('self', None)
return locals

def update_parameters(self, **kwargs) -> None:
"""Updates stored parameters."""
self.save_parameters(**kwargs)
Expand Down
6 changes: 1 addition & 5 deletions src/itwinai/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ def __init__(
strategy
):
super().__init__()
self.save_parameters(
strategy=strategy, epochs=epochs, batch_size=batch_size,
callbacks=callbacks, model_dict=model_dict,
compile_conf=compile_conf, strategy=strategy
)
self.save_parameters(**self.locals2params(locals()))
self.strategy = strategy
self.epochs = epochs
self.batch_size = batch_size
Expand Down
6 changes: 6 additions & 0 deletions src/itwinai/tests/dummy_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def __init__(self, data_uri: str, name: Optional[str] = None
) -> None:
super().__init__(name)
self.save_parameters(data_uri=data_uri, name=name)
self.data_uri = data_uri

def execute(self):
...
Expand All @@ -25,6 +26,7 @@ def __init__(self, train_prop: float, name: Optional[str] = None
) -> None:
super().__init__(name)
self.save_parameters(train_prop=train_prop, name=name)
self.train_prop = train_prop

def execute(self):
...
Expand All @@ -43,6 +45,7 @@ def __init__(self, max_items: int, name: Optional[str] = None
) -> None:
super().__init__(name)
self.save_parameters(max_items=max_items, name=name)
self.max_items = max_items

def execute(self):
...
Expand All @@ -59,6 +62,8 @@ def __init__(self, lr: float, batch_size: int, name: Optional[str] = None
) -> None:
super().__init__(name)
self.save_parameters(lr=lr, batch_size=batch_size, name=name)
self.lr = lr
self.batch_size = batch_size

def execute(self):
...
Expand All @@ -76,6 +81,7 @@ class FakeSaver(BaseComponent):
def __init__(self, save_path: str, name: Optional[str] = None) -> None:
super().__init__(name)
self.save_parameters(save_path=save_path, name=name)
self.save_path = save_path

def execute(self):
...
Expand Down
2 changes: 1 addition & 1 deletion src/itwinai/torch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
name: str = None
) -> None:
super().__init__(model=model, name=name)
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
self.model = self.model.eval()
# self.seed = seed
# self.strategy = strategy
Expand Down
2 changes: 1 addition & 1 deletion src/itwinai/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(
Makes the model a DDP model.
"""
super().__init__()
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
self.model = model
self.loss = loss
self.epochs = epochs
Expand Down
2 changes: 1 addition & 1 deletion use-cases/3dgan/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
data_url: Optional[str] = None,
name: Optional[str] = None,
) -> None:
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
super().__init__(name)
self.data_path = data_path
self.data_url = data_url
Expand Down
2 changes: 1 addition & 1 deletion use-cases/3dgan/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
self,
save_dir: str = '3dgan-generated'
) -> None:
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
super().__init__()
self.save_dir = save_dir

Expand Down
4 changes: 2 additions & 2 deletions use-cases/3dgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class Lightning3DGANTrainer(Trainer):
def __init__(self, config: Union[Dict, str]):
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
super().__init__()
if isinstance(config, str) and os.path.isfile(config):
# Load from YAML
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
config: Union[Dict, str],
name: Optional[str] = None
):
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
super().__init__(model, name)
if isinstance(config, str) and os.path.isfile(config):
# Load from YAML
Expand Down
1 change: 1 addition & 0 deletions use-cases/cyclones/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
data_path: str = "tmp_data"
):
super().__init__()
self.save_parameters(**self.locals2params(locals()))
self.batch_size = batch_size
self.split_ratio = split_ratio
self.epochs = epochs
Expand Down
13 changes: 2 additions & 11 deletions use-cases/cyclones/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
cores: int = None,
):
super().__init__()
self.save_parameters(**self.locals2params(locals()))
# Configurable
self.cores = cores
self.model_backup = model_backup
Expand All @@ -43,7 +44,7 @@ def __init__(
# Optimizers, Losses
self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

def train(self, train_data, validation_data):
def execute(self, train_data, validation_data):
train_dataset, n_train = train_data
valid_dataset, n_valid = validation_data

Expand Down Expand Up @@ -103,16 +104,6 @@ def train(self, train_data, validation_data):
model.save(self.last_model_name)
logging.debug("Saved training history")

def execute(
self,
train_dataset,
validation_dataset,
config: Optional[Dict] = None,
) -> Tuple[Optional[Tuple], Optional[Dict]]:
config = self.setup_config(config)
train_result = self.train(train_dataset, validation_dataset)
return (train_result,), config

def setup_config(self, config: Optional[Dict] = None) -> Dict:
config = config if config is not None else {}
self.experiment_dir = config["experiment_dir"]
Expand Down
4 changes: 2 additions & 2 deletions use-cases/mnist/tensorflow/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class MNISTDataGetter(DataGetter):
def __init__(self):
super().__init__()
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))

@monitor_exec
def execute(self) -> Tuple:
Expand All @@ -19,7 +19,7 @@ def execute(self) -> Tuple:
class MNISTDataPreproc(DataPreproc):
def __init__(self, classes: int):
super().__init__()
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
self.classes = classes

@monitor_exec
Expand Down
2 changes: 1 addition & 1 deletion use-cases/mnist/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
compile_conf=dict(loss=loss, optimizer=optimizer),
strategy=strategy
)
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
print(f'STRATEGY: {strategy}')
self.logger = logger if logger is not None else []

Expand Down
2 changes: 1 addition & 1 deletion use-cases/mnist/torch-lightning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(
name: Optional[str] = None
) -> None:
super().__init__(name)
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
self.data_path = data_path
self._downloader = MNISTDataModule(
data_path=self.data_path, download=True,
Expand Down
2 changes: 1 addition & 1 deletion use-cases/mnist/torch-lightning/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class LightningMNISTTrainer(Trainer):
def __init__(self, config: Union[Dict, str]):
super().__init__()
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
if isinstance(config, str) and os.path.isfile(config):
# Load from YAML
config = load_yaml(config)
Expand Down
4 changes: 2 additions & 2 deletions use-cases/mnist/torch/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MNISTDataModuleTorch(DataGetter):

def __init__(self, save_path: str = '.tmp/',) -> None:
super().__init__()
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
self.save_path = save_path

@monitor_exec
Expand Down Expand Up @@ -107,7 +107,7 @@ def generate_jpg_sample(
class MNISTPredictLoader(DataGetter):
def __init__(self, test_data_path: str) -> None:
super().__init__()
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
self.test_data_path = test_data_path

@monitor_exec
Expand Down
2 changes: 1 addition & 1 deletion use-cases/mnist/torch/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
class_labels: Optional[List] = None
) -> None:
super().__init__()
self.save_parameters(**locals())
self.save_parameters(**self.locals2params(locals()))
self.save_dir = save_dir
self.predictions_file = predictions_file
self.class_labels = (
Expand Down

0 comments on commit 390a911

Please sign in to comment.