diff --git a/trino/client.py b/trino/client.py index 6184c54c..b0b1e64b 100644 --- a/trino/client.py +++ b/trino/client.py @@ -44,6 +44,7 @@ import threading import urllib.parse import warnings +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from datetime import datetime from email.utils import parsedate_to_datetime @@ -764,6 +765,7 @@ def __init__( self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types self._row_mapper: Optional[RowMapper] = None + self._executor = ThreadPoolExecutor(max_workers=1) @property def query_id(self) -> Optional[str]: @@ -868,7 +870,7 @@ def fetch(self) -> List[List[Any]]: # spooled protocol encoding = rows["encoding"] segments = rows["segments"] - return list(SegmentIterator(segments, encoding, self._row_mapper, self._request)) + return list(SegmentIterator(segments, encoding, self._row_mapper, self._request, self._executor)) else: return self._row_mapper.map(rows) @@ -963,7 +965,14 @@ def _parse_retry_after_header(retry_after): class SegmentIterator: - def __init__(self, segments: List[SpooledSegment], encoding: str, row_mapper: RowMapper, request: TrinoRequest): + def __init__( + self, + segments: List[SpooledSegment], + encoding: str, + row_mapper: RowMapper, + request: TrinoRequest, + executor: ThreadPoolExecutor + ): self._segments = iter(segments) self._encoding = encoding self._row_mapper = row_mapper @@ -971,6 +980,8 @@ def __init__(self, segments: List[SpooledSegment], encoding: str, row_mapper: Ro self._rows: Iterator[List[List[Any]]] = iter([]) self._finished = False self._current_segment: Optional[SpooledSegment] = None + self._executor = executor + self._future: Optional[Future] = None def __iter__(self) -> Iterator[List[Any]]: return self @@ -991,26 +1002,34 @@ def __next__(self) -> List[Any]: self._load_next_row_set() def _load_next_row_set(self): - try: - self._current_segment = segment = next(self._segments) - segment_type = segment["type"] - - if segment_type == "inline": - data = segment["data"] - decoded_string = base64.b64decode(data) - rows = self._row_mapper.map(json.loads(decoded_string)) - self._rows = iter(rows) - - elif segment_type == "spooled": - decoded_string = self._load_spooled_segment(segment) - rows = self._row_mapper.map(json.loads(decoded_string)) - self._rows = iter(rows) - else: - raise ValueError(f"Unsupported segment type: {segment_type}") + if self._future: + # Wait for the future to complete and get the result + result = self._future.result() + self._current_segment, rows = result + self._rows = iter(rows) + try: + # Preload the next segment asynchronously + next_segment = next(self._segments) + self._future = self._executor.submit(self._fetch_and_decode_segment, next_segment) except StopIteration: self._finished = True + def _fetch_and_decode_segment(self, segment): + segment_type = segment["type"] + + if segment_type == "inline": + data = segment["data"] + decoded_string = base64.b64decode(data) + rows = self._row_mapper.map(json.loads(decoded_string)) + return segment, rows + elif segment_type == "spooled": + decoded_string = self._load_spooled_segment(segment) + rows = self._row_mapper.map(json.loads(decoded_string)) + return segment, rows + else: + raise ValueError(f"Unsupported segment type: {segment['type']}") + def _load_spooled_segment(self, segment: SpooledSegment) -> str: uri = segment["uri"] encoding = self._encoding