Skip to content

Commit

Permalink
Enhance Compatibility for Custom HF Models (#1433)
Browse files Browse the repository at this point in the history
* Initial commit to support custom huggingface models

* fixed the issue with openai

* updated customizable prompt template

* refactored
  • Loading branch information
kugesan1105 authored Nov 12, 2024
1 parent 56e6ecb commit 9d0ae28
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 46 deletions.
3 changes: 3 additions & 0 deletions jac-mtllm/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ repos:
rev: 24.1.1
hooks:
- id: black
exclude: '__jac_gen__'
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
exclude: '__jac_gen__'
args: ["--config=jac-mtllm/.flake8"]
additional_dependencies: [pep8-naming, flake8_import_order, flake8_docstrings, flake8_comprehensions, flake8_bugbear, flake8_annotations, flake8_simplify]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies: [types-PyYAML, types-requests]
exclude: '__jac_gen__'
args:
- --follow-imports=silent
- --ignore-missing-imports
Binary file added jac-mtllm/examples/car_scratch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
100 changes: 100 additions & 0 deletions jac-mtllm/examples/inherit_basellm.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import from mtllm.llms.base { BaseLLM }
import:py from PIL { Image }
import torch;
import from transformers { AutoModelForCausalLM, AutoProcessor }

glob PROMPT_TEMPLATE: str = """
[Information]
{information}

[Output Information]
{output_information}

[Type Explanations]
{type_explanations}

[Action]
{action}
""";

obj Florence :BaseLLM: {
with entry {
MTLLM_PROMPT = PROMPT_TEMPLATE;
}
can init(model_id: str) {
self.verbose = True;
self.max_tries = 0;
self.type_check = False;
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True, torch_dtype='auto'
).eval().cuda();
self.processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True
);
}

can __infer__(meaning: str, image: Image, **kwargs: tuple) -> str {
can run_example(task_prompt: str, image: Image) -> str {
prompt = task_prompt;
inputs = self.processor(
text=prompt,
images=image,
return_tensors="pt"
).to(
'cuda',
torch.float16
);
generated_ids = self.model.generate(
input_ids=inputs["input_ids"].cuda(),
pixel_values=inputs["pixel_values"].cuda(),
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,

);
generated_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=False
)[0];
parsed_answer = self.processor.post_process_generation(
generated_text,
task=task_prompt, image_size=(image.width, image.height)
);

return parsed_answer;
}

result = run_example('<MORE_DETAILED_CAPTION>', image=image);
return str(next(iter(result.values())));
}

can __call__(meaning: str, media: list, **kwargs: tuple) {
if self.verbose {
print(f'MEANING_IN:\n{meaning}');
print('MEDIA:\n', media);
}
image = media[0].value;
return self.__infer__(meaning, image, **kwargs);
}
}

glob llm = Florence('microsoft/Florence-2-base');

enum DamageType {
NoDamage,
MinorDamage,
MajorDamage,
Destroyed
}

can ""
predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True);

with entry {
img = 'car_scratch.jpg';
image = Image.open(img);
print(predict_vehicle_damage(image));
}
58 changes: 34 additions & 24 deletions jac-mtllm/mtllm/aott.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from typing import Mapping

from PIL import Image as PILImage

from jaclang.compiler.semtable import SemRegistry

from loguru import logger
Expand All @@ -22,6 +24,7 @@
TypeExplanation,
Video,
)
from mtllm.utils import format_template_section


