Skip to content

Commit

Permalink
condition strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Dec 13, 2024
1 parent 5928cd7 commit c53e449
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def apply(x):
residual = x
else:
residual = conv(width, kernel_size=1)(x)

if 'cross_attention' == condition_strategy:
x = layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control)
elif 'concat' == condition_strategy:
Expand Down
2 changes: 1 addition & 1 deletion ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def train_diffusion_control_model(args):
model = DiffusionController(
args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x,
args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss,
args.sigmoid_beta,
args.sigmoid_beta, args.diffusion_condition_strategy,
)

loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error
Expand Down

0 comments on commit c53e449

Please sign in to comment.