From 8dee86ee65a0e7435d267f386effc5c9e588fbc3 Mon Sep 17 00:00:00 2001 From: cornzz <39997278+cornzz@users.noreply.github.com> Date: Sun, 8 Sep 2024 18:40:45 +0200 Subject: [PATCH] Rename context to prompt Fix parallel inference --- src/app.py | 47 ++++++++++++++++++++++++----------------------- src/utils.py | 11 ++++++++--- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/app.py b/src/app.py index a8e927d..bf39c13 100644 --- a/src/app.py +++ b/src/app.py @@ -153,16 +153,17 @@ def compress_prompt(prompt: str, rate: float, force_tokens: list[str], force_dig compression_time = time.time() - start word_sep, label_sep = "\t\t|\t\t", " " - diff = [] - for line in result["fn_labeled_original_prompt"].split(word_sep): - word, label = line.split(label_sep) - diff.append((word, "+") if label == "1" else (word, None)) + diff = [ + (word, (None, "+")[int(label)]) + for line in result["fn_labeled_original_prompt"].split(word_sep) + for word, label in [line.split(label_sep)] + ] return result["compressed_prompt"], diff, create_metrics_df(result), compression_time def run_demo( + question: str, prompt: str, - context: str, rate: float, target_model: str, force_tokens: list[str], @@ -170,20 +171,20 @@ def run_demo( request: gr.Request, ): print( - f"RUN DEMO - prompt: {len(prompt.split())}, context: {len(context.split())}, rate: {rate},", + f"RUN DEMO - question: {len(question.split())}, prompt: {len(prompt.split())}, rate: {rate},", f"model: {target_model.split('/')[-1]} - from {request.cookies['session']}", ) if target_model == "Compress only": - compressed, diff, metrics, compression_time = compress_prompt(context, rate, force_tokens, bool(force_digits)) + compressed, diff, metrics, compression_time = compress_prompt(prompt, rate, force_tokens, bool(force_digits)) metrics["Compression"] = [f"{compression_time:.2f}s"] return compressed, diff, metrics, None, None, None, None - get_query = lambda ctx: f"{prompt}\n\n{ctx}" if prompt else ctx with ThreadPoolExecutor() as executor: - future_original = executor.submit(call_llm_api, get_query(context), target_model) - compressed, diff, metrics, compression_time = compress_prompt(context, rate, force_tokens, bool(force_digits)) - res_compressed = call_llm_api(get_query(compressed), target_model, True) - res_original = future_original.result() + get_query = lambda p: f"{question}\n\n{p}" if question else p + future_original = executor.submit(call_llm_api, get_query(prompt), target_model) + compressed, diff, metrics, compression_time = compress_prompt(prompt, rate, force_tokens, bool(force_digits)) + res_compressed = call_llm_api(get_query(compressed), target_model, True) + res_original = future_original.result() end_to_end_original = res_original["obj"]["call_time"] end_to_end_compressed = res_compressed["obj"]["call_time"] + compression_time @@ -219,9 +220,9 @@ def run_demo( with gr.Column(): gr.Markdown("UI Settings") ui_settings = gr.CheckboxGroup( - ["Show Metrics", "Show Separate Context Field", "Show Compressed Prompt"], + ["Show Metrics", "Show Question Field", "Show Compressed Prompt"], container=False, - value=["Show Metrics", "Show Separate Context Field"], + value=["Show Metrics", "Show Question Field"], elem_classes="ui-settings", ) with gr.Column(): @@ -242,15 +243,15 @@ def run_demo( ) # Inputs - prompt = gr.Textbox( + question = gr.Textbox( label="Question", lines=1, max_lines=1, placeholder=example_dataset[1]["QA_pairs"][6][0], elem_classes="question-target", ) - context = gr.Textbox( - label="Context", + prompt = gr.Textbox( + label="Prompt (Context)", lines=8, max_lines=8, autoscroll=False, @@ -311,17 +312,17 @@ def run_demo( ) # Event handlers - context.change(activate_button, inputs=context, outputs=submit) + prompt.change(activate_button, inputs=prompt, outputs=submit) submit.click( run_demo, - inputs=[prompt, context, rate, target_model, force_tokens, force_digits], + inputs=[question, prompt, rate, target_model, force_tokens, force_digits], outputs=[compressed, compressedDiff, metrics, response_a, response_a_obj, response_b, response_b_obj], ) clear.click( lambda: [None] * 8 + [0.5, create_metrics_df(), gr.DataFrame(visible=False)], outputs=[ + question, prompt, - context, compressed, compressedDiff, response_a_obj, @@ -333,7 +334,7 @@ def run_demo( qa_pairs, ], ) - ui_settings.change(handle_ui_settings, inputs=ui_settings, outputs=[prompt, context, compressedDiff, metrics]) + ui_settings.change(handle_ui_settings, inputs=ui_settings, outputs=[question, prompt, compressedDiff, metrics]) target_model.change(handle_model_change, inputs=[target_model, ui_settings], outputs=[compressedDiff, responses]) compressed.change(lambda x: update_label(x, compressedDiff), inputs=compressed, outputs=compressedDiff) response_a.change(lambda x: update_label(x, response_a), inputs=response_a, outputs=response_a) @@ -352,7 +353,7 @@ def run_demo( ), ), inputs=examples, - outputs=[prompt, context, qa_pairs], + outputs=[question, prompt, qa_pairs], ) # Flagging @@ -366,7 +367,7 @@ def flag(prompt, context, compr_prompt, rate, metrics, res_a_obj, res_b_obj, fla gr.Info("Preference saved. Thank you for your feedback.") return [gr.Button(interactive=False)] * 3 - FLAG_COMPONENTS = [prompt, context, compressed, rate, metrics, response_a_obj, response_b_obj] + FLAG_COMPONENTS = [question, prompt, compressed, rate, metrics, response_a_obj, response_b_obj] flagging_callback.setup(FLAG_COMPONENTS, FLAG_DIRECTORY) flag_a.click(flag, inputs=FLAG_COMPONENTS + [flag_a], outputs=[flag_a, flag_n, flag_b], preprocess=False) flag_n.click(flag, inputs=FLAG_COMPONENTS + [flag_n], outputs=[flag_a, flag_n, flag_b], preprocess=False) diff --git a/src/utils.py b/src/utils.py index a540041..bd12a66 100644 --- a/src/utils.py +++ b/src/utils.py @@ -65,10 +65,10 @@ def activate_button(value: str) -> gr.Button: def handle_ui_settings(value: list[str]) -> tuple[gr.Textbox, gr.Textbox, gr.HighlightedText, gr.DataFrame]: - show_question = "Show Separate Context Field" in value + show_question = "Show Question Field" in value return ( gr.Textbox(visible=True) if show_question else gr.Textbox(visible=False, value=None), - gr.Textbox(label="Context" if show_question else "Prompt"), + gr.Textbox(label="Prompt (Context)" if show_question else "Prompt"), gr.HighlightedText(visible="Show Compressed Prompt" in value), gr.DataFrame(visible="Show Metrics" in value), ) @@ -149,7 +149,12 @@ def prepare_flagged_data(data: pd.DataFrame): ) data["Metrics"] = data["Metrics"].apply(lambda x: metrics_to_df(json.loads(x)).to_html(index=False)) data = data.rename( - columns={"Response A": "Compressed", "Response B": "Uncompressed", "username": "user", "timestamp": "time"} + columns={ + "Response A": "Compressed Response", + "Response B": "Uncompressed Response", + "username": "user", + "timestamp": "time", + } ) return data.iloc[::-1].to_html(table_id="table")