def aott_raise(
Expand All @@ -33,6 +36,7 @@ def aott_raise(
action: str,
context: str,
method: str,
is_custom: bool,
tools: list[Tool],
model_params: dict,
_globals: dict,
Expand All @@ -41,14 +45,18 @@ def aott_raise(
"""AOTT Raise uses the information (Meanings types values) provided to generate a prompt(meaning in)."""
_globals["finish_tool"] = finish_tool
contains_media: bool = any(
isinstance(x.value, (Image, Video)) for x in inputs_information
isinstance(x.value, (Image, Video, PILImage.Image)) for x in inputs_information
)
informations_str = "\n".join([str(x) for x in informations])
inputs_information_repr: list[dict] | str
if contains_media:
media = []
if contains_media and not is_custom:
inputs_information_repr = []
for x in inputs_information:
inputs_information_repr.extend(x.to_list_dict())
for input_info in inputs_information:
inputs_information_repr.extend(input_info.to_list_dict())
elif is_custom:
media = [x for x in inputs_information if isinstance(x.value, PILImage.Image)]
inputs_information_repr = ""
else:
inputs_information_repr = "\n".join([str(x) for x in inputs_information])

Expand All @@ -60,14 +68,15 @@ def aott_raise(
tools.append(finish_tool)
method_prompt = model.MTLLM_METHOD_PROMPTS[method]
if isinstance(inputs_information_repr, str):
mtllm_prompt = model.MTLLM_PROMPT.format(
information=informations_str,
inputs_information=inputs_information_repr,
output_information=str(output_hint),
type_explanations=type_explanations_str,
action=action,
context=context,
).strip()
all_values = {
"information": informations_str,
"inputs_information": inputs_information_repr,
"output_information": str(output_hint),
"type_explanations": type_explanations_str,
"action": action,
"context": context,
}
mtllm_prompt = format_template_section(model.MTLLM_PROMPT, all_values)
if not is_react:
meaning_typed_input_list = [system_prompt, mtllm_prompt, method_prompt]
else:
Expand All @@ -79,17 +88,18 @@ def aott_raise(
method_prompt,
]
else:
upper_half = model.MTLLM_PROMPT.split("{inputs_information}")[0]
lower_half = model.MTLLM_PROMPT.split("{inputs_information}")[1]
upper_half = upper_half.format(
information=informations_str,
context=context,
)
lower_half = lower_half.format(
output_information=str(output_hint),
type_explanations=type_explanations_str,
action=action,
)
upper_half, lower_half = model.MTLLM_PROMPT.split("{inputs_information}")
upper_values = {
"information": informations_str,
"context": context,
}
lower_values = {
"output_information": str(output_hint),
"type_explanations": type_explanations_str,
"action": action,
}
upper_half = format_template_section(upper_half, upper_values)
lower_half = format_template_section(lower_half, lower_values)
meaning_typed_input_list = [
{"type": "text", "text": system_prompt},
{"type": "text", "text": upper_half},
Expand Down Expand Up @@ -121,7 +131,7 @@ def aott_raise(
if not contains_media
else meaning_typed_input_list
)
return model(meaning_typed_input, **model_params) # type: ignore
return model(meaning_typed_input, media=media, **model_params) # type: ignore


def execute_react(
Expand Down
61 changes: 39 additions & 22 deletions jac-mtllm/mtllm/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from loguru import logger

from mtllm.types import OutputHint, ReActOutput, TypeExplanation
from mtllm.types import InputInformation, OutputHint, ReActOutput, TypeExplanation
from mtllm.utils import format_template_section


httpx_logger = logging.getLogger("httpx")
Expand Down Expand Up @@ -161,7 +162,12 @@ def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
"""Infer a response from the input meaning."""
raise NotImplementedError

def __call__(self, input_text: str | list[dict], **kwargs: dict) -> str:
def __call__(
self,
input_text: str | list[dict],
media: list[Optional[InputInformation]],
**kwargs: dict,
) -> str:
"""Infer a response from the input text."""
if self.verbose:
logger.info(f"Meaning In\n{input_text}")
Expand Down Expand Up @@ -251,11 +257,14 @@ def _fix_react_output(
"""Fix the output string."""
if self.verbose:
logger.info(f"Error: {error}, Fixing the output.")
react_output_fix_prompt = self.REACT_OUTPUT_FIX_PROMPT.format(
model_output=meaning_out,
error=str(error),
tool_explanations=tool_explanations,
type_explanations=type_explanations,
react_output_fix_values = {
"model_output": meaning_out,
"error": str(error),
"tool_explanations": tool_explanations,
"type_explanations": type_explanations,
}
react_output_fix_prompt = format_template_section(
self.REACT_OUTPUT_FIX_PROMPT, react_output_fix_values
)
return self.__infer__(react_output_fix_prompt)

Expand All @@ -266,12 +275,15 @@ def _check_output(
output_type_explanations: list[TypeExplanation],
) -> bool:
"""Check if the output is in the desired format."""
output_check_prompt = self.OUTPUT_CHECK_PROMPT.format(
model_output=output,
output_type=output_type,
output_type_info="\n".join(
react_values = {
"model_output": output,
"output_type": output_type,
"output_type_info": "\n".join(
[str(info) for info in output_type_explanations]
),
}
output_check_prompt = format_template_section(
self.OUTPUT_CHECK_PROMPT, react_values
)
llm_output = self.__infer__(output_check_prompt)
return "yes" in llm_output.lower()
Expand All @@ -298,14 +310,16 @@ def _extract_output(
)
else:
logger.info("Extracting output from the meaning out string.")

output_extract_prompt = self.OUTPUT_EXTRACT_PROMPT.format(
model_output=meaning_out,
previous_output=previous_output,
output_info=str(output_hint),
output_type_info="\n".join(
output_check_values = {
"model_output": meaning_out,
"previous_output": previous_output,
"output_info": str(output_hint),
"output_type_info": "\n".join(
[str(info) for info in output_type_explanations]
),
}
output_extract_prompt = format_template_section(
self.OUTPUT_EXTRACT_PROMPT, output_check_values
)
llm_output = self.__infer__(output_extract_prompt)
is_in_desired_format = self._check_output(
Expand Down Expand Up @@ -376,12 +390,15 @@ def _fix_output(
"""Fix the output string."""
if self.verbose:
logger.info(f"Error: {error}, Fixing the output.")
output_fix_prompt = self.OUTPUT_FIX_PROMPT.format(
model_output=output,
output_type=output_hint.type,
output_type_info="\n".join(
output_fix_values = {
"model_output": output,
"output_type": output_hint.type,
"output_type_info": "\n".join(
[str(info) for info in output_type_explanations]
),
error=error,
"error": str(error),
}
output_fix_prompt = format_template_section(
self.OUTPUT_FIX_PROMPT, output_fix_values
)
return self.__infer__(output_fix_prompt)
4 changes: 4 additions & 0 deletions jac-mtllm/mtllm/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def with_llm(
assert _scope is not None, f"Invalid scope: {scope}"

method = model_params.pop("method") if "method" in model_params else "Normal"
is_custom = (
model_params.pop("is_custom") if "is_custom" in model_params else False
)
available_methods = model.MTLLM_METHOD_PROMPTS.keys()
assert (
method in available_methods
Expand Down Expand Up @@ -109,6 +112,7 @@ def with_llm(
action,
context,
method,
is_custom,
_tools,
model_params,
_globals,
Expand Down
14 changes: 14 additions & 0 deletions jac-mtllm/mtllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,17 @@ def get_filtered_registry(mod_registry: SemRegistry, scope: SemScope) -> SemRegi
filtered_registry.registry[_scope] = sem_info_list

return filtered_registry


def extract_template_placeholders(template: str) -> list:
"""Extract placeholders from the template."""
return re.findall(r"{(.*?)}", template)


def format_template_section(template_section: str, values_dict: dict) -> str:
"""Format a template section with given values."""
placeholders = extract_template_placeholders(template_section)
filtered_values = {
key: values_dict[key] for key in placeholders if key in values_dict
}
return template_section.format(**filtered_values).strip()

0 comments on commit 9d0ae28

Please sign in to comment.