diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 77097648..b738c8f3 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -210,11 +210,8 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): processor = FluxAttnProcessor2_0() if torch.cuda.is_available(): - rank = ( - torch.distributed.get_rank() - if torch.distributed.is_initialized() - else 0 - ) + # let's assume that the box only ever has H100s. + rank = 0 primary_device = torch.cuda.get_device_properties(rank) if primary_device.major == 9 and primary_device.minor == 0: if is_flash_attn_available: