From cc23838ef5aa9c241e98539a9e855be4475ac0c1 Mon Sep 17 00:00:00 2001 From: "pixeebot[bot]" <104101892+pixeebot[bot]@users.noreply.github.com> Date: Sun, 30 Jun 2024 09:39:54 -0400 Subject: [PATCH] Add timeout to `requests` calls (#2) Co-authored-by: pixeebot[bot] <104101892+pixeebot[bot]@users.noreply.github.com> --- src/vanna/__init__.py | 6 +++--- src/vanna/base/base.py | 4 ++-- src/vanna/flask/__init__.py | 2 +- src/vanna/vannadb/vannadb_vector.py | 12 ++++++------ src/vanna/vllm/vllm.py | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index 3d60a9e2..4bcaad1f 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -62,8 +62,8 @@ def __unauthenticated_rpc_call(method, params): data = {"method": method, "params": [__dataclass_to_dict(obj) for obj in params]} response = requests.post( - _unauthenticated_endpoint, headers=headers, data=json.dumps(data) - ) + _unauthenticated_endpoint, headers=headers, data=json.dumps(data), + timeout=60) return response.json() @@ -397,4 +397,4 @@ def connect_to_bigquery(cred_file_path: str = None, project_id: str = None): error_deprecation() def connect_to_duckdb(url: str="memory", init_sql: str = None): - error_deprecation() \ No newline at end of file + error_deprecation() diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 56360adc..0c4b7d76 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -808,7 +808,7 @@ def connect_to_sqlite(self, url: str): # Download the database if it doesn't exist if not os.path.exists(url): - response = requests.get(url) + response = requests.get(url, timeout=60) response.raise_for_status() # Check that the request was successful with open(path, "wb") as f: f.write(response.content) @@ -1297,7 +1297,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None): path = os.path.basename(urlparse(url).path) # Download the database if it doesn't exist if not os.path.exists(path): - response = requests.get(url) + response = requests.get(url, timeout=60) response.raise_for_status() # Check that the request was successful with open(path, "wb") as f: f.write(response.content) diff --git a/src/vanna/flask/__init__.py b/src/vanna/flask/__init__.py index 83c2f8f4..01931ae9 100644 --- a/src/vanna/flask/__init__.py +++ b/src/vanna/flask/__init__.py @@ -736,7 +736,7 @@ def proxy_assets(filename): @self.flask_app.route("/vanna.svg") def proxy_vanna_svg(): remote_url = "https://vanna.ai/img/vanna.svg" - response = requests.get(remote_url, stream=True) + response = requests.get(remote_url, stream=True, timeout=60) # Check if the request to the remote URL was successful if response.status_code == 200: diff --git a/src/vanna/vannadb/vannadb_vector.py b/src/vanna/vannadb/vannadb_vector.py index eec372e5..17f14af8 100644 --- a/src/vanna/vannadb/vannadb_vector.py +++ b/src/vanna/vannadb/vannadb_vector.py @@ -60,7 +60,7 @@ def _rpc_call(self, method, params): "params": [self._dataclass_to_dict(obj) for obj in params], } - response = requests.post(self._endpoint, headers=headers, data=json.dumps(data)) + response = requests.post(self._endpoint, headers=headers, data=json.dumps(data), timeout=60) return response.json() def _dataclass_to_dict(self, obj): @@ -85,7 +85,7 @@ def get_all_functions(self) -> list: } """ - response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query}) + response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query}, timeout=60) response_json = response.json() if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']: self.log(response_json['data']['get_all_sql_functions']) @@ -123,7 +123,7 @@ def get_function(self, question: str, additional_data: dict = {}) -> dict: """ static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()] variables = {"question": question, "staticFunctionArguments": static_function_arguments} - response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables}) + response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables}, timeout=60) response_json = response.json() if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']: self.log(response_json['data']['get_and_instantiate_function']) @@ -153,7 +153,7 @@ def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) - } """ variables = {"question": question, "sql": sql, "plotly_code": plotly_code} - response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables}) + response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables}, timeout=60) response_json = response.json() if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']: resp = response_json['data']['generate_and_create_sql_function'] @@ -216,7 +216,7 @@ def validate_arguments(args): print("variables", variables) - response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables}) + response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables}, timeout=60) response_json = response.json() if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']: return response_json['data']['update_sql_function'] @@ -230,7 +230,7 @@ def delete_function(self, function_name: str) -> bool: } """ variables = {"function_name": function_name} - response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables}) + response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables}, timeout=60) response_json = response.json() if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']: return response_json['data']['delete_sql_function'] diff --git a/src/vanna/vllm/vllm.py b/src/vanna/vllm/vllm.py index 53990821..70335692 100644 --- a/src/vanna/vllm/vllm.py +++ b/src/vanna/vllm/vllm.py @@ -78,11 +78,11 @@ def submit_prompt(self, prompt, **kwargs) -> str: 'Authorization': f'Bearer {self.auth_key}' } - response = requests.post(url, headers=headers,json=data) + response = requests.post(url, headers=headers,json=data, timeout=60) else: - response = requests.post(url, json=data) + response = requests.post(url, json=data, timeout=60) response_dict = response.json()