Skip to content

Commit

Permalink
Rename context to prompt
Browse files Browse the repository at this point in the history
Fix parallel inference
  • Loading branch information
cornzz committed Sep 8, 2024
1 parent 8db6615 commit 8dee86e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
47 changes: 24 additions & 23 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,37 +153,38 @@ def compress_prompt(prompt: str, rate: float, force_tokens: list[str], force_dig
compression_time = time.time() - start

word_sep, label_sep = "\t\t|\t\t", " "
diff = []
for line in result["fn_labeled_original_prompt"].split(word_sep):
word, label = line.split(label_sep)
diff.append((word, "+") if label == "1" else (word, None))
diff = [
(word, (None, "+")[int(label)])
for line in result["fn_labeled_original_prompt"].split(word_sep)
for word, label in [line.split(label_sep)]
]
return result["compressed_prompt"], diff, create_metrics_df(result), compression_time


def run_demo(
question: str,
prompt: str,
context: str,
rate: float,
target_model: str,
force_tokens: list[str],
force_digits: list[str],
request: gr.Request,
):
print(
f"RUN DEMO - prompt: {len(prompt.split())}, context: {len(context.split())}, rate: {rate},",
f"RUN DEMO - question: {len(question.split())}, prompt: {len(prompt.split())}, rate: {rate},",
f"model: {target_model.split('/')[-1]} - from {request.cookies['session']}",
)
if target_model == "Compress only":
compressed, diff, metrics, compression_time = compress_prompt(context, rate, force_tokens, bool(force_digits))
compressed, diff, metrics, compression_time = compress_prompt(prompt, rate, force_tokens, bool(force_digits))
metrics["Compression"] = [f"{compression_time:.2f}s"]
return compressed, diff, metrics, None, None, None, None

get_query = lambda ctx: f"{prompt}\n\n{ctx}" if prompt else ctx
with ThreadPoolExecutor() as executor:
future_original = executor.submit(call_llm_api, get_query(context), target_model)
compressed, diff, metrics, compression_time = compress_prompt(context, rate, force_tokens, bool(force_digits))
res_compressed = call_llm_api(get_query(compressed), target_model, True)
res_original = future_original.result()
get_query = lambda p: f"{question}\n\n{p}" if question else p
future_original = executor.submit(call_llm_api, get_query(prompt), target_model)
compressed, diff, metrics, compression_time = compress_prompt(prompt, rate, force_tokens, bool(force_digits))
res_compressed = call_llm_api(get_query(compressed), target_model, True)
res_original = future_original.result()

end_to_end_original = res_original["obj"]["call_time"]
end_to_end_compressed = res_compressed["obj"]["call_time"] + compression_time
Expand Down Expand Up @@ -219,9 +220,9 @@ def run_demo(
with gr.Column():
gr.Markdown("UI Settings")
ui_settings = gr.CheckboxGroup(
["Show Metrics", "Show Separate Context Field", "Show Compressed Prompt"],
["Show Metrics", "Show Question Field", "Show Compressed Prompt"],
container=False,
value=["Show Metrics", "Show Separate Context Field"],
value=["Show Metrics", "Show Question Field"],
elem_classes="ui-settings",
)
with gr.Column():
Expand All @@ -242,15 +243,15 @@ def run_demo(
)

# Inputs
prompt = gr.Textbox(
question = gr.Textbox(
label="Question",
lines=1,
max_lines=1,
placeholder=example_dataset[1]["QA_pairs"][6][0],
elem_classes="question-target",
)
context = gr.Textbox(
label="Context",
prompt = gr.Textbox(
label="Prompt (Context)",
lines=8,
max_lines=8,
autoscroll=False,
Expand Down Expand Up @@ -311,17 +312,17 @@ def run_demo(
)

# Event handlers
context.change(activate_button, inputs=context, outputs=submit)
prompt.change(activate_button, inputs=prompt, outputs=submit)
submit.click(
run_demo,
inputs=[prompt, context, rate, target_model, force_tokens, force_digits],
inputs=[question, prompt, rate, target_model, force_tokens, force_digits],
outputs=[compressed, compressedDiff, metrics, response_a, response_a_obj, response_b, response_b_obj],
)
clear.click(
lambda: [None] * 8 + [0.5, create_metrics_df(), gr.DataFrame(visible=False)],
outputs=[
question,
prompt,
context,
compressed,
compressedDiff,
response_a_obj,
Expand All @@ -333,7 +334,7 @@ def run_demo(
qa_pairs,
],
)
ui_settings.change(handle_ui_settings, inputs=ui_settings, outputs=[prompt, context, compressedDiff, metrics])
ui_settings.change(handle_ui_settings, inputs=ui_settings, outputs=[question, prompt, compressedDiff, metrics])
target_model.change(handle_model_change, inputs=[target_model, ui_settings], outputs=[compressedDiff, responses])
compressed.change(lambda x: update_label(x, compressedDiff), inputs=compressed, outputs=compressedDiff)
response_a.change(lambda x: update_label(x, response_a), inputs=response_a, outputs=response_a)
Expand All @@ -352,7 +353,7 @@ def run_demo(
),
),
inputs=examples,
outputs=[prompt, context, qa_pairs],
outputs=[question, prompt, qa_pairs],
)

# Flagging
Expand All @@ -366,7 +367,7 @@ def flag(prompt, context, compr_prompt, rate, metrics, res_a_obj, res_b_obj, fla
gr.Info("Preference saved. Thank you for your feedback.")
return [gr.Button(interactive=False)] * 3

FLAG_COMPONENTS = [prompt, context, compressed, rate, metrics, response_a_obj, response_b_obj]
FLAG_COMPONENTS = [question, prompt, compressed, rate, metrics, response_a_obj, response_b_obj]
flagging_callback.setup(FLAG_COMPONENTS, FLAG_DIRECTORY)
flag_a.click(flag, inputs=FLAG_COMPONENTS + [flag_a], outputs=[flag_a, flag_n, flag_b], preprocess=False)
flag_n.click(flag, inputs=FLAG_COMPONENTS + [flag_n], outputs=[flag_a, flag_n, flag_b], preprocess=False)
Expand Down
11 changes: 8 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def activate_button(value: str) -> gr.Button:


def handle_ui_settings(value: list[str]) -> tuple[gr.Textbox, gr.Textbox, gr.HighlightedText, gr.DataFrame]:
show_question = "Show Separate Context Field" in value
show_question = "Show Question Field" in value
return (
gr.Textbox(visible=True) if show_question else gr.Textbox(visible=False, value=None),
gr.Textbox(label="Context" if show_question else "Prompt"),
gr.Textbox(label="Prompt (Context)" if show_question else "Prompt"),
gr.HighlightedText(visible="Show Compressed Prompt" in value),
gr.DataFrame(visible="Show Metrics" in value),
)
Expand Down Expand Up @@ -149,7 +149,12 @@ def prepare_flagged_data(data: pd.DataFrame):
)
data["Metrics"] = data["Metrics"].apply(lambda x: metrics_to_df(json.loads(x)).to_html(index=False))
data = data.rename(
columns={"Response A": "Compressed", "Response B": "Uncompressed", "username": "user", "timestamp": "time"}
columns={
"Response A": "Compressed Response",
"Response B": "Uncompressed Response",
"username": "user",
"timestamp": "time",
}
)
return data.iloc[::-1].to_html(table_id="table")

Expand Down

0 comments on commit 8dee86e

Please sign in to comment.