Skip to content

Commit

Permalink
fixes for cloud fetch - part un (#356)
Browse files Browse the repository at this point in the history
* fixes for cloud fetch

Signed-off-by: Andre Furlan <[email protected]>
---------

Signed-off-by: Andre Furlan <[email protected]>
Co-authored-by: Raymond Cypher <[email protected]>
  • Loading branch information
andrefurlan-db and rcypher-databricks authored Feb 16, 2024
1 parent a737ef3 commit 6a348ec
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 143 deletions.
33 changes: 33 additions & 0 deletions examples/custom_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from databricks import sql
import os
import logging


logger = logging.getLogger("databricks.sql")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("pysqllogs.log")
fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(process)d %(thread)d %(message)s"))
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
use_cloud_fetch=True,
max_download_threads = 2
) as connection:

with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor:
print(
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2"
)
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
try:
while True:
row = cursor.fetchone()
if row is None:
break
print(f"row: {row}")
except sql.exc.ResultSetDownloadError as e:
print(f"error: {e}")
32 changes: 32 additions & 0 deletions src/databricks/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@
threadsafety = 1 # Threads may share the module, but not connections.
paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s

import re


class RedactUrlQueryParamsFilter(logging.Filter):
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)")
mask = r"\1\2=<REDACTED>"

def __init__(self):
super().__init__()

def redact(self, string):
return re.sub(self.pattern, self.mask, str(string))

def filter(self, record):
record.msg = self.redact(str(record.msg))
if isinstance(record.args, dict):
for k in record.args.keys():
record.args[k] = (
self.redact(record.args[k])
if isinstance(record.arg[k], str)
else record.args[k]
)
else:
record.args = tuple(
(self.redact(arg) if isinstance(arg, str) else arg)
for arg in record.args
)

return True


logging.getLogger("urllib3.connectionpool").addFilter(RedactUrlQueryParamsFilter())

class DBAPITypeObject(object):
def __init__(self, *values):
Expand Down
40 changes: 8 additions & 32 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResultSetDownloadHandler,
DownloadableResultSettings,
)
from databricks.sql.exc import ResultSetDownloadError
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool):
self.download_handlers: List[ResultSetDownloadHandler] = []
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self.fetch_need_retry = False
self.num_consecutive_result_file_download_retries = 0

def add_file_links(
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
Expand Down Expand Up @@ -81,13 +80,15 @@ def get_next_downloaded_file(

# Find next file
idx = self._find_next_file_index(next_row_offset)
# is this correct?
if idx is None:
self._shutdown_manager()
logger.debug("could not find next file index")
return None
handler = self.download_handlers[idx]

# Check (and wait) for download status
if self._check_if_download_successful(handler):
if handler.is_file_download_successful():
# Buffer should be empty so set buffer to new ArrowQueue with result_file
result = DownloadedFile(
handler.result_file,
Expand All @@ -97,9 +98,11 @@ def get_next_downloaded_file(
self.download_handlers.pop(idx)
# Return True upon successful download to continue loop and not force a retry
return result
# Download was not successful for next download item, force a retry
# Download was not successful for next download item. Fail
self._shutdown_manager()
return None
raise ResultSetDownloadError(
f"Download failed for result set starting at {next_row_offset}"
)

def _remove_past_handlers(self, next_row_offset: int):
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
Expand Down Expand Up @@ -133,33 +136,6 @@ def _find_next_file_index(self, next_row_offset: int):
]
return next_indices[0] if len(next_indices) > 0 else None

def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
# Check (and wait until download finishes) if download was successful
if not handler.is_file_download_successful():
if handler.is_link_expired:
self.fetch_need_retry = True
return False
elif handler.is_download_timedout:
# Consecutive file retries should not exceed threshold in settings
if (
self.num_consecutive_result_file_download_retries
>= self.downloadable_result_settings.max_consecutive_file_download_retries
):
self.fetch_need_retry = True
return False
self.num_consecutive_result_file_download_retries += 1

# Re-submit handler run to thread pool and recursively check download status
self.thread_pool.submit(handler.run)
return self._check_if_download_successful(handler)
else:
self.fetch_need_retry = True
return False

self.num_consecutive_result_file_download_retries = 0
self.fetch_need_retry = False
return True

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
Expand Down
120 changes: 93 additions & 27 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from dataclasses import dataclass

import requests
import lz4.frame
import threading
import time

import os
import re
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)

DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60))


