From 54aa31ebfa974f990297ed52f44556ee7225cad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 17 Oct 2023 18:18:19 +0200 Subject: [PATCH] address review --- optimum/exporters/onnx/model_patcher.py | 47 +++++++++++++++++++++---- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e965b4df722..827d1578874 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -15,6 +15,7 @@ import dataclasses import functools import inspect +import types from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union import transformers @@ -378,6 +379,42 @@ def falcon_model_forward_without_kv_reformatting( class FalconModelPatcher(ModelPatcher): + def __enter__(self): + self.patch_ops() + + transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched + + if self.real_config.task == "text-generation": + self._model.transformer.forward = types.MethodType( + falcon_model_forward_without_kv_reformatting, self._model.transformer + ) + + # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. + if isinstance(self._model, FalconModel): + self._model._prepare_attn_mask = _falcon_prepare_attn_mask + else: + self._model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask + + setattr(self._model, self.orig_forward_name, self.patched_forward) + + def __exit__(self, exc_type, exc_value, traceback): + self.restore_ops() + + setattr(self._model, self.orig_forward_name, self.orig_forward) + + if self.real_config.task == "text-generation": + self._model.transformer.forward = types.MethodType( + self.original_model_transformer_forward, self._model.transformer + ) + + transformers.models.falcon.modeling_falcon._make_causal_mask = self.original_make_causal + + # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. + if isinstance(self._model, FalconModel): + self._model._prepare_attn_mask = self.original_falcon_prepare_attn_mask + else: + self._model.transformer._prepare_attn_mask = self.original_falcon_prepare_attn_mask + def __init__( self, config: "OnnxConfig", @@ -386,19 +423,15 @@ def __init__( ): super().__init__(config, model, model_kwargs) - # This is kind of ugly and bug prone if other FalconModel are instantiated. if config.task == "text-generation": - model.transformer.__class__.forward = falcon_model_forward_without_kv_reformatting + self.original_model_transformer_forward = model.transformer.forward self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask - transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched - - # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. if isinstance(model, FalconModel): - model._prepare_attn_mask = _falcon_prepare_attn_mask + self.original_falcon_prepare_attn_mask = model._prepare_attn_mask else: - model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask + self.original_falcon_prepare_attn_mask = model.transformer._prepare_attn_mask self._model = model