diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f7c22a1e..3dd52673 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,7 +1,7 @@ import base64 import logging import urllib.parse -from typing import Dict, Union +from typing import Dict, Optional, Union import six import thrift @@ -31,6 +31,7 @@ def __init__( ssl_context=None, max_connections: int = 1, retry_policy: Union[DatabricksRetryPolicy, int] = 0, + proxies: Optional[Dict[str, str]] = None, ): if port is not None: warnings.warn( @@ -60,8 +61,11 @@ def __init__( self.path = parsed.path if parsed.query: self.path += "?%s" % parsed.query + + if proxies is None: + proxies = urllib.request.getproxies() try: - proxy = urllib.request.getproxies()[self.scheme] + proxy = proxies[self.scheme] except KeyError: proxy = None else: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e9e2978f..f4fcd6ef 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -158,6 +158,7 @@ def read(self) -> Optional[OAuthToken]: STRUCT is returned as Dict[str, Any] ARRAY is returned as numpy.ndarray When False, complex types are returned as a strings. These are generally deserializable as JSON. + :param proxies: An optional dictionary mapping protocol to the URL of the proxy. """ # Internal arguments in **kwargs: @@ -206,6 +207,7 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.proxies = kwargs.get("proxies") auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs @@ -651,7 +653,12 @@ def _handle_staging_put( raise Error("Cannot perform PUT without specifying a local_file") with open(local_file, "rb") as fh: - r = requests.put(url=presigned_url, data=fh, headers=headers) + r = requests.put( + url=presigned_url, + data=fh, + headers=headers, + proxies=self.connection.proxies, + ) # fmt: off # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 @@ -685,7 +692,9 @@ def _handle_staging_get( if local_file is None: raise Error("Cannot perform GET without specifying a local_file") - r = requests.get(url=presigned_url, headers=headers) + r = requests.get( + url=presigned_url, headers=headers, proxies=self.connection.proxies + ) # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True @@ -700,7 +709,9 @@ def _handle_staging_get( def _handle_staging_remove(self, presigned_url: str, headers: dict = None): """Make an HTTP DELETE request to the presigned_url""" - r = requests.delete(url=presigned_url, headers=headers) + r = requests.delete( + url=presigned_url, headers=headers, proxies=self.connection.proxies + ) if not r.ok: raise Error( diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e7b6dfd1..123df29d 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -6,7 +6,7 @@ import uuid import threading from ssl import CERT_NONE, CERT_REQUIRED, create_default_context -from typing import List, Union +from typing import List, Union, Optional, Dict import pyarrow import thrift.transport.THttpClient @@ -220,6 +220,9 @@ def __init__( additional_transport_args["retry_policy"] = self.retry_policy + if "proxies" in kwargs: + additional_transport_args["proxies"] = kwargs["proxies"] + self._transport = databricks.sql.auth.thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri,