From 0d65332b7897df3fa92a451ac82b30892048ef2d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 30 May 2023 13:23:35 -0700 Subject: [PATCH] Better error reporting in frontend Signed-off-by: Antoni Baum --- aviary/common/backend.py | 30 +++++++++++++++++--------- aviary/frontend/app.py | 46 ++++++++++++++++++++++++++++++++-------- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/aviary/common/backend.py b/aviary/common/backend.py index bea2539a..89a9716f 100644 --- a/aviary/common/backend.py +++ b/aviary/common/backend.py @@ -7,6 +7,12 @@ from aviary.common.constants import TIMEOUT +class BackendError(RuntimeError): + def __init__(self, *args: object, **kwargs) -> None: + self.response = kwargs.pop("response", None) + super().__init__(*args) + + def get_aviary_backend(): """ Establishes a connection to the Aviary backed after establishing @@ -97,23 +103,25 @@ def __init__(self, backend_url: str, bearer: str): def models(self) -> List[str]: url = self.backend_url + "models" - resp = requests.get(url, headers=self.header, timeout=120) + response = requests.get(url, headers=self.header, timeout=TIMEOUT) try: - result = resp.json() + result = response.json() except requests.JSONDecodeError as e: - raise RuntimeError( - f"Error decoding JSON from {url}. Text response: {resp.text}", + raise BackendError( + f"Error decoding JSON from {url}. Text response: {response.text}", + response=response, ) from e return result def metadata(self, llm: str) -> Dict[str, Dict[str, Any]]: url = self.backend_url + "metadata/" + llm.replace("/", "--") - resp = requests.get(url, headers=self.header, timeout=120) + response = requests.get(url, headers=self.header, timeout=TIMEOUT) try: - result = resp.json() + result = response.json() except requests.JSONDecodeError as e: - raise RuntimeError( - f"Error decoding JSON from {url}. Text response: {resp.text}", + raise BackendError( + f"Error decoding JSON from {url}. Text response: {response.text}", + response=response, ) from e return result @@ -128,8 +136,9 @@ def completions(self, prompt: str, llm: str) -> Dict[str, Union[str, float, int] try: return response.json()[llm] except requests.JSONDecodeError as e: - raise RuntimeError( + raise BackendError( f"Error decoding JSON from {url}. Text response: {response.text}", + response=response, ) from e def batch_completions( @@ -145,8 +154,9 @@ def batch_completions( try: return response.json()[llm] except requests.JSONDecodeError as e: - raise RuntimeError( + raise BackendError( f"Error decoding JSON from {url}. Text response: {response.text}", + response=response, ) from e diff --git a/aviary/frontend/app.py b/aviary/frontend/app.py index 00a6ad50..fc0cb7df 100644 --- a/aviary/frontend/app.py +++ b/aviary/frontend/app.py @@ -5,10 +5,11 @@ import gradio as gr import ray +import requests from ray import serve from ray.serve.gradio_integrations import GradioIngress -from aviary.common.backend import get_aviary_backend +from aviary.common.backend import BackendError, get_aviary_backend from aviary.common.constants import ( AVIARY_DESC, CSS, @@ -75,14 +76,41 @@ def completions(prompt, llm): def do_query(prompt, model1, model2, model3, unused_raw=None): - models = [model1, model2, model3] - futures = [completions.remote(prompt, model) for model in models] - outs = ray.get(futures) - - text_output = [o["generated_text"] for o in outs] - stats = [gen_stats(o) for o in outs] - - return [*text_output, *stats, "", outs] + try: + models = [model1, model2, model3] + futures = [completions.remote(prompt, model) for model in models] + outs = ray.get(futures) + + text_output = [o["generated_text"] for o in outs] + stats = [gen_stats(o) for o in outs] + + return [*text_output, *stats, "", outs] + except requests.ReadTimeout as e: + raise gr.Error( + ( + "The request timed out. This usually means the server is experiencing a higher than usual load. " + "Please try again in a few minutes." + ).replace("\n", " ") + ) from e + except BackendError as e: + if "timeout" in e.response or e.response.status_code in (408, 504): + raise gr.Error( + ( + f"The request timed out. This usually means the server is experiencing a higher than usual load. " + f"Please try again in a few minutes.\nStatus code: {e.response.status_code}" + f"\nResponse: {e.response.text.split('raise ')[-1]}" + ).replace("\n", " ") + ) from e + else: + raise gr.Error( + ( + f"Backend returned an error. " + f"Status code: {e.response.status_code}" + f"\nResponse: {e.response.text.split('raise ')[-1]}" + ).replace("\n", " ") + ) from e + except Exception as e: + raise gr.Error(f"An error occurred. Please try again.\nError: {e}") from e def show_results(buttons, llm_text_boxes, llm_stats):