Skip to content

Commit

Permalink
Simplify 3DGAN model config
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Nov 23, 2023
1 parent 33de0b4 commit f2ccfae
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 31 deletions.
55 changes: 27 additions & 28 deletions use-cases/3dgan/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
import os
import pickle
# import os
# import pickle
from collections import defaultdict
import math
from typing import Any
Expand Down Expand Up @@ -306,18 +306,16 @@ class ThreeDGAN(pl.LightningModule):
def __init__(
self,
latent_size=256,
batch_size=64,
loss_weights=[3, 0.1, 25, 0.1],
power=0.85,
lr=0.001,
checkpoint_path: str = '3Dgan.pth'
# checkpoint_path: str = '3Dgan.pth'
):
super().__init__()
self.save_hyperparameters()
self.automatic_optimization = False

self.latent_size = latent_size
self.batch_size = batch_size
self.loss_weights = loss_weights
self.lr = lr
self.power = power
Expand All @@ -332,9 +330,9 @@ def __init__(
self.index = 0
self.train_history = defaultdict(list)
self.test_history = defaultdict(list)
self.pklfile = checkpoint_path
checkpoint_dir = os.path.dirname(checkpoint_path)
os.makedirs(checkpoint_dir, exist_ok=True)
# self.pklfile = checkpoint_path
# checkpoint_dir = os.path.dirname(checkpoint_path)
# os.makedirs(checkpoint_dir, exist_ok=True)

def BitFlip(self, x, prob=0.05):
"""
Expand Down Expand Up @@ -411,10 +409,10 @@ def training_step(self, batch, batch_idx):
ecal_batch = ecal_batch.to(self.device)

optimizer_discriminator, optimizer_generator = self.optimizers()
batch_size = energy_batch.shape[0]

noise = torch.randn(
(energy_batch.shape[0], self.latent_size - 2),
# (self.batch_size, self.latent_size - 2),
(batch_size, self.latent_size - 2),
dtype=torch.float32,
device=self.device
)
Expand All @@ -427,7 +425,7 @@ def training_step(self, batch, batch_idx):
generated_images = self.generator(generator_ip)

# Train discriminator first on real batch
fake_batch = self.BitFlip(np.ones(self.batch_size).astype(np.float32))
fake_batch = self.BitFlip(np.ones(batch_size).astype(np.float32))
fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device)
labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

Expand All @@ -450,7 +448,7 @@ def training_step(self, batch, batch_idx):
optimizer_discriminator.step()

# Train discriminator on the fake batch
fake_batch = self.BitFlip(np.zeros(self.batch_size).astype(np.float32))
fake_batch = self.BitFlip(np.zeros(batch_size).astype(np.float32))
fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device)
labels = [fake_batch, energy_batch, ang_batch, ecal_batch]

Expand All @@ -473,15 +471,15 @@ def training_step(self, batch, batch_idx):

# avg_disc_loss = (sum(real_batch_loss) + sum(fake_batch_loss)) / 2

trick = np.ones(self.batch_size).astype(np.float32)
trick = np.ones(batch_size).astype(np.float32)
fake_batch = torch.tensor([[el] for el in trick]).to(self.device)
labels = [fake_batch, energy_batch.view(-1, 1), ang_batch, ecal_batch]

gen_losses_train = []
# Train generator twice using combined model
for _ in range(2):
noise = torch.randn(
(self.batch_size, self.latent_size - 2)).to(self.device)
(batch_size, self.latent_size - 2)).to(self.device)
generator_ip = torch.cat(
(energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise),
dim=1
Expand Down Expand Up @@ -615,9 +613,9 @@ def on_train_epoch_end(self): # outputs
torch.save(self.discriminator.state_dict(),
"discriminator_weights.pth")

with open(self.pklfile, "wb") as f:
pickle.dump({"train": self.train_history,
"test": self.test_history}, f)
# with open(self.pklfile, "wb") as f:
# pickle.dump({"train": self.train_history,
# "test": self.test_history}, f)

# pickle.dump({"train": self.train_history}, open(self.pklfile, "wb"))
print("train-loss:" + str(self.train_history["generator"][-1][0]))
Expand All @@ -633,10 +631,11 @@ def validation_step(self, batch, batch_idx):
ang_batch = ang_batch.to(self.device)
ecal_batch = ecal_batch.to(self.device)

batch_size = energy_batch.shape[0]

# Generate Fake events with same energy and angle as data batch
noise = torch.randn(
(energy_batch.shape[0], self.latent_size - 2),
# (self.batch_size, self.latent_size - 2),
(batch_size, self.latent_size - 2),
dtype=torch.float32,
device=self.device
)
Expand All @@ -648,10 +647,10 @@ def validation_step(self, batch, batch_idx):
# concatenate to fake and real batches
X = torch.cat((image_batch, generated_images), dim=0)

# y = np.array([1] * self.batch_size \
# + [0] * self.batch_size).astype(np.float32)
y = torch.tensor([1] * self.batch_size + [0] *
self.batch_size, dtype=torch.float32).to(self.device)
# y = np.array([1] * batch_size \
# + [0] * batch_size).astype(np.float32)
y = torch.tensor([1] * batch_size + [0] *
batch_size, dtype=torch.float32).to(self.device)
y = y.view(-1, 1)

ang = torch.cat((ang_batch, ang_batch), dim=0)
Expand All @@ -667,7 +666,7 @@ def validation_step(self, batch, batch_idx):
labels, disc_eval, self.loss_weights)

# Calculate generator loss
trick = np.ones(self.batch_size).astype(np.float32)
trick = np.ones(batch_size).astype(np.float32)
fake_batch = torch.tensor([[el] for el in trick]).to(self.device)
# fake_batch = [[el] for el in trick]
labels = [fake_batch, energy_batch, ang_batch, ecal_batch]
Expand Down Expand Up @@ -728,10 +727,10 @@ def on_validation_epoch_end(self):
print(ROW_FMT.format("discriminator (test)",
*self.test_history["discriminator"][-1]))

# save loss dict to pkl file
with open(self.pklfile, "wb") as f:
pickle.dump({"train": self.train_history,
"test": self.test_history}, f)
# # save loss dict to pkl file
# with open(self.pklfile, "wb") as f:
# pickle.dump({"train": self.train_history,
# "test": self.test_history}, f)
# pickle.dump({"test": self.test_history}, open(self.pklfile, "wb"))
# print("train-loss:" + str(self.train_history["generator"][-1][0]))

Expand Down
4 changes: 1 addition & 3 deletions use-cases/3dgan/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ executor:
plugins: null
profiler: null
reload_dataloaders_every_n_epochs: 0
strategy: auto #ddp_find_unused_parameters_true #auto
strategy: auto #ddp_find_unused_parameters_true
sync_batchnorm: false
use_distributed_sampler: true
val_check_interval: null
Expand All @@ -81,11 +81,9 @@ executor:
class_path: model.ThreeDGAN
init_args:
latent_size: 256
batch_size: 4
loss_weights: [3, 0.1, 25, 0.1]
power: 0.85
lr: 0.001
checkpoint_path: exp_data/3dgan.pth

# Lightning data module configuration
data:
Expand Down

0 comments on commit f2ccfae

Please sign in to comment.