From 3b5587569d2ad21d2ca53f375e1e958f16f67f4f Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Fri, 30 Aug 2024 17:15:44 +0200 Subject: [PATCH 1/5] Follow up the diffusers task refactoring (#1999) * fix * fix style --- optimum/exporters/tasks.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index f02f1769233..97053040879 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1937,12 +1937,6 @@ def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrai if inferred_model_type is not None: break - if inferred_model_type is None: - raise ValueError( - f"The export of a DiffusionPipeline model with the class name {model.__class__.__name__} is currently not supported in Optimum. " - "Please open an issue or submit a PR to add the support." - ) - # `model_type` is a class attribute in Transformers, let's avoid modifying it. model.config.export_model_type = inferred_model_type @@ -2068,9 +2062,16 @@ def get_model_from_task( if original_task == "auto" and config.architectures is not None: model_class_name = config.architectures[0] - model_class = TasksManager.get_model_class_for_task( - task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name - ) + if library_name == "diffusers": + config = DiffusionPipeline.load_config(model_name_or_path, **kwargs) + class_name = config.get("_class_name", None) + loaded_library = importlib.import_module(library_name) + model_class = getattr(loaded_library, class_name) + else: + model_class = TasksManager.get_model_class_for_task( + task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name + ) + if library_name == "timm": model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True) model = model.to(torch_dtype).to(device) From 7cc57e40f84e00f8ebc2849da303e40575fb23b4 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:35:14 +0200 Subject: [PATCH 2/5] Transformers 4.44 support (#1996) * test * fix conll2003 dataset with remote code * sdpa for new bloom attention block * style * fix bloom modeling * better version ranges to reflect max and min transformers support * pin right version * use input dims --- optimum/bettertransformer/models/attention.py | 218 ++++++++++++------ .../models/decoder_models.py | 2 + optimum/exporters/onnx/model_configs.py | 38 +-- optimum/onnxruntime/modeling_decoder.py | 21 +- optimum/utils/input_generators.py | 36 +-- .../preprocessing/token_classification.py | 2 +- setup.py | 11 +- tests/bettertransformer/test_decoder.py | 6 +- 8 files changed, 210 insertions(+), 124 deletions(-) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 6c8f16f057c..9dfa57844d4 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -15,6 +15,9 @@ from typing import Optional, Tuple import torch +import torch.nn.functional as F + +from ...utils import check_if_transformers_greater # TODO (CRITICAL): Layer-wise attention scaling is broken for several archs. @@ -23,7 +26,7 @@ def raise_on_head_mask(head_mask: Optional[torch.Tensor]): if head_mask is not None: raise ValueError( - "layer_head_mask different than None is unsupported for now with BetterTransformer, please" + "layer_head_mask (or head_mask) different than None is unsupported for now with BetterTransformer, please" "open a PR or an issue at https://github.com/huggingface/optimum." ) @@ -534,88 +537,159 @@ def bart_forward( return attn_output, None, past_key_value -# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward -def bloom_forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs, -): - raise_on_head_mask(head_mask) +if check_if_transformers_greater("4.44"): + from transformers.cache_utils import Cache + from transformers.models.bloom.modeling_bloom import dropout_add + + # Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward + def bloom_forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Cache] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + raise_on_head_mask(head_mask) + + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + + batch_size, q_length, _ = hidden_states.shape + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + + if attention_mask is not None: # no matter the length, we just slice it + kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1 + causal_mask = attention_mask[:, :, :, :kv_length] + alibi = torch.masked_fill(alibi, causal_mask.bool(), torch.finfo(alibi.dtype).min) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=alibi, + dropout_p=self.dropout_prob_attn if self.training else 0.0, + ) - if output_attentions is True: - raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape(batch_size, q_length, self.hidden_size) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + outputs = (output_tensor, layer_past) - batch_size, q_length, _, _ = query_layer.shape + return outputs - # Permute to [batch_size, num_heads, seq_length, head_dim] - query_layer = query_layer.transpose(1, 2) +else: + # Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward + def bloom_forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + **kwargs, + ): + raise_on_head_mask(head_mask) - if layer_past is not None: - past_key, past_value = layer_past - past_key = past_key.transpose(1, 2) + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") - key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) - # concatenate along seq_length dimension - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - # untangle batch_size from self.num_heads - key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:]) - value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:]) - else: - key_layer = key_layer.transpose(1, 2) - value_layer = value_layer.transpose(1, 2) - - alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min) - - context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=alibi, - dropout_p=self.dropout_prob_attn if self.training else 0.0, - ) + batch_size, q_length, _, _ = query_layer.shape - # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] - context_layer = context_layer.transpose(1, 2) - context_layer = context_layer.reshape(*context_layer.shape[:2], -1) - - # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 - if self.pretraining_tp > 1 and self.slow_but_exact: - slices = self.hidden_size / self.pretraining_tp - output_tensor = torch.zeros_like(context_layer) - for i in range(self.pretraining_tp): - output_tensor = output_tensor + torch.nn.functional.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - output_tensor = self.dense(context_layer) + # Permute to [batch_size, num_heads, seq_length, head_dim] + query_layer = query_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + past_key = past_key.transpose(1, 2) - output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training) - output_tensor = residual + output_tensor + key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - if use_cache is True: - present = ( - key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2), - value_layer.reshape(-1, *value_layer.shape[2:]), + # concatenate along seq_length dimension + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # untangle batch_size from self.num_heads + key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:]) + value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:]) + else: + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=alibi, + dropout_p=self.dropout_prob_attn if self.training else 0.0, ) - else: - present = None - return (output_tensor, present) + # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape(*context_layer.shape[:2], -1) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + torch.nn.functional.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training) + output_tensor = residual + output_tensor + + if use_cache is True: + present = ( + key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2), + value_layer.reshape(-1, *value_layer.shape[2:]), + ) + else: + present = None + + return (output_tensor, present) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 4bcc057373a..b64b7f5a1eb 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -216,6 +216,8 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.dropout_prob_attn = config.attention_dropout self.module_mapping = None + self.layer_idx = getattr(layer, "layer_idx", None) + submodules = ["query_key_value", "dense", "attention_dropout"] for attr in submodules: setattr(self, attr, getattr(layer, attr)) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 3e11c7e614a..d4b15b2968b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -338,27 +338,31 @@ class BloomOnnxConfig(TextDecoderOnnxConfig): ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") + DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): - if direction not in ["inputs", "outputs"]: - raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - - if direction == "inputs": - decoder_sequence_name = "past_sequence_length" - name = "past_key_values" + if check_if_transformers_greater("4.44"): + super().add_past_key_values(inputs_or_outputs, direction) else: - decoder_sequence_name = "past_sequence_length + 1" - name = "present" + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - for i in range(self._normalized_config.num_layers): - inputs_or_outputs[f"{name}.{i}.key"] = { - 0: "batch_size x num_heads", - 2: decoder_sequence_name, - } - inputs_or_outputs[f"{name}.{i}.value"] = { - 0: "batch_size x num_heads", - 1: decoder_sequence_name, - } + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = { + 0: "batch_size x num_heads", + 2: decoder_sequence_name, + } + inputs_or_outputs[f"{name}.{i}.value"] = { + 0: "batch_size x num_heads", + 1: decoder_sequence_name, + } class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 6a0dcbba2f0..f6d4b7e20ab 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -336,8 +336,7 @@ def prepare_past_key_values( dtype = constructor.float16 if self.use_fp16 else constructor.float32 # TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY. - # "1" is the dummy sequence length - if self.model_type == "bloom": + if self.__class__.__name__ == "ORTBloomForCausalLM": shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head) shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0) key = constructor.zeros(shape_key, dtype=dtype) @@ -354,9 +353,9 @@ def prepare_past_key_values( for name, value in zip(self.key_value_output_names, past_key_values): shape = [*value.shape] index = 1 if "value" in name else 2 - shape[index] += sequence_length pkv_output_shape[name] = shape + elif self.model_type == "gpt_bigcode": # GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor. shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) @@ -371,9 +370,9 @@ def prepare_past_key_values( shape = [*value.shape] shape[1] += sequence_length pkv_output_shape[name] = shape + else: num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads - shape = (batch_size, num_key_value_heads, 0, embed_size_per_head) key_or_value = constructor.zeros(shape, dtype=dtype) @@ -534,9 +533,9 @@ def _from_pretrained( # Since https://github.com/huggingface/optimum/pull/871/ # changed axis notation/naming during export, we need to update the dims - for dim in input_dims.keys(): - if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length": - input_dims[dim][2] = "past_sequence_length" + for input_name in input_dims.keys(): + if "past" in input_name and input_dims[input_name][2] == "past_sequence_length + sequence_length": + input_dims[input_name][2] = "past_sequence_length" override_dims = True if override_dims: @@ -559,6 +558,12 @@ def _from_pretrained( size_threshold=0, ) + # Since transformers 4.44, the bloom model has been updated to use the standard cache format + use_old_bloom_modeling = not check_if_transformers_greater("4.44") + for input_name in input_dims.keys(): + if input_dims[input_name][0] == "batch_size x num_heads": + use_old_bloom_modeling = True + del onnx_model model = ORTModel.load_model( @@ -568,7 +573,7 @@ def _from_pretrained( provider_options=provider_options, ) - if config.model_type == "bloom": + if config.model_type == "bloom" and use_old_bloom_modeling: init_cls = ORTBloomForCausalLM elif config.model_type == "falcon": init_cls = ORTFalconForCausalLM diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 36913f652a8..dac14a38114 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -22,6 +22,7 @@ import numpy as np from transformers.utils import is_tf_available, is_torch_available +from ..utils import check_if_transformers_greater from .normalized_config import ( NormalizedConfig, NormalizedEncoderDecoderConfig, @@ -1026,23 +1027,26 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - past_key_shape = ( - self.batch_size * self.num_attention_heads, - self.hidden_size // self.num_attention_heads, - self.sequence_length, - ) - past_value_shape = ( - self.batch_size * self.num_attention_heads, - self.sequence_length, - self.hidden_size // self.num_attention_heads, - ) - return [ - ( - self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), - self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + if check_if_transformers_greater("4.44"): + return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) + else: + past_key_shape = ( + self.batch_size * self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + self.sequence_length, ) - for _ in range(self.num_layers) - ] + past_value_shape = ( + self.batch_size * self.num_attention_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] class MultiQueryPastKeyValuesGenerator(DummyPastKeyValuesGenerator): diff --git a/optimum/utils/preprocessing/token_classification.py b/optimum/utils/preprocessing/token_classification.py index 1c59aa2285b..64a0bf2da8a 100644 --- a/optimum/utils/preprocessing/token_classification.py +++ b/optimum/utils/preprocessing/token_classification.py @@ -28,7 +28,7 @@ class TokenClassificationProcessing(TaskProcessor): ACCEPTED_PREPROCESSOR_CLASSES = (PreTrainedTokenizerBase,) - DEFAULT_DATASET_ARGS = "conll2003" + DEFAULT_DATASET_ARGS = {"path": "conll2003", "trust_remote_code": True} DEFAUL_DATASET_DATA_KEYS = {"primary": "tokens"} ALLOWED_DATA_KEY_NAMES = {"primary"} DEFAULT_REF_KEYS = ["ner_tags", "pos_tags", "chunk_tags"] diff --git a/setup.py b/setup.py index 2e8c9489a89..3ac4315321b 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ REQUIRED_PKGS = [ "coloredlogs", "sympy", - "transformers[sentencepiece]>=4.29.0,<4.44.0", + "transformers[sentencepiece]>=4.29,<4.45.0", "torch>=1.11", "packaging", "numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569 @@ -24,6 +24,7 @@ ] # TODO: unpin pytest once https://github.com/huggingface/transformers/pull/29154 is merged & released +# pytest>=8.0.0 also fails with the transformers version pinned for exporters-tf TESTS_REQUIRE = [ "accelerate", "pytest<=8.0.0", @@ -72,7 +73,7 @@ "timm", "h5py", "numpy<1.24.0", - "transformers[sentencepiece]>=4.26.0,<4.38.0", + "transformers[sentencepiece]>=4.26,<4.38", ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", @@ -80,9 +81,9 @@ "nncf": "optimum-intel[nncf]>=1.18.0", "neural-compressor": "optimum-intel[neural-compressor]>=1.18.0", "ipex": "optimum-intel[ipex]>=1.18.0", - "habana": ["optimum-habana", "transformers >= 4.43.0, < 4.44.0"], - "neuron": ["optimum-neuron[neuron]>=0.0.20", "transformers >= 4.36.2, < 4.42.0"], - "neuronx": ["optimum-neuron[neuronx]>=0.0.20", "transformers >= 4.36.2, < 4.42.0"], + "habana": ["optimum-habana", "transformers>=4.43.0,<4.44.0"], + "neuron": ["optimum-neuron[neuron]>=0.0.20", "transformers>=4.36.2,<4.42.0"], + "neuronx": ["optimum-neuron[neuronx]>=0.0.20", "transformers>=4.36.2,<4.42.0"], "graphcore": "optimum-graphcore", "furiosa": "optimum-furiosa", "amd": "optimum-amd", diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index 42340d3b3aa..bab8f376fcc 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -23,7 +23,6 @@ from optimum.bettertransformer import BetterTransformer from optimum.utils import ( - BloomDummyPastKeyValuesGenerator, DummyPastKeyValuesGenerator, NormalizedConfigManager, ) @@ -136,10 +135,7 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config) - if model_type == "bloom": - pkv_generator_class = BloomDummyPastKeyValuesGenerator - else: - pkv_generator_class = DummyPastKeyValuesGenerator + pkv_generator_class = DummyPastKeyValuesGenerator pkv_generator = pkv_generator_class( task="", normalized_config=normalized_config, batch_size=batch_size, sequence_length=seq_length From ad98dc944be4308f405ab34e78fa85b16c7d3709 Mon Sep 17 00:00:00 2001 From: Longjie Zheng <32992656+zhenglongjiepheonix@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:40:14 -0400 Subject: [PATCH 3/5] Modify Parallelization Strategy to Make it More General (#1988) * modify parallelization strategy * only support model id in api now * more comments * more comments * address comments * remove idle runner * fix * format * more comments * nit --- .../workflows/test_fx_automatic_parallel.yml | 2 +- optimum/fx/parallelization/api.py | 87 ++-- optimum/fx/parallelization/core.py | 5 + optimum/fx/parallelization/decomp.py | 225 +++++++++ .../parallelization/op_registry/__init__.py | 15 + .../op_registry/op_handlers.py | 450 ++++++++++++++++++ optimum/fx/parallelization/passes.py | 350 +++++--------- optimum/fx/parallelization/utils.py | 29 +- 8 files changed, 878 insertions(+), 285 deletions(-) create mode 100644 optimum/fx/parallelization/decomp.py create mode 100644 optimum/fx/parallelization/op_registry/__init__.py create mode 100644 optimum/fx/parallelization/op_registry/op_handlers.py diff --git a/.github/workflows/test_fx_automatic_parallel.yml b/.github/workflows/test_fx_automatic_parallel.yml index 3c913e3f7ed..d8af6e40caa 100644 --- a/.github/workflows/test_fx_automatic_parallel.yml +++ b/.github/workflows/test_fx_automatic_parallel.yml @@ -24,7 +24,7 @@ jobs: config: - name: GPU-enabled Optimum Test Suite image: nvidia/cuda:12.4.1-devel-ubuntu22.04 - gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"] + gpu_target: ["nvidia-multi-gpu-a10-runners"] name: ${{ matrix.config.name }} runs-on: diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index bd307bd93c1..9700b491e52 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -15,10 +15,11 @@ import importlib import os from functools import partial -from typing import List, Union +from typing import Callable, List import torch from torch.fx import GraphModule +from transformers import AutoConfig from .core import Config, ParallelExecutionCtx from .passes import build_parallel_pass_pipeline @@ -43,30 +44,31 @@ def parallelize_backend( def parallelize_model( - model: Union[torch.nn.Module, str], + model: str, parallel_ctx: ParallelExecutionCtx, *model_args, **kwargs, -): +) -> Callable: """ API for automatic model parallelism through Pytorch FX. Args: - model (Union[torch.nn.Module, str]): - Model to parallelize, could either be a module or a model id on the Huggingface Hub. - parallel_ctx (ParallelExecutionCtx): + model (`str`): + Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights + of the model. + parallel_ctx (`ParallelExecutionCtx`): Parallel execution context containing process groups the current process belongs to. - *model_args (Any): + *model_args (`Any`): Additional postional arguments for intializing the model if a model id is passed. - revision (str, defaults to `main`): + revision (`str`, defaults to `main`): Model revision for weights downloading if a model id is passed. - cache_dir (Optional[str], defaults to `None`): + cache_dir (`Optional[str]`, defaults to `None`): Cache directory to store downloaded weights. Defaults to None. - local_files_only (bool, defaults to `False`): + local_files_only (`bool`, defaults to `False`): Whether to use local files only, will avoid downloading from remote if set to `True`. - skip_load_weights (bool, defaults to `False`): + skip_load_weights (`bool`, defaults to `False`): Whether to skip loading weights from disk to model. - **kwargs (Dict[str, Any]): + **kwargs (`Dict[str, Any]`): Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. """ revision = kwargs.pop("revision", "main") @@ -80,44 +82,41 @@ def parallelize_model( setattr(parallel_config, k, v) kwargs.pop(k) - if isinstance(model, str): - from transformers import AutoConfig - - is_local = os.path.isdir(model) - if not is_local: - hf_folder = download_model_from_hf( - model_name_or_path=model, - cache_dir=cache_dir, - revision=revision, - local_files_only=local_files_only, - skip_download_weights=skip_load_weights, - ) - else: - hf_folder = model - - # should be able to load config using only local files - model_config, kwargs = AutoConfig.from_pretrained( - hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + is_local = os.path.isdir(model) + if not is_local: + hf_folder = download_model_from_hf( + model_name_or_path=model, + cache_dir=cache_dir, + revision=revision, + local_files_only=local_files_only, + skip_download_weights=skip_load_weights, ) + else: + hf_folder = model - # try getting model class info from config - model_arch = model_config.architectures - model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) + # should be able to load config using only local files + model_config, kwargs = AutoConfig.from_pretrained( + hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + ) - if not skip_load_weights: - parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) + # try getting model class info from config + model_arch = model_config.architectures + model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) - torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None - if torch_dtype is not None: - dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) + if not skip_load_weights: + parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) - with MetaAwareMethodsPatcher(): - model = model_cls(model_config, *model_args, **kwargs) - # TODO: remove this once support training-time trace - model.eval() + torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None + if torch_dtype is not None: + dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) + with MetaAwareMethodsPatcher(): + model = model_cls(model_config, *model_args, **kwargs) + # TODO: remove this once support training-time trace + model.eval() + + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) move_model_to_device(model, device=parallel_ctx.current_device) initialize_parameter_meta(model) diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 1d13b00b468..84737292f07 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -166,8 +166,13 @@ class Config: - weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`) Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, if not provided weights loading path. + + - enable_sequence_parallel (`bool`, defaults to `False`): + Whether to enable Megatron-style sequence parallelism in searching parallelization + strategies. """ lint_and_recompile: bool = True clean_markers_after_all_passes: bool = True weight_init_fn: Callable = partial(nn.init.normal_, std=0.02) + enable_sequence_parallel: bool = False diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py new file mode 100644 index 00000000000..26258d451bf --- /dev/null +++ b/optimum/fx/parallelization/decomp.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from typing import Callable, Dict, List + +import torch +import torch.nn.functional as F +import torch.utils._pytree as pytree +from torch import SymBool, SymFloat, SymInt +from torch._decomp import core_aten_decompositions +from torch._functorch._aot_autograd.functional_utils import from_fun, to_fun +from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, disable_functional_mode +from torch.fx import Graph, GraphModule, Interpreter, Proxy, traceback +from torch.fx.experimental.proxy_tensor import ( + ProxyTorchDispatchMode, + _ProxyTensor, + _SymNodeDict, + decompose, + disable_proxy_modes_tracing, + fetch_object_proxy, + fetch_sym_proxy, + get_proxy_slot, + track_tensor_tree, +) +from torch.fx.proxy import GraphAppendingTracer +from torch.utils.weak import WeakTensorKeyDictionary + + +def is_leaf_module(m): + return (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) and not isinstance( + m, torch.nn.Sequential + ) + + +@contextlib.contextmanager +def trace_decomp_origin(): + creat_node = Graph.create_node + + def create_node_(*args, **kwargs): + node = creat_node(*args, **kwargs) + node.meta["traced_from"] = traceback.get_current_meta()["from_node"] + return node + + try: + Graph.create_node = create_node_ + yield + finally: + Graph.create_node = creat_node + + +class DecompTracer(GraphAppendingTracer): + """ + DecompTracer is a tracer class which works together with `DecompositionInterpreter`, it keeps track of tensors and their + corresponding proxy objects during execution process. When invoked with `create_proxy`, it creates a node in the containing + graph and associate the output tensor of the node with the created proxy. + + See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details. + """ + + def __init__(self, graph: Graph): + super().__init__(graph) + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() + + +class DecompositionInterpreter(Interpreter): + """ + DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose + high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. + + Notes: + - Certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific + heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts + in the orignal graph module. + + - The traced graph is a low-level equivalent representation of the original graph module, and is only used for + parallel axis propagation and analysis, the original graph module is still used for real execution. + """ + + def __init__( + self, module: GraphModule, new_graph: Graph, decomposition_table=None, leaf_function_targets=None, **kwargs + ): + super().__init__(module, **kwargs) + self.new_graph = new_graph + self.tracer = DecompTracer(new_graph) + + self.decomposition_table = decomposition_table + if self.decomposition_table is None: + self.decomposition_table = {} + + self.leaf_function_targets = leaf_function_targets + if self.leaf_function_targets is None: + self.leaf_function_targets = [] + + self.fun_mode = FunctionalTensorMode() + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + def placeholder(self, target, args, kwargs): + out = super().placeholder(target, args, kwargs) + out = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), out) + proxy = self.tracer.create_proxy("placeholder", target, args, kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + def call_function(self, target, args, kwargs): + if target in self.leaf_function_targets: + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + + with disable_proxy_modes_tracing(), disable_functional_mode(): + out = target(*args, **kwargs) + + args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs)) + proxy_args, proxy_kwargs = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + fetch_sym_proxy(self.tracer), + pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)), + ) + proxy = self.tracer.create_proxy("call_function", target, proxy_args, proxy_kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + return super().call_function(target, args, kwargs) + + def call_module(self, target, args, kwargs): + assert isinstance(target, str) + submod = self.fetch_attr(target) + if not is_leaf_module(submod): + return super().call_module(target, args, kwargs) + + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + + with disable_proxy_modes_tracing(), disable_functional_mode(): + out = submod(*args, **kwargs) + + args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs)) + proxy_args, proxy_kwargs = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + fetch_sym_proxy(self.tracer), + pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)), + ) + proxy = self.tracer.create_proxy("call_module", target, proxy_args, proxy_kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + def get_attr(self, target, args, kwargs): + out = super().get_attr(target, args, kwargs) + proxy = Proxy(self.new_graph.get_attr(target), self.tracer) + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + def output(self, target, args, kwargs): + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + out = super().output(target, args, kwargs) + + def unwrap(e): + return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node) + + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args, **kwargs): + with self.fun_mode: + args = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), args) + kwargs = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), kwargs) + with traceback.preserve_node_meta(), trace_decomp_origin(), decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) + + +def decompose_and_functionalize( + graph_module: GraphModule, + decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(), + leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], +) -> Callable: + """ + API to decompose and functionalize a high-level graph module. + + Args: + graph_module (`GraphModule`): + The high-level graph module to be decomposed and functionalized. + decomposition_table (`Dict[torch._ops.OperatorBase, Callable]`, defaults to `core_aten_decompostions()`): + The lookup table which maps high-level torch op to their equivalent low-level implementation. + leaf_function_targets (`List[Callable]`, defaults to `[F.scaled_dot_product_attention]`): + Functions which will not be traced through for convenience, `F.scaled_dot_product_attention` is + treated as a leaf function by default so that we don't have to deal with all detailed version of + sdpas in the traced graph. + + Returns: + Callable: a wrapper which returns the traced low-level graph when called with concrete arguments. + """ + new_graph = Graph(owning_module=graph_module) + interp = DecompositionInterpreter(graph_module, new_graph, decomposition_table, leaf_function_targets) + + def wrapper(*args, **kwargs): + interp.run(*args, **kwargs) + return new_graph + + return wrapper diff --git a/optimum/fx/parallelization/op_registry/__init__.py b/optimum/fx/parallelization/op_registry/__init__.py new file mode 100644 index 00000000000..8f8df0f7bd0 --- /dev/null +++ b/optimum/fx/parallelization/op_registry/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .op_handlers import REGISTRY, FallbackParallelAxisPropagateHandler diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py new file mode 100644 index 00000000000..56b8fc16bc0 --- /dev/null +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -0,0 +1,450 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from typing import Any, List, Optional + +import torch +from torch.fx import Node + +from ..core import Config +from ..utils import is_activation, is_embedding, is_linear + + +class Registry: + """ + Registry class handles registration of parallel axis propagation handlers of different aten ops. + To support a new aten op, you need to register the corresponding handler class by decorating it with `register` function. + """ + + def __init__(self) -> None: + self.mapping = {} + + def register(self, op_types): + def wrapper(cls): + if isinstance(op_types, (list, tuple)): + for op_type in op_types: + self.mapping[op_type] = cls + else: + self.mapping[op_types] = cls + return cls + + return wrapper + + def is_supported(self, op_type) -> bool: + return op_type in self.mapping + + +REGISTRY = Registry() + + +class OpParallelAxisPropagateHandler: + def __init__(self, node: Node, meta_key: str, config: Config) -> None: + self.node = node + self.meta_key = meta_key + self.config = config + + def extract_axis(self, arg: Any) -> Optional[int]: + if not isinstance(arg, Node): + return None + return arg.meta[self.meta_key].get("parallel_axis", None) + + @abstractmethod + def propagate(self) -> List[int]: + raise NotImplementedError + + +@REGISTRY.register( + [ + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.rsqrt.default, + torch.ops.aten.clone.default, + torch.ops.aten.bitwise_not.default, + torch.ops.aten.abs.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.acos.default, + torch.ops.aten.acosh.default, + torch.ops.aten.alias.default, + torch.ops.aten.asin.default, + torch.ops.aten.asinh.default, + torch.ops.aten.atan.default, + torch.ops.aten.atanh.default, + torch.ops.aten.ceil.default, + torch.ops.aten.clamp.default, + torch.ops.aten.cos.default, + torch.ops.aten.cosh.default, + torch.ops.aten.erf.default, + torch.ops.aten.exp.default, + torch.ops.aten.trunc.default, + torch.ops.aten.tanh.default, + torch.ops.aten.tan.default, + torch.ops.aten.add.Scalar, + torch.ops.aten.sub.Scalar, + torch.ops.aten.sqrt.default, + torch.ops.aten.sin.default, + torch.ops.aten.sinh.default, + torch.ops.aten.sign.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.round.default, + torch.ops.aten.remainder.Scalar, + torch.ops.aten.relu.default, + torch.ops.aten.reciprocal.default, + torch.ops.aten.neg.default, + torch.ops.aten.ne.Scalar, + torch.ops.aten.native_dropout.default, + torch.ops.aten.mul.Scalar, + torch.ops.aten.logical_not.default, + torch.ops.aten.lt.Scalar, + torch.ops.aten.le.Scalar, + torch.ops.aten.log.default, + torch.ops.aten.log10.default, + torch.ops.aten.log2.default, + torch.ops.aten.log1p.default, + torch.ops.aten.leaky_relu.default, + torch.ops.aten.isnan.default, + torch.ops.aten.isinf.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.gt.Scalar, + torch.ops.aten.gelu.default, + torch.ops.aten.ge.Scalar, + torch.ops.aten.fmod.Scalar, + torch.ops.aten.floor.default, + torch.ops.aten.fill.Scalar, + torch.ops.aten.div.Scalar_mode, + torch.ops.aten.div.Scalar, + torch.ops.aten.bitwise_and.Scalar, + torch.ops.aten.bitwise_or.Scalar, + torch.ops.aten.bitwise_xor.Scalar, + ] +) +class UnaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg = self.node.all_input_nodes[0] + axis = self.extract_axis(arg) + return [axis] + + +@REGISTRY.register( + [ + torch.ops.aten.atan2.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.bitwise_and.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.bitwise_xor.Tensor, + torch.ops.aten.div.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.eq.Tensor, + torch.ops.aten.fmod.Tensor, + torch.ops.aten.ge.Tensor, + torch.ops.aten.gt.Tensor, + torch.ops.aten.le.Tensor, + torch.ops.aten.logical_and.default, + torch.ops.aten.logical_or.default, + torch.ops.aten.logical_xor.default, + torch.ops.aten.lt.Tensor, + torch.ops.aten.maximum.default, + torch.ops.aten.minimum.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.ne.Tensor, + torch.ops.aten.pow.Tensor_Tensor, + torch.ops.aten.remainder.Tensor, + torch.ops.aten.sub.Tensor, + ] +) +class BinaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + input_nodes = self.node.all_input_nodes + # only one node + if len(input_nodes) == 1: + return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() + + assert len(input_nodes) == 2, "binary op should have exact two nodes as inputs" + lhs_shape, rhs_shape = input_nodes[0].meta["val"].shape, input_nodes[1].meta["val"].shape + lhs_axis = self.extract_axis(input_nodes[0]) + rhs_axis = self.extract_axis(input_nodes[1]) + i, j = len(lhs_shape) - 1, len(rhs_shape) - 1 + while i >= 0 and j >= 0: + k = max(lhs_shape[i], rhs_shape[j]) + assert ( + k % min(lhs_shape[i], rhs_shape[j]) == 0 + ), f"shape {lhs_shape} and {rhs_shape} are not broadcastable!" + i -= 1 + j -= 1 + + if i < 0 and lhs_axis is not None: + lhs_axis += j + 1 + if j < 0 and rhs_axis is not None: + rhs_axis += i + 1 + + if lhs_axis is None: + return [rhs_axis] + elif rhs_axis is None: + return [lhs_axis] + elif lhs_axis != rhs_axis: + return [] + return [lhs_axis] + + +@REGISTRY.register( + [ + torch.ops.aten.amax.default, + torch.ops.aten.amin.default, + torch.ops.aten.any.dim, + torch.ops.aten._log_softmax.default, + torch.ops.aten._softmax.default, + torch.ops.aten.cumsum.default, + torch.ops.aten.mean.dim, + # torch.ops.aten.min.dim, + # torch.ops.aten.max.dim, + torch.ops.aten.var.dim, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.prod.dim_int, + ] +) +class ReductionOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def extract_dims( + self, + ) -> List[int]: + ndim = self.node.meta["val"].ndim + dims = None + if "dim" in self.node.kwargs: + dims = self.node.kwargs["dim"] + elif len(self.node.args) > 1 and isinstance(self.node.args[1], (int, list)): + dims = self.node.args[1] + + if isinstance(dims, int): + dims = [dims] + if not dims: + dims = list(range(ndim)) + dims = [(dim + ndim) % ndim for dim in dims] + + keepdim = False + if "keepdim" in self.node.kwargs: + keepdim = self.node.kwargs + elif len(self.node.args) > 2 and isinstance(self.node.args[2], bool): + keepdim = self.node.args[2] + + return dims, keepdim + + def propagate(self) -> List[int]: + dims, keepdim = self.extract_dims() + arg = self.node.all_input_nodes[0] + axis = self.extract_axis(arg) + if axis in dims: + return [] + if axis is None: + return [None] + if keepdim: + return [axis] + return [axis - sum([1 if dim < axis else 0 for dim in dims])] + + +@REGISTRY.register(torch.ops.aten.view.default) +class ViewLikeOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg = self.node.args[0] + axis = self.extract_axis(arg) + if axis is None: + return [None] + shape_before, shape_after = arg.meta["val"].shape, self.node.meta["val"].shape + size = 1 + for i in range(len(shape_before) - 1, axis - 1, -1): + size *= shape_before[i] + + cur, i, res = 1, len(shape_after) - 1, [] + while cur <= size and i >= 0: + cur *= shape_after[i] + if cur == size: + res.append(i) + i -= 1 + + return res + + +@REGISTRY.register(torch.ops.aten.unsqueeze.default) +class UnsqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dim = self.node.args[0], self.node.args[1] + ndim = arg.meta["val"].ndim + axis = self.extract_axis(arg) + if axis is None: + return [None] + dim = (dim + ndim) % ndim + if dim <= axis: + return [axis + 1] + return [axis] + + +@REGISTRY.register( + [ + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + ] +) +class SqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dims = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + + ndim = self.node.args[0].meta["val"].ndim + if isinstance(dims, int): + dims = [dims] + dims = [(dim + ndim) % ndim for dim in dims] + if axis in dims: + # being conservative + return [] + return [axis - sum([1 if dim < axis else 0 for dim in dims])] + + +@REGISTRY.register(torch.ops.aten.permute.default) +class PermuteParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dims = self.node.args[0], self.node.args[1] + ndim = arg.meta["val"].ndim + axis = self.extract_axis(arg) + if axis is None: + return [None] + + for i, dim in enumerate(dims): + if (dim + ndim) % ndim == axis: + return [i] + return [] + + +@REGISTRY.register(torch.ops.aten.slice.Tensor) +class SliceParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, slice_dim = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + ndim = arg.meta["val"].ndim + slice_dim = (slice_dim + ndim) % ndim + if slice_dim == axis: + # slice on the parallel axis is not allowed + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.expand.default) +class ExpandParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, size = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + assert len(size) >= arg.meta["val"].ndim, "input size must be broadcastable to the target size in expand" + return [axis + len(size) - arg.meta["val"].ndim] + + +@REGISTRY.register(torch.ops.aten.cat.default) +class CatParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + nodes, cat_axis = self.node.all_input_nodes, self.node.args[1] + axis, ndim = self.extract_axis(nodes[0]), nodes[0].meta["val"].ndim + cat_axis = (cat_axis + ndim) % ndim + if cat_axis == axis: + return [] + for i in range(1, len(nodes)): + if self.extract_axis(nodes[i]) != axis: + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.constant_pad_nd.default) +class PadParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + pad, ndim = self.node.args[1], self.node.args[0].meta["val"].ndim + axis = self.extract_axis(self.node.args[0]) + if axis is None: + return [None] + if axis >= ndim - pad // 2: + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.copy.default) +class CopyParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + dst, src = self.node.all_input_nodes + axis_dst = self.extract_axis(dst) + axis_src = self.extract_axis(src) + if axis_dst != axis_src: + return [] + return [axis_dst] + + +@REGISTRY.register(torch.nn.functional.scaled_dot_product_attention) +class SpdaAttnParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + q, k, v = self.node.args[:3] + q_axis = self.extract_axis(q) + # parallel axis must be the head dimension if being parallelized + if q_axis != self.extract_axis(k) or q_axis != self.extract_axis(v) or q_axis not in {None, 1}: + return [] + return [q_axis] + + +class FallbackParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + # by default we don't parallelize inputs and constants(except parameters embeded in modules) + if self.node.op in ["placeholder", "get_attr"]: + return [None] + elif self.node.op == "output": + for node in self.node.all_input_nodes: + # TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather + # parallelized output if intructed + if self.extract_axis(node) is not None: + return [] + return [None] + elif is_linear(self.node): + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + if axis is None: + # with input being not parallelized, output can be parallelized on the head dimension, + # i.e., `ColumnLinear`, or not being parallelized by all-gather at the end + return [2, None] + elif self.config.enable_sequence_parallel and axis == 1: + # with input being parallelized on sequence dimension, output can be parallelized on + # the head dimension, i.e., `ColumnLinear` with sequence parallel, or not being parallelized + # by all-gather at the end + return [2, None] + elif axis == 2: + # with input being parallelized on head dimension, output can be parallelized on the + # sequence dimension or not parallelized by all-reduce at the end, i.e., `RowLinear` + # when sp is not enabled + return [1, None] if self.config.enable_sequence_parallel else [None] + else: + return [] + elif is_embedding(self.node): + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + if axis is None: + # only support the embedding parameter being parallelized on `vocab` dim or not parallelized for now, + # the output can be parallelized on sequence dim or not parallelized + return [1, None] if self.config.enable_sequence_parallel else [None] + else: + return [] + elif is_activation(self.node): + return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() + + # last resort, if no input is being parallelized, then we make output also not parallelized, + # this will give us relief on writing policies for strange ops which don't actually need + # parallelization in most cases + if all(self.extract_axis(arg) is None for arg in self.node.all_input_nodes): + return [None] + + raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}") diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 379b027d400..14b652fff73 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -23,15 +23,14 @@ from torch.fx import Graph, GraphModule, Node from .core import Config, ParallelExecutionCtx, ParameterMeta +from .decomp import decompose_and_functionalize from .distributed import scatter +from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .utils import ( is_embedding, is_linear, - is_permute, is_shape_consumer, - is_shape_generator, - is_transpose, stable_topological_sort, ) @@ -135,238 +134,151 @@ def clean_all(self, graph_module: GraphModule) -> None: self.clear_marker_per_node(node) -class ParallelLayerAnnotatePass(AnalyzeBase): +class ParallelAxisSolverPass(AnalyzeBase): """ - A pass which tries to automatically identify parallel layers in the graph. Note that for simplicity - we only consider classical ways of parallelizing layers in transformers architecture for now, we are not - solving an optimization problem which tries to give a best solution of parallelizing any model under - memory/hardware constraints. - - For `nn.Embedding` layers, we parallelize them on the vocabulary dim by default, because they are often tied - to the `lm_head` of the model, which is usually a `ColumnLinear`(parallelized on vocab dim). - - For `nn.Linear` layers, we parallelize them by grouping them as `upstream` nodes and `downstream` nodes, and - `upstream` nodes are marked as `ColumnLinear`, `downstream` nodes are marked as `RowLinear`. - - Typical examples in transformer models: - - Attention Bert-style MLP Llama-style MLP - __________________________________________________________________________ - Linear Linear Linear Linear - \\ / | \\ --> upstream - Matmul Linear Activation Activation Linear - __________________________________________________________________________ - \\ / | \\ / - \\ / ___________ \\ / - Matmul / Linear \ Mul - | / \ | - _______________________________/ \___________________________ - Linear Linear --> downstream - - Note that there are some patterns that can not be clearly marked, like this one: - - Linear - | \\ - | Linear <-- which label should we mark for the intermediate linear, `upstream` or `downstream` - | / - Add - | - Linear - - For patterns like this we will be conservative and raise errors directly because we don't know how to parallelize - it. Another concern is about the correctness, it's possible that we might end up with a wrong parallelization solution - even if the pattern itself is clear, but for now we are mainly targeting on transformer models and the current solution - should work fairly well. + A pass which tries to automatically identify parallel layers in the graph. There are three steps + involved to find a possible parallel solution given the traced graph module and process group. + + - Decompostion & Functionalization + The vanilla graph traced by dynamo frontend is a high-level graph which contains high-level + pytorch ops, and there could be thousands of them, which makes graph analysis hard in order + to cover all cases. So we decompose the high-level graph into low-level graph which only + conrtains core aten ops, which is a much smaller set. And functionalization is also needed + to remove inplace ops in the graph so that we get `aten.Add` instead of `aten.Add_` in the + graph, which furthur reduces the op set we need to consider. + + - Parallel Axis Propagation + We need to write parallel axis propagation rules for aten ops in the decomposed and functionalized + graph, note that we don't need to cover every possible parallelization strategy because in general + only certain ops(usually involves computation) can be parallelized in transformer models. And we just + need to write rules for a subset of core aten op set in order to support most of the transformer models. + + - Backtracking Search + After we have defined parallel axis propagation rules for each op in the graph, we do a brute force + backtracking search to try to find a possible solution which respects the propagation rule of every + op in the graph. + + + Note that there are several practical concerns + + - Time Complexity. Although brute force backtracking introduces an exponential time complexity, we reduces + the search space by injecting human heuristics. First, we only consider parallelization on the head dimension + (for tensor parallel) or the sequence dimension(to support sequence parallel), then at any time the tensor is + parallelized on at most one dimension. Second, we only allow axis switch around certain layers(like `nn.Linear` + or `nn.Embedding), and all other ops fall into their places by the parallel axis of their input and rules we write. + + - Optimal Solution. Note that since we return the first solution we find, then it might not be optimal in terms of + memory consumption and communication overhead. But again we can adjust the order of search and try parallelize + as much as we can first before fall back to non-parallelized search paths. And we don't pay too much attention + on calculating communication overhead because in practice they are bounded under the constraint that only certain + layers are allowed to communicate. + + Our goal is not to solve an optimization problem which tries to give a best solution of parallelizing any model under memory/hardware + constraints, but rather a cheap solution which relieves you from writing boilerplate code for parallelizing layers of different models. """ - def try_form_parallel_linear_groups(self, linear: Node) -> None: - """ - We try to form linears by forming closures in a greedy way, we start with an unmarked linear node, and traverses down - recusively to find all the potential `downstream` linears, note that once we have reached a linear, the recursion stops. - And the newly found `downstream` linears are used as new seeds to traverse upwards to find all the potential `upstream` - linears, the process goes on until number of linears on both sides converges. - Args: - linear (Node): the first linear node used as `upstream` node seed to form closure. - - Raises: - RuntimeError: - raises runtime error when the pattern itself is not clear, there are no clear boundaries that can be drawn. - """ - upstream_nodes, downstream_nodes = {linear}, set() - - seeds, next_seeds = [(linear, "down")], [] - - def traverse(start: Node, cur: Node, direction: str = "down"): - if is_linear(cur) and cur is not start: - if direction == "up" and cur not in upstream_nodes: - upstream_nodes.add(cur) - next_seeds.append((cur, "down")) - elif direction == "down" and cur not in downstream_nodes: - downstream_nodes.add(cur) - next_seeds.append((cur, "up")) - return - - next_nodes = cur.all_input_nodes if direction == "up" else cur.users - for node in next_nodes: - # we should ignore shape-related dependencies - if is_shape_generator(node): - continue - traverse(start, node, direction) - - while seeds: - next_seeds = [] - for node, direction in seeds: - traverse(start=node, cur=node, direction=direction) - seeds = next_seeds - - if any(self.already_executed_per_node(node) for node in (upstream_nodes | downstream_nodes)) or ( - upstream_nodes & downstream_nodes - ): - raise RuntimeError( - "Failed to automatically group and parallelize ops in graph in greedy way: " - "no clear boudaries between `upstream` and `downstream` ops." - ) - - for node in upstream_nodes: - self.place_marker_per_node(node, {"axis": "column", "gather_output": False if downstream_nodes else True}) - - for node in downstream_nodes: - self.place_marker_per_node(node, {"axis": "row", "input_is_parallel": True}) + def trace_back(self, graph_module: GraphModule, decomp_graph: Graph) -> None: + node_map = {node.name: node for node in graph_module.graph.nodes} + + for node in decomp_graph.nodes: + if "traced_from" in node.meta: + node_name, _ = node.meta["traced_from"][0] + assert node_name in node_map, f"un-recognized node origin {node_name} not in graph being traced" + orig_node = node_map[node_name] + self.clear_marker_per_node(orig_node) + self.place_marker_per_node( + orig_node, {"parallel_axis": self.get_stored_field_info(node, field="parallel_axis")} + ) def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: - graph: Graph = graph_module.graph + graph: Graph = decompose_and_functionalize(graph_module)(*ctx.example_inputs) stable_topological_sort(graph) - for node in graph.nodes: - if is_linear(node) and not self.already_executed_per_node(node): - self.try_form_parallel_linear_groups(node) - elif is_embedding(node): - # directly mark `nn.Embedding` layers - self.place_marker_per_node(node, {"axis": "vocab"}) - return graph_module + nodes = list(graph.nodes) + def search(idx: int): + if idx == len(nodes): + return True -class ParallelAxisPropagationPass(AnalyzeBase): - """ - A pass which tries to track which axis is being parallelized in the dataflow. For transformer models, the - axis being paralled for tensor parallism is almost always 2, i.e., the attention head axis, except for - Q and K matrices which need to swap the sequence length axis and head axis to do the attention computation, - so we focus on operations like `transpose` or `permute` which swaps axis, and try inducting the parallel - axis after these operations. - """ + node = nodes[idx] + if node.op == "call_function" and REGISTRY.is_supported(node.target): + prop_cls = REGISTRY.mapping[node.target] + else: + prop_cls = FallbackParallelAxisPropagateHandler - def propagate_transpose(self, node: Node, parallel_axis: int) -> bool: - dims = node.meta["example_value"].dim() - if "dim0" in node.kwargs and "dim1" in node.kwargs: - dim0, dim1 = node.kwargs["dim0"], node.kwargs["dim1"] - elif len(node.args) == 3: - dim0, dim1 = node.args[1:] - - dim0 = (dim0 + dims) % dims - dim1 = (dim1 + dims) % dims - - if dim0 == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": dim1}) - return True - elif dim1 == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": dim0}) - return True - return False - - def propagate_permute(self, node: Node, parallel_axis: int) -> bool: - if "dims" in node.kwargs: - dims = node.kwargs["dims"] - else: - dims = ( - list(node.args[1]) - if isinstance(node.args[1], tuple) - else [arg for arg in node.args if isinstance(arg, int)] - ) + prop = prop_cls(node, self.meta_key(), config) + axis_candidates = prop.propagate() + for axis in axis_candidates: + self.place_marker_per_node(node, {"parallel_axis": axis}) + if search(idx + 1): + return True + self.clear_marker_per_node(node) - dim_len = node.meta["example_value"].dim() - dims = [dim + dim_len if dim < 0 else dim for dim in dims] + return False - for i, dim in enumerate(dims): - if dim == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": i}) - return True - return False - - def propagate_getitem(self, node: Node, parallel_axis: int) -> bool: - slices = node.args[1] - dims = node.meta["example_value"].dim() - assert parallel_axis < dims - inc, i, j = 0, 0, 0 - - while i < parallel_axis and j < len(slices): - if isinstance(slices[j], int): - inc -= 1 - i += 1 - elif slices[j] is None: - inc += 1 - elif slices[j] is Ellipsis: - i = dims - k = j - while k < len(slices): - if slices[k] is not Ellipsis: - i -= 1 - k += 1 - else: - i += 1 - j += 1 + if not search(0): + raise RuntimeError("Failed to find a solution to automatically parallelize ops in graph in greedy way.") - if inc != 0: - assert parallel_axis + inc < dims and parallel_axis + inc >= 0 - self.place_marker_per_node(node, {"parallel_axis": parallel_axis + inc}) - return True - return False + self.trace_back(graph_module, graph) + return graph_module + + +class ParallelLayerAnnotatePass(AnalyzeBase): + """ + This pass annotates layers which have different parallel axis(requires communication inside the layer) in their + input and output tensors. Since heuristics applied during the searching process respect traditional classical ways of + parallelizing layers(like Megatron-style `ColumnLinear` or `RowLinear`), we are guaranteed to match a valid replacement + annotation according to parallelization strategy of input and output tensors. + """ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: - g: Graph = graph_module.graph - stable_topological_sort(g) + for node in graph_module.graph.nodes: + if is_linear(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis") + info = {} + if axis_before is None: + info["axis"] = "column" + info["gather_output"] = True if axis_after is None else False + elif axis_before == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["axis"] = "column" + info["sequence_parallel"] = True + info["gather_output"] = True if axis_after is None else False + elif axis_before == 2: + info["axis"] = "row" + info["input_is_parallel"] = True + if axis_after == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["sequence_parallel"] = True + else: + info["sequence_parallel"] = False + self.place_marker_per_node(node, info) - for node in g.nodes: - if ParallelLayerAnnotatePass.already_executed_per_node(node): - # start propagating at ColumnLinear, marking the beginning of parallelized region - axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis", must_have=True) - gather_output = ParallelLayerAnnotatePass.get_stored_field_info(node, field="gather_output") - if axis == "column" and not gather_output: - self.place_marker_per_node(node, {"parallel_axis": 2}) - # stop propagating at RowLinear, concluding the ending of parallelized region + elif is_embedding(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis") + assert axis_before is None and axis_after in [1, None] + info = {"axis": "vocab"} + if axis_after == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["sequence_parallel"] = True else: - continue - else: - already_marked_args, parallel_axis = [], None - for arg in node.all_input_nodes: - if not self.already_executed_per_node(arg): - continue - if parallel_axis is None: - parallel_axis = self.get_stored_field_info(arg, field="parallel_axis", must_have=True) - else: - assert parallel_axis == self.get_stored_field_info( - arg, field="parallel_axis", must_have=True - ), "`parallel_axis` should be equal for all arguments in any related ops" - already_marked_args.append(arg) - - if not already_marked_args: - continue - - marked = False - if is_transpose(node): - marked = self.propagate_transpose(node, parallel_axis) - elif is_permute(node): - marked = self.propagate_permute(node, parallel_axis) - - # fall back - if not marked: - self.place_marker_per_node(node, {"parallel_axis": parallel_axis}) + info["sequence_parallel"] = False + self.place_marker_per_node(node, info) + return graph_module class ParallelLayerReplacePass(PassBase): """ - A pass which modifies graph according to information provided by previous analytical passes, - in general it does two things for now: + A pass which modifies graph according to information provided by previous analytical passes, in general it does two things for now: 1. replaces linears and embedding layers with their parallel counterparts. 2. modifies hard-coded arguments like the number of attention heads in the graph by dividing it by parallelism level. """ @@ -453,7 +365,7 @@ def update(node: Node, new_shape: List[Any], parallel_axis: int): else: node.update_arg(parallel_axis + 1, shape[parallel_axis]) - parallel_axis = ParallelAxisPropagationPass.get_stored_field_info(node, field="parallel_axis") + parallel_axis = ParallelAxisSolverPass.get_stored_field_info(node, field="parallel_axis") if parallel_axis is None: return @@ -582,18 +494,18 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf def build_parallel_pass_pipeline() -> PassPipeline: """ Ensemble a pass pipeline which contains the following passes: - 1. `ParallelLayerAnnotatePass` to annoate which linears are `ColumnLinear`, which are `RowLinear` - 2. `ParallelAxisPropagationPass` to propate parallel axis along the data flow - 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes - 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters + 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. + 2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step. + 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes. + 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters. Returns: PassPipeline: the pipeline used for automatic parallelism. """ return PassPipeline( [ + ParallelAxisSolverPass(), ParallelLayerAnnotatePass(), - ParallelAxisPropagationPass(), ParallelLayerReplacePass(), InitializeOrLoadWeightsPass(), ] diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index f129ffbd402..b7b1ccd41c8 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -17,7 +17,6 @@ import hashlib import importlib import json -import operator import os import re import tempfile @@ -45,6 +44,14 @@ def ensure_divisibility(numerator: int, denominator: int) -> None: ) +def is_activation(node: Node) -> bool: + # only consider leaf Module activations + if node.op != "call_module": + return False + mod = node.graph.owning_module + return getattr(mod.get_submodule(node.target), "__module__", "").startswith("torch.nn.modules.activation") + + def is_linear(node: Node) -> bool: if node.op != "call_module": return False @@ -67,26 +74,6 @@ def is_shape_consumer(node: Node) -> bool: return False -def is_transpose(node: Node) -> bool: - if node.op == "call_method": - return node.target in {"transpose", "transpose_"} - elif node.op == "call_function": - return node.target is torch.transpose - return False - - -def is_permute(node: Node) -> bool: - if node.op == "call_method": - return node.target in {"permute"} - elif node.op == "call_function": - return node.target is torch.permute - return False - - -def is_getitem(node: Node) -> bool: - return node.op == "call_function" and node.target is operator.getitem - - def is_output(node: Node) -> bool: return node.op == "output" From bb46ebea547a2545c33c36f77067406f687187b8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:23:07 +0200 Subject: [PATCH 4/5] Modify token classification processor default dataset args (#2005) --- optimum/utils/preprocessing/token_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/utils/preprocessing/token_classification.py b/optimum/utils/preprocessing/token_classification.py index 64a0bf2da8a..1c59aa2285b 100644 --- a/optimum/utils/preprocessing/token_classification.py +++ b/optimum/utils/preprocessing/token_classification.py @@ -28,7 +28,7 @@ class TokenClassificationProcessing(TaskProcessor): ACCEPTED_PREPROCESSOR_CLASSES = (PreTrainedTokenizerBase,) - DEFAULT_DATASET_ARGS = {"path": "conll2003", "trust_remote_code": True} + DEFAULT_DATASET_ARGS = "conll2003" DEFAUL_DATASET_DATA_KEYS = {"primary": "tokens"} ALLOWED_DATA_KEY_NAMES = {"primary"} DEFAULT_REF_KEYS = ["ner_tags", "pos_tags", "chunk_tags"] From 8cb6832a2797f54ec1221ff5014a81d961016b6b Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:54:59 +0200 Subject: [PATCH 5/5] Fix TFLite tests (#2007) downgrade datasets --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 3ac4315321b..98ee4f36a3f 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ "timm", "h5py", "numpy<1.24.0", + "datasets<=2.16", "transformers[sentencepiece]>=4.26,<4.38", ], "diffusers": ["diffusers"],