diff --git a/jac-mtllm/examples/inherit_basellm.jac b/jac-mtllm/examples/inherit_basellm.jac index b55d03ab47..66ba1289b1 100644 --- a/jac-mtllm/examples/inherit_basellm.jac +++ b/jac-mtllm/examples/inherit_basellm.jac @@ -7,6 +7,9 @@ glob PROMPT_TEMPLATE: str = """ [Information] {information} +[Context] +{context} + [Output Information] {output_information} @@ -15,7 +18,7 @@ glob PROMPT_TEMPLATE: str = """ [Action] {action} -"""; +"""; # [Context] will not be appear in the prompt obj Florence :BaseLLM: { with entry { @@ -91,7 +94,7 @@ enum DamageType { } can "" -predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True); +predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True,raw_output=True); with entry { img = 'car_scratch.jpg'; diff --git a/jac-mtllm/mtllm/aott.py b/jac-mtllm/mtllm/aott.py index 881f3b858f..6d056ddc00 100644 --- a/jac-mtllm/mtllm/aott.py +++ b/jac-mtllm/mtllm/aott.py @@ -128,7 +128,7 @@ def aott_raise( return f"[Output] {result}" meaning_typed_input = ( "\n".join(meaning_typed_input_list) # type: ignore - if not contains_media + if not (contains_media and not is_custom) else meaning_typed_input_list ) return model(meaning_typed_input, media=media, **model_params) # type: ignore diff --git a/jac-mtllm/mtllm/plugin.py b/jac-mtllm/mtllm/plugin.py index 536dbeb99f..1b2c13de1a 100644 --- a/jac-mtllm/mtllm/plugin.py +++ b/jac-mtllm/mtllm/plugin.py @@ -62,6 +62,9 @@ def with_llm( is_custom = ( model_params.pop("is_custom") if "is_custom" in model_params else False ) + raw_output = ( + model_params.pop("raw_output") if "raw_output" in model_params else False + ) available_methods = model.MTLLM_METHOD_PROMPTS.keys() assert ( method in available_methods @@ -118,8 +121,12 @@ def with_llm( _globals, _locals, ) - _output = model.resolve_output( - meaning_out, output_hint, output_type_explanations, _globals, _locals + _output = ( + model.resolve_output( + meaning_out, output_hint, output_type_explanations, _globals, _locals + ) + if not raw_output + else meaning_out ) return _output diff --git a/jac-mtllm/mtllm/utils.py b/jac-mtllm/mtllm/utils.py index cbbf9d36ed..2e1e060ef9 100644 --- a/jac-mtllm/mtllm/utils.py +++ b/jac-mtllm/mtllm/utils.py @@ -107,7 +107,9 @@ def extract_template_placeholders(template: str) -> list: def format_template_section(template_section: str, values_dict: dict) -> str: """Format a template section with given values.""" placeholders = extract_template_placeholders(template_section) - filtered_values = { - key: values_dict[key] for key in placeholders if key in values_dict - } - return template_section.format(**filtered_values).strip() + formatted_sections = [] + for placeholder in placeholders: + if placeholder in values_dict and values_dict[placeholder]: + section_template = f"[{placeholder.title()}]\n{values_dict[placeholder]}" + formatted_sections.append(section_template) + return "\n\n".join(formatted_sections).strip()