Skip to content

Commit

Permalink
find_lr script finished advances #26
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Jun 24, 2024
1 parent cb9d27b commit f88a6e4
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
29 changes: 17 additions & 12 deletions scripts/find_learning_rates.sh
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
#!/bin/bash

#SBATCH --job-name LR_EfficientNetB0 # Nombre del proceso

#SBATCH --partition dios # Cola para ejecutar

#SBATCH --gres=gpu:1 # Numero de gpus a usar


#SBATCH --job-name=train_EfficientNetB0 # Process name
#SBATCH --partition=dios # Queue for execution
#SBATCH --gres=gpu:1 # Number of GPUs to use
#SBATCH --mail-type=END,FAIL # Notifications for job done & fail
#SBATCH [email protected] # Where to send notification

# Load necessary paths
export PATH="/opt/anaconda/anaconda3/bin:$PATH"

export PATH="/opt/anaconda/bin:$PATH"

# Setup Conda environment
eval "$(conda shell.bash hook)"
conda activate /mnt/homeGPU/polivares/tda-nn/tda-nn-separability
export TFHUB_CACHE_DIR=.

conda activate /mnt/homeGPU/polivares/tda-nn-separability
# Check if correct number of arguments is passed
if [ "$#" -ne 1 ]; then
echo "Usage: $0 <config_file>"
exit 1
fi

export TFHUB_CACHE_DIR=.
config_file=$1

python find_learning_rates.py
python find_learning_rates.py $config_file

mail -s "Proceso finalizado" [email protected] <<< "El proceso ha finalizado"
# mail -s "Proceso finalizado" [email protected] <<<"El proceso ha finalizado"
57 changes: 56 additions & 1 deletion utils/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,52 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def plot_loss_statistics_with_seaborn(loss_lists):
"""
Plots the mean loss per epoch and the standard deviation as a shaded area for multiple experiments with variable epoch lengths using seaborn for improved aesthetics.
Args:
loss_lists (list of list of lists): Each element is a list of lists where each sublist represents
the loss per epoch for a single experiment, not necessarily all the same length.
"""
# Setting up seaborn for better aesthetics
sns.set(style="whitegrid")

# Determine the maximum number of epochs across all experiments
max_epochs = max(max(len(losses) for losses in experiment) for experiment in loss_lists)

num_experiments = len(loss_lists)
mean_losses = []
std_losses = []

for experiment_losses in loss_lists:
# Create an array with shape (num_of_experiments, max_epochs), initialized with NaN
losses_array = np.full((len(experiment_losses), max_epochs), np.nan)

# Fill the array with loss values
for i, losses in enumerate(experiment_losses):
losses_array[i, :len(losses)] = losses

# Compute mean and std deviation along the experiment axis, ignoring NaNs
mean_losses.append(np.nanmean(losses_array, axis=0))
std_losses.append(np.nanstd(losses_array, axis=0))

# Plotting with seaborn
plt.figure(figsize=(10, 6))
epochs_x = np.arange(1, max_epochs + 1)
colors = sns.color_palette("hsv", num_experiments) # Using seaborn color palette

for i, (mean_loss, std_loss) in enumerate(zip(mean_losses, std_losses)):
plt.plot(epochs_x, mean_loss, label=f'Experiment Group {i+1}', color=colors[i])
plt.fill_between(epochs_x, mean_loss - std_loss, mean_loss + std_loss, color=colors[i], alpha=0.3)

plt.title('Mean Loss Per Epoch With Standard Deviation')
plt.xlabel('Epoch')
plt.ylabel('Mean Loss')
plt.xticks(epochs_x) # Set x-ticks to show integer values for epochs
plt.legend()
plt.show()

def plot_loss(training_epoch_losses, validation_epoch_losses, plot_path):
"""
Expand Down Expand Up @@ -45,7 +92,7 @@ def plot_lr_vs_loss(log_lrs, losses, plot_path):
"""
plt.figure(figsize=(10, 5))

plt.plot(log_lrs, losses, label="Learning Rate vs Loss", marker='o')
plt.plot(log_lrs, losses, label="Learning Rate vs Loss")

plt.title("Learning Rate vs Loss")
plt.xlabel("Log Learning Rate")
Expand All @@ -54,3 +101,11 @@ def plot_lr_vs_loss(log_lrs, losses, plot_path):
plt.grid(True)
plt.savefig(plot_path)
plt.close()

# Deriving CSV path from the plot path
csv_path = plot_path.replace(".png", ".csv")

# Saving data to a CSV file
data = {'Log Learning Rate': log_lrs, 'Loss': losses}
df = pd.DataFrame(data)
df.to_csv(csv_path, index=False)

0 comments on commit f88a6e4

Please sign in to comment.