diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index c43ba1c2..fb8ab1bc 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -433,7 +433,9 @@ def save_sharded_safetensors( ".safetensors", "{suffix}.safetensors" ) state_dict_split = split_torch_state_dict_into_shards( - input_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, ) index = { "metadata": state_dict_split.metadata,