From ff81dee23880e3992aa64f5e604a209ccfef91ee Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 3 Jan 2025 12:32:43 -0500 Subject: [PATCH] kernel inception distance --- ml4h/models/diffusion_blocks.py | 54 ++++++++++++++++++++++++++++----- ml4h/models/train.py | 20 ++++++++---- ml4h/recipes.py | 2 ++ 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 32cda7659..74f4e04cd 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -303,11 +303,15 @@ def compile(self, **kwargs): self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") self.image_loss_tracker = keras.metrics.Mean(name="i_loss") - self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75) + if self.tensor_map.axes() == 3: + self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75) @property def metrics(self): - return [self.noise_loss_tracker, self.image_loss_tracker, self.kid] + m = [self.noise_loss_tracker, self.image_loss_tracker] + if self.tensor_map.axes() == 3: + m.append(self.kid) + return m def denormalize(self, images): # convert the pixel values back to 0-1 range @@ -469,11 +473,12 @@ def test_step(self, images_original): # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small - images = self.denormalize(images) - generated_images = self.generate( - num_images=self.batch_size, diffusion_steps=20 - ) - self.kid.update_state(images, generated_images) + if self.tensor_map.axes() == 3: + images = self.denormalize(images) + generated_images = self.generate( + num_images=self.batch_size, diffusion_steps=20 + ) + self.kid.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics} @@ -607,6 +612,7 @@ class DiffusionController(keras.Model): def __init__( self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, + supervisor = None, ): super().__init__() @@ -620,6 +626,7 @@ def __init__( self.ema_network = keras.models.clone_model(self.network) self.use_sigmoid_loss = diffusion_loss == 'sigmoid' self.beta = sigmoid_beta + self.supervisor = supervisor def compile(self, **kwargs): @@ -627,11 +634,16 @@ def compile(self, **kwargs): self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") self.image_loss_tracker = keras.metrics.Mean(name="i_loss") + if self.supervisor is not None: + self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss") # self.kid = KID(name = "kid", input_shape = self.tensor_map.shape) @property def metrics(self): - return [self.noise_loss_tracker, self.image_loss_tracker] + m = [self.noise_loss_tracker, self.image_loss_tracker] + if self.supervisor is not None: + m.append(self.supervised_loss_tracker) + return m def denormalize(self, images): # convert the pixel values back to 0-1 range @@ -751,6 +763,17 @@ def train_step(self, batch): lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared) weight = tf.math.sigmoid(self.beta - lambda_t) noise_loss = weight * noise_loss + if self.supervisor is not None: + loss_fn = tf.keras.losses.MeanSquaredError() + supervised_preds = self.supervisor(pred_images, training=True) + supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) + self.supervised_loss_tracker.update_state(supervised_loss) + # Combine losses: add noise_loss and supervised_loss + noise_loss += 0.01 * supervised_loss + + # Gradients for self.supervised_model + supervised_gradients = tape.gradient(supervised_loss, self.supervisor.trainable_weights) + self.optimizer.apply_gradients(zip(supervised_gradients, self.supervisor.trainable_weights)) gradients = tape.gradient(noise_loss, self.network.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) @@ -827,6 +850,21 @@ def test_step(self, batch): noise_loss = self.loss(noises, pred_noises) image_loss = self.loss(images, pred_images) + if self.use_sigmoid_loss: + signal_rates_squared = tf.square(signal_rates) + noise_rates_squared = tf.square(noise_rates) + + # Compute log-SNR (lambda_t) + lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared) + weight = tf.math.sigmoid(self.beta - lambda_t) + noise_loss = weight * noise_loss + if self.supervisor is not None: + loss_fn = tf.keras.losses.MeanSquaredError() + supervised_preds = self.supervisor(pred_images, training=True) + supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) + self.supervised_loss_tracker.update_state(supervised_loss) + # Combine losses: add noise_loss and supervised_loss + noise_loss += 0.01*supervised_loss self.image_loss_tracker.update_state(image_loss) self.noise_loss_tracker.update_state(noise_loss) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 18a9c3fe0..9c9f4f3a9 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -279,13 +279,21 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba plt.close() -def train_diffusion_control_model(args): +def train_diffusion_control_model(args, supervised=False): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) - model = DiffusionController( - args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, - args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, - args.sigmoid_beta, args.diffusion_condition_strategy, - ) + if supervised: + supervised_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + model = DiffusionController( + args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, + args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, + args.sigmoid_beta, args.diffusion_condition_strategy, supervised_model, + ) + else: + model = DiffusionController( + args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, + args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, + args.sigmoid_beta, args.diffusion_condition_strategy, + ) loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error model.compile( diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 753b27ad4..4642303b7 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -116,6 +116,8 @@ def run(args): train_diffusion_model(args) elif 'train_diffusion_control' == args.mode: train_diffusion_control_model(args) + elif 'train_diffusion_supervised' == args.mode: + train_diffusion_control_model(args, supervised=True) elif 'train_siamese' == args.mode: train_siamese_model(args) elif 'write_tensor_maps' == args.mode: