Skip to content

Commit

Permalink
Move 'compress only' option to target model selection
Browse files Browse the repository at this point in the history
  • Loading branch information
cornzz committed Aug 10, 2024
1 parent dbf1eda commit 425fdfe
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
5 changes: 1 addition & 4 deletions src/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
}, 300));
});

// Show compressed prompt if "compress only" is checked
const ui_settings = document.querySelectorAll('.ui-settings input');
ui_settings[3].addEventListener('change', (event) => event.target.checked && !ui_settings[2].checked && ui_settings[2].click());

// Hide diff button
const diff = document.getElementById('compressed-diff');
const diffButton = document.createElement('button');
Expand All @@ -35,6 +31,7 @@

// Question click handler
const handleQuestionClick = (event) => {
const ui_settings = document.querySelectorAll('.ui-settings input');
const promptCheckbox = ui_settings[1];
if (!promptCheckbox.checked) promptCheckbox.click();
const promptInput = document.querySelector('.question-target input');
Expand Down
33 changes: 18 additions & 15 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException
from fastapi.responses import FileResponse, HTMLResponse, PlainTextResponse, StreamingResponse
from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from llmlingua import PromptCompressor

Expand All @@ -23,7 +23,8 @@
check_password,
create_llm_response,
create_metrics_df,
handle_ui_options,
handle_model_change,
handle_ui_settings,
metrics_to_df,
prepare_flagged_data,
shuffle_and_flatten,
Expand All @@ -36,7 +37,12 @@

LLM_ENDPOINT = os.getenv("LLM_ENDPOINT")
LLM_TOKEN = os.getenv("LLM_TOKEN")
LLM_MODELS = ["meta-llama/Meta-Llama-3.1-70B-Instruct", "mistral-7b-q4", "CohereForAI/c4ai-command-r-plus"]
LLM_LIST = [
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"mistral-7b-q4",
"CohereForAI/c4ai-command-r-plus",
"Compress only",
]
MPS_AVAILABLE = torch.backends.mps.is_available()
CUDA_AVAILABLE = torch.cuda.is_available()
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -143,13 +149,12 @@ def compress_prompt(prompt: str, rate: float):
return result["compressed_prompt"], diff, create_metrics_df(result), compression_time


def run_demo(prompt: str, context: str, rate: float, target_model: str, ui_settings: list[str], request: gr.Request):
# TODO: allow selecting parallel / sequential processing (?)
def run_demo(prompt: str, context: str, rate: float, target_model: str, request: gr.Request):
print(
f"RUN DEMO - prompt: {len(prompt.split())}, context: {len(context.split())}, rate: {rate}, model: {target_model.split('/')[-1]}",
f"{'(compress only) ' if 'Compress only' in ui_settings else ''}- from {request.cookies['session']}",
f"RUN DEMO - prompt: {len(prompt.split())}, context: {len(context.split())}, rate: {rate},",
f"model: {target_model.split('/')[-1]} - from {request.cookies['session']}",
)
if "Compress only" in ui_settings:
if target_model == "Compress only":
compressed, diff, metrics, compression_time = compress_prompt(context, rate)
metrics["Compression"] = [f"{compression_time:.2f}s"]
return compressed, diff, metrics, None, None, None, None
Expand Down Expand Up @@ -189,7 +194,7 @@ def run_demo(prompt: str, context: str, rate: float, target_model: str, ui_setti
"""
)
ui_settings = gr.CheckboxGroup(
["Show Metrics", "Show Separate Context Field", "Show Compressed Prompt", "Compress only"],
["Show Metrics", "Show Separate Context Field", "Show Compressed Prompt"],
label="UI Settings",
value=["Show Metrics", "Show Separate Context Field"],
elem_classes="ui-settings",
Expand All @@ -199,8 +204,7 @@ def run_demo(prompt: str, context: str, rate: float, target_model: str, ui_setti
prompt = gr.Textbox(label="Question", lines=1, max_lines=1, elem_classes="question-target")
context = gr.Textbox(label="Context", lines=8, max_lines=8, autoscroll=False, elem_classes="word-count")
rate = gr.Slider(0.1, 1, 0.5, step=0.05, label="Rate")
# TODO: move "compress only" here
target_model = gr.Radio(label="Target LLM Model", choices=LLM_MODELS, value=LLM_MODELS[0])
target_model = gr.Radio(label="Target LLM", choices=LLM_LIST, value=LLM_LIST[0])
with gr.Row():
clear = gr.Button("Clear", elem_classes="clear")
submit = gr.Button("Submit", variant="primary", interactive=False)
Expand Down Expand Up @@ -256,7 +260,7 @@ def run_demo(prompt: str, context: str, rate: float, target_model: str, ui_setti
context.change(activate_button, inputs=[prompt, context], outputs=submit)
submit.click(
run_demo,
inputs=[prompt, context, rate, target_model, ui_settings],
inputs=[prompt, context, rate, target_model],
outputs=[compressed, compressedDiff, metrics, response_a, response_a_obj, response_b, response_b_obj],
)
clear.click(
Expand All @@ -275,9 +279,8 @@ def run_demo(prompt: str, context: str, rate: float, target_model: str, ui_setti
qa_pairs,
],
)
ui_settings.change(
handle_ui_options, inputs=ui_settings, outputs=[prompt, context, compressedDiff, metrics, responses]
)
ui_settings.change(handle_ui_settings, inputs=ui_settings, outputs=[prompt, context, compressedDiff, metrics])
target_model.change(handle_model_change, inputs=[target_model, ui_settings], outputs=[compressedDiff, responses])
compressed.change(lambda x: update_label(x, compressedDiff), inputs=compressed, outputs=compressedDiff)
response_a.change(lambda x: update_label(x, response_a), inputs=response_a, outputs=response_a)
response_b.change(lambda x: update_label(x, response_b), inputs=response_b, outputs=response_b)
Expand Down
17 changes: 12 additions & 5 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,21 @@ def activate_button(*values):
return gr.Button(interactive=any(bool(value) for value in values))


def handle_ui_options(value: list[str]):
show_prompt = "Show Separate Context Field" in value
def handle_ui_settings(value: list[str]):
show_question = "Show Separate Context Field" in value
return (
gr.Textbox(visible=True) if show_prompt else gr.Textbox(visible=False, value=None),
gr.Textbox(label="Context" if show_prompt else "Prompt"),
gr.Textbox(visible=True) if show_question else gr.Textbox(visible=False, value=None),
gr.Textbox(label="Context" if show_question else "Prompt"),
gr.HighlightedText(visible="Show Compressed Prompt" in value),
gr.DataFrame(visible="Show Metrics" in value),
gr.Column(visible="Compress only" not in value),
)


def handle_model_change(value: str, options: list[str]):
compress_only = value == "Compress only"
return (
gr.HighlightedText(visible="Show Compressed Prompt" in options or compress_only),
gr.Column(visible=not compress_only),
)


Expand Down

0 comments on commit 425fdfe

Please sign in to comment.