Skip to content

Commit

Permalink
condition strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Dec 13, 2024
1 parent 1eee793 commit 5928cd7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
4 changes: 4 additions & 0 deletions ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def parse_args():
help='For diffusion models, this controls how frequently Cross-Attention is applied. '
'2 means every other residual block, 3 would mean every third.',
)
parser.add_argument(
'--diffusion_condition_strategy', default='concat', choices=['cross_attention', 'concat', 'film'],
help='For diffusion models, this controls conditional embeddings are integrated into the U-NET',
)
parser.add_argument(
'--diffusion_loss', default='sigmoid',
help='Loss function to use for diffusion models. Can be sigmoid, mean_absolute_error, or mean_squared_error',
Expand Down
46 changes: 34 additions & 12 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,35 @@ def apply(x):
return apply


def residual_block_control(width, conv, kernel_size, attention_heads):
def condition_layer_film(input_tensor, control_vector, filters):
# Transform control into gamma and beta
gamma = layers.Dense(filters, activation="linear")(control_vector)
beta = layers.Dense(filters, activation="linear")(control_vector)

# Reshape gamma and beta to match the spatial dimensions
gamma = tf.reshape(gamma, (-1, 1, 1, filters))
beta = tf.reshape(beta, (-1, 1, 1, filters))

# Apply FiLM (Feature-wise Linear Modulation)
return input_tensor * gamma + beta


def residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy):
def apply(x):
x, control = x
input_width = x.shape[-1]
if input_width == width:
residual = x
else:
residual = conv(width, kernel_size=1)(x)
x = keras.layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control)

if 'cross_attention' == condition_strategy:
x = layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control)
elif 'concat' == condition_strategy:
x = layers.Concatenate()([x, control])
elif 'film' == condition_strategy:
x = condition_layer_film(x, control, width)

x = layers.BatchNormalization(center=False, scale=False)(x)
x = conv(
width, kernel_size=kernel_size, padding="same", activation=keras.activations.swish
Expand All @@ -123,33 +143,33 @@ def apply(x):
return apply


def down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads):
def down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads, condition_strategy):
def apply(x):
x, skips, control = x
for _ in range(block_depth):
x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control])
x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control])
skips.append(x)
x = pool(pool_size=2)(x)
return x

return apply


def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads):
def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads, condition_strategy):
def apply(x):
x, skips, control = x
# x = upsample(size=2, interpolation="bilinear")(x)
x = upsample(size=2)(x)
for _ in range(block_depth):
x = layers.Concatenate()([x, skips.pop()])
x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control])
x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control])
return x

return apply


def get_control_network(input_shape, widths, block_depth, kernel_size, control_size,
attention_start, attention_heads, attention_modulo):
attention_window, attention_heads, attention_modulo, condition_strategy):
noisy_images = keras.Input(shape=input_shape)
noise_variances = keras.Input(shape=[1] * len(input_shape))

Expand Down Expand Up @@ -177,7 +197,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s
c2 = upsample(size=x.shape[1:-1])(control[control_idxs])
else:
c2 = upsample(size=x.shape[-2])(control[control_idxs])
x = down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads)([x, skips, c2])
x = down_block_control(width, block_depth, conv, pool,
kernel_size, attention_heads, condition_strategy)([x, skips, c2])
else:
x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips])

Expand All @@ -187,15 +208,16 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s
c2 = upsample(size=x.shape[-2])(control[control_idxs])

for i in range(block_depth):
x = residual_block_control(widths[-1], conv, kernel_size, attention_heads)([x, c2])
x = residual_block_control(widths[-1], conv, kernel_size, attention_heads, condition_strategy)([x, c2])

for i, width in enumerate(reversed(widths[:-1])):
if attention_modulo > 1 and i % attention_modulo == 0:
if len(input_shape) > 2:
c2 = upsample(size=x.shape[1:-1])(control[control_idxs])
else:
c2 = upsample(size=x.shape[-2])(control[control_idxs])
x = up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads)([x, skips, c2])
x = up_block_control(width, block_depth, conv, upsample,
kernel_size, attention_heads, condition_strategy)([x, skips, c2])
else:
x = up_block(width, block_depth, conv, upsample, kernel_size)([x, skips])

Expand All @@ -217,7 +239,7 @@ def get_control_embed_model(output_maps, control_size):
class DiffusionController(keras.Model):
def __init__(
self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size,
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy,
):
super().__init__()

Expand All @@ -227,7 +249,7 @@ def __init__(
self.control_embed_model = get_control_embed_model(self.output_maps, control_size)
self.normalizer = layers.Normalization()
self.network = get_control_network(self.input_map.shape, widths, block_depth, conv_x, control_size,
attention_start, attention_heads, attention_modulo)
attention_start, attention_heads, attention_modulo, condition_strategy)
self.ema_network = keras.models.clone_model(self.network)
self.use_sigmoid_loss = diffusion_loss == 'sigmoid'
self.beta = sigmoid_beta
Expand Down

0 comments on commit 5928cd7

Please sign in to comment.