Skip to content

Commit

Permalink
Use background thread
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Nov 26, 2024
1 parent d7f4ee3 commit 6ec4713
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import threading
import urllib.parse
import warnings
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from email.utils import parsedate_to_datetime
Expand Down Expand Up @@ -764,6 +765,7 @@ def __init__(
self._result: Optional[TrinoResult] = None
self._legacy_primitive_types = legacy_primitive_types
self._row_mapper: Optional[RowMapper] = None
self._executor = ThreadPoolExecutor(max_workers=1)

@property
def query_id(self) -> Optional[str]:
Expand Down Expand Up @@ -868,7 +870,7 @@ def fetch(self) -> List[List[Any]]:
# spooled protocol
encoding = rows["encoding"]
segments = rows["segments"]
return list(SegmentIterator(segments, encoding, self._row_mapper, self._request))
return list(SegmentIterator(segments, encoding, self._row_mapper, self._request, self._executor))
else:
return self._row_mapper.map(rows)

Expand Down Expand Up @@ -963,14 +965,23 @@ def _parse_retry_after_header(retry_after):


class SegmentIterator:
def __init__(self, segments: List[SpooledSegment], encoding: str, row_mapper: RowMapper, request: TrinoRequest):
def __init__(
self,
segments: List[SpooledSegment],
encoding: str,
row_mapper: RowMapper,
request: TrinoRequest,
executor: ThreadPoolExecutor
):
self._segments = iter(segments)
self._encoding = encoding
self._row_mapper = row_mapper
self._request = request
self._rows: Iterator[List[List[Any]]] = iter([])
self._finished = False
self._current_segment: Optional[SpooledSegment] = None
self._executor = executor
self._future: Optional[Future] = None

def __iter__(self) -> Iterator[List[Any]]:
return self
Expand All @@ -991,26 +1002,34 @@ def __next__(self) -> List[Any]:
self._load_next_row_set()

def _load_next_row_set(self):
try:
self._current_segment = segment = next(self._segments)
segment_type = segment["type"]

if segment_type == "inline":
data = segment["data"]
decoded_string = base64.b64decode(data)
rows = self._row_mapper.map(json.loads(decoded_string))
self._rows = iter(rows)

elif segment_type == "spooled":
decoded_string = self._load_spooled_segment(segment)
rows = self._row_mapper.map(json.loads(decoded_string))
self._rows = iter(rows)
else:
raise ValueError(f"Unsupported segment type: {segment_type}")
if self._future:
# Wait for the future to complete and get the result
result = self._future.result()
self._current_segment, rows = result
self._rows = iter(rows)

try:
# Preload the next segment asynchronously
next_segment = next(self._segments)
self._future = self._executor.submit(self._fetch_and_decode_segment, next_segment)
except StopIteration:
self._finished = True

def _fetch_and_decode_segment(self, segment):
segment_type = segment["type"]

if segment_type == "inline":
data = segment["data"]
decoded_string = base64.b64decode(data)
rows = self._row_mapper.map(json.loads(decoded_string))
return segment, rows
elif segment_type == "spooled":
decoded_string = self._load_spooled_segment(segment)
rows = self._row_mapper.map(json.loads(decoded_string))
return segment, rows
else:
raise ValueError(f"Unsupported segment type: {segment['type']}")

def _load_spooled_segment(self, segment: SpooledSegment) -> str:
uri = segment["uri"]
encoding = self._encoding
Expand Down

0 comments on commit 6ec4713

Please sign in to comment.