Skip to content

Commit

Permalink
ADD: draft inference wf
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Nov 7, 2023
1 parent 0a0f56e commit 95661c1
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 8 deletions.
21 changes: 21 additions & 0 deletions use-cases/3dgan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,24 @@ micromamba run -p ../../.venv-pytorch mlflow ui --backend-store-uri ml_logs/mlfl
```

And select the "3DGAN" experiment.

## Inference

The following is preliminary and not 100% ML/scientifically sound.

1. As inference dataset we can reuse training/validation dataset
2. As model, we can create a dummy version of it with:

```python
import torch
from model import ThreeDGAN
# Same params as in the training config file!
my_gan = ThreeDGAN()
torch.save(my_gan, '3dgan-inference.pth')
```

3. Run inference with the following command:

```bash
TODO
```
2 changes: 1 addition & 1 deletion use-cases/3dgan/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def GetDataAngleParallel(
return final_dataset


class MyDataModule(pl.LightningDataModule):
class ParticlesDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int, datapath):
super().__init__()
self.batch_size = batch_size
Expand Down
12 changes: 10 additions & 2 deletions use-cases/3dgan/inference-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ executor:
data_path: exp_data/
data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX

- class_path: trainer.Lightning3DGANTrainer
- class_path: trainer.Lightning3DGANPredictor
init_args:
model:
class_path: trainer.LightningModelLoader
init_args:
model_uri: 3dgan-inference.pth
# Pytorch lightning config for training
config:
seed_everything: 4231162351
Expand Down Expand Up @@ -87,7 +91,11 @@ executor:

# Lightning data module configuration
data:
class_path: dataloader.MyDataModule
class_path: dataloader.ParticlesDataModule
init_args:
datapath: exp_data/*/*.h5
batch_size: 64

- class_path: saver.ParticleImagesSaver
init_args:
save_dir: 3dgan-generated
2 changes: 1 addition & 1 deletion use-cases/3dgan/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ executor:

# Lightning data module configuration
data:
class_path: dataloader.MyDataModule
class_path: dataloader.ParticlesDataModule
init_args:
datapath: exp_data/*/*.h5
batch_size: 64
55 changes: 55 additions & 0 deletions use-cases/3dgan/saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict, Tuple, Optional
import os
import shutil

import torch
from torch import Tensor

from itwinai.components import Saver


class ParticleImagesSaver(Saver):
"""Saves generated particle trajectories to disk."""

def __init__(
self,
save_dir: str = '3dgan-generated',
) -> None:
super().__init__()
self.save_dir = save_dir

def execute(
self,
generated_images: Dict[str, Tensor],
config: Optional[Dict] = None
) -> Tuple[Optional[Tuple], Optional[Dict]]:
"""Saves generated images to disk.
Args:
generated_images (Dict[str, Tensor]): maps unique item ID to
the generated image.
config (Optional[Dict], optional): inherited configuration.
Defaults to None.
Returns:
Tuple[Optional[Tuple], Optional[Dict]]: propagation of inherited
configuration and saver return value.
"""
result = self.save(generated_images)
return ((result,), config)

def save(self, generated_images: Dict[str, Tensor]) -> None:
"""Saves generated images to disk.
Args:
generated_images (Dict[str, Tensor]): maps unique item ID to
the generated image.
"""
if os.path.exists(self.save_dir):
shutil.rmtree(self.save_dir)
os.makedirs(self.save_dir)

# TODO: save as 3D plot image
for img_id, img in generated_images.items():
img_path = os.path.join(self.save_dir, img_id + '.pth')
torch.save(img, img_path)
109 changes: 105 additions & 4 deletions use-cases/3dgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
import sys
from typing import Union, Dict, Tuple, Optional, Any

from itwinai.components import Trainer
from model import ThreeDGAN
from dataloader import MyDataModule
import torch
from torch import Tensor
import lightning as pl
from lightning.pytorch.cli import LightningCLI

from itwinai.components import Trainer, Predictor
from itwinai.serialization import ModelLoader
from itwinai.torch.inference import TorchModelLoader

from model import ThreeDGAN
from dataloader import ParticlesDataModule
from utils import load_yaml


Expand All @@ -23,7 +30,7 @@ def train(self) -> Any:
cli = LightningCLI(
args=self.conf,
model_class=ThreeDGAN,
datamodule_class=MyDataModule,
datamodule_class=ParticlesDataModule,
run=False,
save_config_kwargs={
"overwrite": True,
Expand All @@ -47,3 +54,97 @@ def save_state(self):

def load_state(self):
return super().load_state()


class LightningModelLoader(TorchModelLoader):
"""Loads a torch lightning model from somewhere.
Args:
model_uri (str): Can be a path on local filesystem
or an mlflow 'locator' in the form:
'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH'
"""

def __call__(self) -> pl.LightningModule:
""""Loads model from model URI.
Raises:
ValueError: if the model URI is not recognized
or the model is not found.
Returns:
pl.LightningModule: torch lightning module.
"""
# TODO: improve
# # Load best model
# loaded_model = cli.model.load_from_checkpoint(
# ckpt_path,
# lightning_conf['model']['init_args']
# )
return super().__call__()


class Lightning3DGANPredictor(Predictor):

def __init__(
self,
model: Union[ModelLoader, pl.LightningModule],
config: Union[Dict, str],
name: Optional[str] = None
):
super().__init__(model, name)
if isinstance(config, str) and os.path.isfile(config):
# Load from YAML
config = load_yaml(config)
self.conf = config

def predict(
self,
datamodule: Optional[pl.LightningDataModule] = None,
model: Optional[pl.LightningModule] = None
) -> Dict[str, Tensor]:
old_argv = sys.argv
sys.argv = ['some_script_placeholder.py']
cli = LightningCLI(
args=self.conf,
model_class=ThreeDGAN,
datamodule_class=ParticlesDataModule,
run=False,
save_config_kwargs={
"overwrite": True,
"config_filename": "pl-training.yml",
},
subclass_mode_model=True,
subclass_mode_data=True,
)
sys.argv = old_argv

# Override config file with inline arguments, if given
if datamodule is None:
datamodule = cli.datamodule
if model is None:
model = cli.model

predictions = cli.trainer.predict(model, datamodule=datamodule)

predictions_dict = dict()
# TODO: postprocess predictions
for idx, generated_img in enumerate(torch.cat(predictions)):
predictions_dict[str(idx)] = generated_img
return predictions_dict

def execute(
self,
config: Optional[Dict] = None,
) -> Tuple[Optional[Tuple], Optional[Dict]]:
""""Execute some operations.
Args:
config (Dict, optional): key-value configuration.
Defaults to None.
Returns:
Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as
(results, config).
"""
return self.predict(), config

0 comments on commit 95661c1

Please sign in to comment.