diff --git a/jac-mtllm/mtllm/aott.py b/jac-mtllm/mtllm/aott.py index 6d056ddc00..32bcfa2f8d 100644 --- a/jac-mtllm/mtllm/aott.py +++ b/jac-mtllm/mtllm/aott.py @@ -131,8 +131,27 @@ def aott_raise( if not (contains_media and not is_custom) else meaning_typed_input_list ) - return model(meaning_typed_input, media=media, **model_params) # type: ignore - + if is_custom: + try: + # This is a temporary solution to enable passing in custom + # parameters to custom models + # custom model should override the __call__ method to + # accept function_inputs parameter + return model( + meaning_typed_input, # type: ignore + media=media, # type: ignore + function_inputs=inputs_information, # type: ignore + **model_params, + ) + except TypeError: + # this is for backward compatibility, + # for any existing custom models that do not have the + # function_inputs parameter + return model( + meaning_typed_input, media=media, **model_params # type: ignore + ) + else: + return model(meaning_typed_input, media=media, **model_params) # type: ignore def execute_react( model: BaseLLM,