Skip to content

Commit

Permalink
Refine Prompt Structure: Add Section Headers Conditionally (#1441)
Browse files Browse the repository at this point in the history
* initial commit

* add new flag to return the direct output and pass the meaning as string to the llm

* refactored

* Linting

---------

Co-authored-by: Yiping Kang <[email protected]>
  • Loading branch information
kugesan1105 and ypkang authored Nov 12, 2024
1 parent 9d0ae28 commit 04aa620
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
7 changes: 5 additions & 2 deletions jac-mtllm/examples/inherit_basellm.jac
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ glob PROMPT_TEMPLATE: str = """
[Information]
{information}

[Context]
{context}

[Output Information]
{output_information}

Expand All @@ -15,7 +18,7 @@ glob PROMPT_TEMPLATE: str = """

[Action]
{action}
""";
"""; # [Context] will not be appear in the prompt

obj Florence :BaseLLM: {
with entry {
Expand Down Expand Up @@ -91,7 +94,7 @@ enum DamageType {
}

can ""
predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True);
predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True,raw_output=True);

with entry {
img = 'car_scratch.jpg';
Expand Down
2 changes: 1 addition & 1 deletion jac-mtllm/mtllm/aott.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def aott_raise(
return f"[Output] {result}"
meaning_typed_input = (
"\n".join(meaning_typed_input_list) # type: ignore
if not contains_media
if not (contains_media and not is_custom)
else meaning_typed_input_list
)
return model(meaning_typed_input, media=media, **model_params) # type: ignore
Expand Down
11 changes: 9 additions & 2 deletions jac-mtllm/mtllm/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def with_llm(
is_custom = (
model_params.pop("is_custom") if "is_custom" in model_params else False
)
raw_output = (
model_params.pop("raw_output") if "raw_output" in model_params else False
)
available_methods = model.MTLLM_METHOD_PROMPTS.keys()
assert (
method in available_methods
Expand Down Expand Up @@ -118,8 +121,12 @@ def with_llm(
_globals,
_locals,
)
_output = model.resolve_output(
meaning_out, output_hint, output_type_explanations, _globals, _locals
_output = (
model.resolve_output(
meaning_out, output_hint, output_type_explanations, _globals, _locals
)
if not raw_output
else meaning_out
)
return _output

Expand Down
10 changes: 6 additions & 4 deletions jac-mtllm/mtllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def extract_template_placeholders(template: str) -> list:
def format_template_section(template_section: str, values_dict: dict) -> str:
"""Format a template section with given values."""
placeholders = extract_template_placeholders(template_section)
filtered_values = {
key: values_dict[key] for key in placeholders if key in values_dict
}
return template_section.format(**filtered_values).strip()
formatted_sections = []
for placeholder in placeholders:
if placeholder in values_dict and values_dict[placeholder]:
section_template = f"[{placeholder.title()}]\n{values_dict[placeholder]}"
formatted_sections.append(section_template)
return "\n\n".join(formatted_sections).strip()

0 comments on commit 04aa620

Please sign in to comment.