diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 01c20e804..f1111f3c0 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -232,6 +232,10 @@ def parse_args(): help='For diffusion models, this controls how frequently Cross-Attention is applied. ' '2 means every other residual block, 3 would mean every third.', ) + parser.add_argument( + '--diffusion_condition_strategy', default='concat', choices=['cross_attention', 'concat', 'film'], + help='For diffusion models, this controls conditional embeddings are integrated into the U-NET', + ) parser.add_argument( '--diffusion_loss', default='sigmoid', help='Loss function to use for diffusion models. Can be sigmoid, mean_absolute_error, or mean_squared_error', diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index d90048ff7..f1afea6d8 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -102,7 +102,20 @@ def apply(x): return apply -def residual_block_control(width, conv, kernel_size, attention_heads): +def condition_layer_film(input_tensor, control_vector, filters): + # Transform control into gamma and beta + gamma = layers.Dense(filters, activation="linear")(control_vector) + beta = layers.Dense(filters, activation="linear")(control_vector) + + # Reshape gamma and beta to match the spatial dimensions + gamma = tf.reshape(gamma, (-1, 1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, 1, filters)) + + # Apply FiLM (Feature-wise Linear Modulation) + return input_tensor * gamma + beta + + +def residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy): def apply(x): x, control = x input_width = x.shape[-1] @@ -110,7 +123,14 @@ def apply(x): residual = x else: residual = conv(width, kernel_size=1)(x) - x = keras.layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control) + + if 'cross_attention' == condition_strategy: + x = layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control) + elif 'concat' == condition_strategy: + x = layers.Concatenate()([x, control]) + elif 'film' == condition_strategy: + x = condition_layer_film(x, control, width) + x = layers.BatchNormalization(center=False, scale=False)(x) x = conv( width, kernel_size=kernel_size, padding="same", activation=keras.activations.swish @@ -123,11 +143,11 @@ def apply(x): return apply -def down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads): +def down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads, condition_strategy): def apply(x): x, skips, control = x for _ in range(block_depth): - x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control]) + x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control]) skips.append(x) x = pool(pool_size=2)(x) return x @@ -135,21 +155,21 @@ def apply(x): return apply -def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads): +def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads, condition_strategy): def apply(x): x, skips, control = x # x = upsample(size=2, interpolation="bilinear")(x) x = upsample(size=2)(x) for _ in range(block_depth): x = layers.Concatenate()([x, skips.pop()]) - x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control]) + x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control]) return x return apply def get_control_network(input_shape, widths, block_depth, kernel_size, control_size, - attention_start, attention_heads, attention_modulo): + attention_window, attention_heads, attention_modulo, condition_strategy): noisy_images = keras.Input(shape=input_shape) noise_variances = keras.Input(shape=[1] * len(input_shape)) @@ -177,7 +197,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) - x = down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads)([x, skips, c2]) + x = down_block_control(width, block_depth, conv, pool, + kernel_size, attention_heads, condition_strategy)([x, skips, c2]) else: x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) @@ -187,7 +208,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s c2 = upsample(size=x.shape[-2])(control[control_idxs]) for i in range(block_depth): - x = residual_block_control(widths[-1], conv, kernel_size, attention_heads)([x, c2]) + x = residual_block_control(widths[-1], conv, kernel_size, attention_heads, condition_strategy)([x, c2]) for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and i % attention_modulo == 0: @@ -195,7 +216,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) - x = up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads)([x, skips, c2]) + x = up_block_control(width, block_depth, conv, upsample, + kernel_size, attention_heads, condition_strategy)([x, skips, c2]) else: x = up_block(width, block_depth, conv, upsample, kernel_size)([x, skips]) @@ -217,7 +239,7 @@ def get_control_embed_model(output_maps, control_size): class DiffusionController(keras.Model): def __init__( self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta + attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, ): super().__init__() @@ -227,7 +249,7 @@ def __init__( self.control_embed_model = get_control_embed_model(self.output_maps, control_size) self.normalizer = layers.Normalization() self.network = get_control_network(self.input_map.shape, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, attention_modulo) + attention_start, attention_heads, attention_modulo, condition_strategy) self.ema_network = keras.models.clone_model(self.network) self.use_sigmoid_loss = diffusion_loss == 'sigmoid' self.beta = sigmoid_beta