@dataclass
class DownloadableResultSettings:
Expand All @@ -20,13 +22,17 @@ class DownloadableResultSettings:
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
download_timeout (int): Timeout for download requests. Default 60 secs.
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
download_max_retries (int): Number of consecutive download retries before shutting down.
max_retries (int): Number of consecutive download retries before shutting down.
backoff_factor (int): Factor to increase wait time between retries.
"""

is_lz4_compressed: bool
link_expiry_buffer_secs: int = 0
download_timeout: int = 60
max_consecutive_file_download_retries: int = 0
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT
max_retries: int = 5
backoff_factor: int = 2


class ResultSetDownloadHandler(threading.Thread):
Expand Down Expand Up @@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool:
else None
)
try:
logger.debug(
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

if not self.is_download_finished.wait(timeout=timeout):
self.is_download_timedout = True
logger.debug(
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
self.settings.download_timeout,
self.result_link.startRowOffset,
self.result_link.startRowOffset + self.result_link.rowCount,
)
logger.error(
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return False
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
return self.is_file_downloaded_successfully

logger.debug(
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
except Exception as e:
logger.error(e)
return False
Expand All @@ -81,24 +92,36 @@ def run(self):
"""
self._reset()

# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return
try:
# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return

session = requests.Session()
session.timeout = self.settings.download_timeout
logger.debug(
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

try:
# Get the file via HTTP request
response = session.get(self.result_link.fileLink)
response = http_get_with_retry(
url=self.result_link.fileLink,
max_retries=self.settings.max_retries,
backoff_factor=self.settings.backoff_factor,
download_timeout=self.settings.download_timeout,
)

if not response.ok:
self.is_file_downloaded_successfully = False
if not response:
logger.error(
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return

logger.debug(
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

# Save (and decompress if needed) the downloaded file
compressed_data = response.content
decompressed_data = (
Expand All @@ -109,15 +132,22 @@ def run(self):
self.result_file = decompressed_data

# The size of the downloaded file should match the size specified from TSparkArrowResultLink
self.is_file_downloaded_successfully = (
len(self.result_file) == self.result_link.bytesNum
success = len(self.result_file) == self.result_link.bytesNum
logger.debug(
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
self.is_file_downloaded_successfully = success
except Exception as e:
logger.error(
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
logger.error(e)
self.is_file_downloaded_successfully = False

finally:
session and session.close()
logger.debug(
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()

Expand Down Expand Up @@ -145,6 +175,7 @@ def check_link_expired(
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
logger.debug("link expired")
return True
return False

Expand All @@ -171,3 +202,38 @@ def decompress_data(compressed_data: bytes) -> bytes:
uncompressed_data += data
start += num_bytes
return uncompressed_data


def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60):
attempts = 0
pattern = re.compile(r"(\?|&)([\w-]+)=([^&\s]+)")
mask = r"\1\2=<REDACTED>"

# TODO: introduce connection pooling. I am seeing weird errors without it.
while attempts < max_retries:
try:
session = requests.Session()
session.timeout = download_timeout
response = session.get(url)

# Check if the response status code is in the 2xx range for success
if response.status_code == 200:
return response
else:
logger.error(response)
except requests.RequestException as e:
# if this is not redacted, it will print the pre-signed URL
logger.error(f"request failed with exception: {re.sub(pattern, mask, str(e))}")
finally:
session.close()
# Exponential backoff before the next attempt
wait_time = backoff_factor**attempts
logger.info(f"retrying in {wait_time} seconds...")
time.sleep(wait_time)

attempts += 1

logger.error(
f"exceeded maximum number of retries ({max_retries}) while downloading result."
)
return None
4 changes: 4 additions & 0 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError):

class CursorAlreadyClosedError(RequestError):
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""


class ResultSetDownloadError(RequestError):
"""Thrown if there was an error during the download of a result set"""
Loading

0 comments on commit 6a348ec

Please sign in to comment.