Skip to content

Commit

Permalink
kernel inception distance
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 3, 2025
1 parent a5d518d commit ff81dee
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
54 changes: 46 additions & 8 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,15 @@ def compile(self, **kwargs):

self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75)
if self.tensor_map.axes() == 3:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75)

@property
def metrics(self):
return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]
m = [self.noise_loss_tracker, self.image_loss_tracker]
if self.tensor_map.axes() == 3:
m.append(self.kid)
return m

def denormalize(self, images):
# convert the pixel values back to 0-1 range
Expand Down Expand Up @@ -469,11 +473,12 @@ def test_step(self, images_original):

# measure KID between real and generated images
# this is computationally demanding, kid_diffusion_steps has to be small
images = self.denormalize(images)
generated_images = self.generate(
num_images=self.batch_size, diffusion_steps=20
)
self.kid.update_state(images, generated_images)
if self.tensor_map.axes() == 3:
images = self.denormalize(images)
generated_images = self.generate(
num_images=self.batch_size, diffusion_steps=20
)
self.kid.update_state(images, generated_images)

return {m.name: m.result() for m in self.metrics}

Expand Down Expand Up @@ -607,6 +612,7 @@ class DiffusionController(keras.Model):
def __init__(
self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size,
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy,
supervisor = None,
):
super().__init__()

Expand All @@ -620,18 +626,24 @@ def __init__(
self.ema_network = keras.models.clone_model(self.network)
self.use_sigmoid_loss = diffusion_loss == 'sigmoid'
self.beta = sigmoid_beta
self.supervisor = supervisor


def compile(self, **kwargs):
super().compile(**kwargs)

self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
if self.supervisor is not None:
self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss")
# self.kid = KID(name = "kid", input_shape = self.tensor_map.shape)

@property
def metrics(self):
return [self.noise_loss_tracker, self.image_loss_tracker]
m = [self.noise_loss_tracker, self.image_loss_tracker]
if self.supervisor is not None:
m.append(self.supervised_loss_tracker)
return m

def denormalize(self, images):
# convert the pixel values back to 0-1 range
Expand Down Expand Up @@ -751,6 +763,17 @@ def train_step(self, batch):
lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared)
weight = tf.math.sigmoid(self.beta - lambda_t)
noise_loss = weight * noise_loss
if self.supervisor is not None:
loss_fn = tf.keras.losses.MeanSquaredError()
supervised_preds = self.supervisor(pred_images, training=True)
supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds)
self.supervised_loss_tracker.update_state(supervised_loss)
# Combine losses: add noise_loss and supervised_loss
noise_loss += 0.01 * supervised_loss

# Gradients for self.supervised_model
supervised_gradients = tape.gradient(supervised_loss, self.supervisor.trainable_weights)
self.optimizer.apply_gradients(zip(supervised_gradients, self.supervisor.trainable_weights))

gradients = tape.gradient(noise_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
Expand Down Expand Up @@ -827,6 +850,21 @@ def test_step(self, batch):

noise_loss = self.loss(noises, pred_noises)
image_loss = self.loss(images, pred_images)
if self.use_sigmoid_loss:
signal_rates_squared = tf.square(signal_rates)
noise_rates_squared = tf.square(noise_rates)

# Compute log-SNR (lambda_t)
lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared)
weight = tf.math.sigmoid(self.beta - lambda_t)
noise_loss = weight * noise_loss
if self.supervisor is not None:
loss_fn = tf.keras.losses.MeanSquaredError()
supervised_preds = self.supervisor(pred_images, training=True)
supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds)
self.supervised_loss_tracker.update_state(supervised_loss)
# Combine losses: add noise_loss and supervised_loss
noise_loss += 0.01*supervised_loss

self.image_loss_tracker.update_state(image_loss)
self.noise_loss_tracker.update_state(noise_loss)
Expand Down
20 changes: 14 additions & 6 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,21 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba
plt.close()


def train_diffusion_control_model(args):
def train_diffusion_control_model(args, supervised=False):
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
model = DiffusionController(
args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x,
args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss,
args.sigmoid_beta, args.diffusion_condition_strategy,
)
if supervised:
supervised_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__)
model = DiffusionController(
args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x,
args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss,
args.sigmoid_beta, args.diffusion_condition_strategy, supervised_model,
)
else:
model = DiffusionController(
args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x,
args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss,
args.sigmoid_beta, args.diffusion_condition_strategy,
)

loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error
model.compile(
Expand Down
2 changes: 2 additions & 0 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def run(args):
train_diffusion_model(args)
elif 'train_diffusion_control' == args.mode:
train_diffusion_control_model(args)
elif 'train_diffusion_supervised' == args.mode:
train_diffusion_control_model(args, supervised=True)
elif 'train_siamese' == args.mode:
train_siamese_model(args)
elif 'write_tensor_maps' == args.mode:
Expand Down

0 comments on commit ff81dee

Please sign in to comment.