From 0157035be2031257c4f9e37804b5a04e7064cabd Mon Sep 17 00:00:00 2001 From: kugesan1105 Date: Mon, 18 Nov 2024 16:48:30 +0530 Subject: [PATCH 1/2] support dual llm ; one for processing and one for type check and type resolver --- jac-mtllm/examples/inherit_basellm.jac | 5 ++++- jac-mtllm/mtllm/plugin.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/jac-mtllm/examples/inherit_basellm.jac b/jac-mtllm/examples/inherit_basellm.jac index 66ba1289b1..1b9cd5b24b 100644 --- a/jac-mtllm/examples/inherit_basellm.jac +++ b/jac-mtllm/examples/inherit_basellm.jac @@ -1,4 +1,5 @@ import from mtllm.llms.base { BaseLLM } +import:py from mtllm.llms { OpenAI } import:py from PIL { Image } import torch; import from transformers { AutoModelForCausalLM, AutoProcessor } @@ -85,6 +86,7 @@ obj Florence :BaseLLM: { } glob llm = Florence('microsoft/Florence-2-base'); +glob llm2 = OpenAI(verbose=True, model_name="gpt-4o-mini"); enum DamageType { NoDamage, @@ -94,7 +96,8 @@ enum DamageType { } can "" -predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True,raw_output=True); +# predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True,raw_output=True); +predict_vehicle_damage(img: Image) -> DamageType by llm(is_custom=True,resolve_with =llm2); with entry { img = 'car_scratch.jpg'; diff --git a/jac-mtllm/mtllm/plugin.py b/jac-mtllm/mtllm/plugin.py index 1b2c13de1a..19db0c22ad 100644 --- a/jac-mtllm/mtllm/plugin.py +++ b/jac-mtllm/mtllm/plugin.py @@ -121,12 +121,12 @@ def with_llm( _globals, _locals, ) + resolver_model = model_params.pop("resolve_with") if "resolve_with" in model_params else model _output = ( - model.resolve_output( + meaning_out if raw_output else + resolver_model.resolve_output( meaning_out, output_hint, output_type_explanations, _globals, _locals ) - if not raw_output - else meaning_out ) return _output From 5af632d086c72ba0a485862c68bf071cfd81828f Mon Sep 17 00:00:00 2001 From: kugesan1105 Date: Mon, 18 Nov 2024 17:56:52 +0530 Subject: [PATCH 2/2] linting fixed --- jac-mtllm/mtllm/plugin.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/jac-mtllm/mtllm/plugin.py b/jac-mtllm/mtllm/plugin.py index 19db0c22ad..f4f6287bc8 100644 --- a/jac-mtllm/mtllm/plugin.py +++ b/jac-mtllm/mtllm/plugin.py @@ -121,10 +121,15 @@ def with_llm( _globals, _locals, ) - resolver_model = model_params.pop("resolve_with") if "resolve_with" in model_params else model + resolver_model = ( + model_params.pop("resolve_with") + if "resolve_with" in model_params + else model + ) _output = ( - meaning_out if raw_output else - resolver_model.resolve_output( + meaning_out + if raw_output + else resolver_model.resolve_output( meaning_out, output_hint, output_type_explanations, _globals, _locals ) )