Skip to content

Commit

Permalink
Added config option to hpo script, styling (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
annaelisalappe authored Oct 22, 2024
1 parent a2290c1 commit c5ea0aa
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
1 change: 0 additions & 1 deletion use-cases/eurac/hpo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import os
from pathlib import Path
from typing import Dict

import matplotlib.pyplot as plt
Expand Down
4 changes: 2 additions & 2 deletions use-cases/eurac/slurm_ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Job configuration
#SBATCH --job-name=ray_tune_hpo
#SBATCH --account=intertwin
#SBATCH --time 01:00:00
#SBATCH --time 02:30:00

# Resources allocation
#SBATCH --cpus-per-task=24
Expand Down Expand Up @@ -88,7 +88,7 @@ echo All Ray workers started.
# Run the Python script using Ray
echo 'Starting HPO.'

python hpo.py --num_samples 8 --max_iterations 2 --ngpus $num_gpus --ncpus $num_cpus
python hpo.py --num_samples 4 --max_iterations 2 --ngpus $num_gpus --ncpus $num_cpus --pipeline_name rnn_training_pipeline # NOTE: conv_training_pipeline has not been tested

# Shutdown Ray after completion
ray stop
7 changes: 7 additions & 0 deletions use-cases/eurac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def train(self):
self.lr_scheduler.step(avg_val_loss)
loss_history["train"].append(train_loss)
loss_history["val"].append(avg_val_loss)

self.log(
item=train_loss.item(),
identifier="train_loss_per_epoch",
Expand Down Expand Up @@ -496,6 +497,12 @@ def train(self):
best_loss = avg_val_loss
# self.model.load_state_dict(best_model_weights)

# Report training metrics of last epoch to Ray
train.report(
{"loss": avg_val_loss.item(),
"train_loss": train_loss.item()}
)

return loss_history, metric_history

def create_dataloaders(self, train_dataset, validation_dataset, test_dataset):
Expand Down

0 comments on commit c5ea0aa

Please sign in to comment.