diff --git a/README.md b/README.md index 3c911e4e..6f0eb90a 100644 --- a/README.md +++ b/README.md @@ -469,6 +469,30 @@ conn = connect( ) ``` +## Spooled protocol + +The client spooling protocol requires [a Trino server with spooling protocol support](https://trino.io/docs/current/client/client-protocol.html#spooling-protocol). + +Enable the spooling protocol by specifying a supported encoding in the `encoding` parameter: + +```python +from trino.dbapi import connect + +conn = connect( + encoding="json+zstd" +) +``` + +or a list of supported encodings: + +```python +from trino.dbapi import connect + +conn = connect( + encoding=["json+zstd", "json"] +) +``` + ## Transactions The client runs by default in *autocommit* mode. To enable transactions, set diff --git a/etc/catalog/jmx.properties b/etc/catalog/jmx.properties new file mode 100644 index 00000000..b6e0372b --- /dev/null +++ b/etc/catalog/jmx.properties @@ -0,0 +1 @@ +connector.name=jmx diff --git a/etc/catalog/memory.properties b/etc/catalog/memory.properties new file mode 100644 index 00000000..833abd3f --- /dev/null +++ b/etc/catalog/memory.properties @@ -0,0 +1 @@ +connector.name=memory diff --git a/etc/catalog/tpcds.properties b/etc/catalog/tpcds.properties new file mode 100644 index 00000000..ba8147db --- /dev/null +++ b/etc/catalog/tpcds.properties @@ -0,0 +1 @@ +connector.name=tpcds diff --git a/etc/catalog/tpch.properties b/etc/catalog/tpch.properties new file mode 100644 index 00000000..599f5ec6 --- /dev/null +++ b/etc/catalog/tpch.properties @@ -0,0 +1,2 @@ +connector.name=tpch +tpch.splits-per-node=4 diff --git a/etc/config-pre-466.properties b/etc/config-pre-466.properties new file mode 100644 index 00000000..e28f2281 --- /dev/null +++ b/etc/config-pre-466.properties @@ -0,0 +1,11 @@ +node.id=coordinator +node.environment=test + +coordinator=true +node-scheduler.include-coordinator=true +http-server.http.port=8080 +query.max-memory=1GB +discovery.uri=http://localhost:8080 + +# Disable http request log +http-server.log.enabled=false diff --git a/etc/config.properties b/etc/config.properties new file mode 100644 index 00000000..10372938 --- /dev/null +++ b/etc/config.properties @@ -0,0 +1,17 @@ +node.id=coordinator +node.environment=test + +coordinator=true +experimental.concurrent-startup=true +node-scheduler.include-coordinator=true +http-server.http.port=8080 +query.max-memory=1GB +discovery.uri=http://localhost:8080 + +# spooling protocol settings +protocol.spooling.enabled=true +protocol.spooling.shared-secret-key=jxTKysfCBuMZtFqUf8UJDQ1w9ez8rynEJsJqgJf66u0= +protocol.spooling.retrieval-mode=coordinator_proxy + +# Disable http request log +http-server.log.enabled=false diff --git a/etc/jvm-pre-466.config b/etc/jvm-pre-466.config new file mode 100644 index 00000000..09753c04 --- /dev/null +++ b/etc/jvm-pre-466.config @@ -0,0 +1,16 @@ +-server +-Xmx2G +-XX:G1HeapRegionSize=32M +-XX:+ExplicitGCInvokesConcurrent +-XX:+ExitOnOutOfMemoryError +-XX:+HeapDumpOnOutOfMemoryError +-XX:-OmitStackTraceInFastThrow +-XX:ReservedCodeCacheSize=150M +-XX:PerMethodRecompilationCutoff=10000 +-XX:PerBytecodeRecompilationCutoff=10000 +-Djdk.attach.allowAttachSelf=true +# jdk.nio.maxCachedBufferSize controls what buffers can be allocated in per-thread "temporary buffer cache" (sun.nio.ch.Util). Value of 0 disables the cache. +-Djdk.nio.maxCachedBufferSize=0 +# Allow loading dynamic agent used by JOL +-XX:+EnableDynamicAgentLoading +-XX:+UnlockDiagnosticVMOptions diff --git a/etc/jvm.config b/etc/jvm.config new file mode 100644 index 00000000..08e3285d --- /dev/null +++ b/etc/jvm.config @@ -0,0 +1,17 @@ +-server +-Xmx2G +-XX:G1HeapRegionSize=32M +-XX:+ExplicitGCInvokesConcurrent +-XX:+ExitOnOutOfMemoryError +-XX:+HeapDumpOnOutOfMemoryError +-XX:-OmitStackTraceInFastThrow +-XX:ReservedCodeCacheSize=150M +-XX:PerMethodRecompilationCutoff=10000 +-XX:PerBytecodeRecompilationCutoff=10000 +-Djdk.attach.allowAttachSelf=true +# jdk.nio.maxCachedBufferSize controls what buffers can be allocated in per-thread "temporary buffer cache" (sun.nio.ch.Util). Value of 0 disables the cache. +-Djdk.nio.maxCachedBufferSize=0 +# Allow loading dynamic agent used by JOL +-XX:+EnableDynamicAgentLoading +-XX:+UnlockDiagnosticVMOptions +--enable-native-access=ALL-UNNAMED diff --git a/etc/spooling-manager.properties b/etc/spooling-manager.properties new file mode 100644 index 00000000..72d8e396 --- /dev/null +++ b/etc/spooling-manager.properties @@ -0,0 +1,8 @@ +spooling-manager.name=filesystem +fs.s3.enabled=true +fs.location=s3://spooling/ +s3.endpoint=http://localstack:4566/ +s3.region=us-east-1 +s3.aws-access-key=test +s3.aws-secret-key=test +s3.path-style-access=true diff --git a/setup.py b/setup.py index 0a512e6e..e497ab36 100755 --- a/setup.py +++ b/setup.py @@ -46,7 +46,9 @@ "pre-commit", "black", "isort", - "keyring" + "keyring", + "testcontainers", + "boto3" ] setup( @@ -81,11 +83,13 @@ ], python_requires=">=3.9", install_requires=[ + "lz4", "python-dateutil", "pytz", # requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q "requests>=2.31.0", "tzlocal", + "zstandard", ], extras_require={ "all": all_require, diff --git a/tests/development_server.py b/tests/development_server.py new file mode 100644 index 00000000..422f4bcd --- /dev/null +++ b/tests/development_server.py @@ -0,0 +1,127 @@ +import os +import time +from contextlib import contextmanager +from pathlib import Path + +from testcontainers.core.container import DockerContainer +from testcontainers.core.network import Network +from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.localstack import LocalStackContainer + +from trino.constants import DEFAULT_PORT + +MINIO_ROOT_USER = "minio-access-key" +MINIO_ROOT_PASSWORD = "minio-secret-key" + +TRINO_VERSION = os.environ.get("TRINO_VERSION") or "latest" +TRINO_HOST = "localhost" + + +def create_bucket(s3_client): + bucket_name = "spooling" + try: + print("Checking for bucket existence...") + response = s3_client.list_buckets() + buckets = [bucket["Name"] for bucket in response["Buckets"]] + if bucket_name in buckets: + print("Bucket exists!") + return + except s3_client.exceptions.ClientError as e: + if not e.response['Error']['Code'] == '404': + print("An error occurred:", e) + return + + try: + print("Creating bucket...") + s3_client.create_bucket( + Bucket=bucket_name, + ) + print("Bucket created!") + except s3_client.exceptions.ClientError as e: + print("An error occurred:", e) + + +@contextmanager +def start_development_server(port=None, trino_version=TRINO_VERSION): + network = None + localstack = None + trino = None + + try: + network = Network().create() + supports_spooling_protocol = TRINO_VERSION == "latest" or int(TRINO_VERSION) >= 466 + if supports_spooling_protocol: + localstack = LocalStackContainer(image="localstack/localstack:latest", region_name="us-east-1") \ + .with_name("localstack") \ + .with_network(network) \ + .with_bind_ports(4566, 4566) \ + .with_bind_ports(4571, 4571) \ + .with_env("SERVICES", "s3") + + # Start the container + print("Starting LocalStack container...") + localstack.start() + + # Wait for logs indicating MinIO has started + wait_for_logs(localstack, "Ready.", timeout=30) + + # create spooling bucket + create_bucket(localstack.get_client("s3")) + + trino = DockerContainer(f"trinodb/trino:{trino_version}") \ + .with_name("trino") \ + .with_network(network) \ + .with_env("TRINO_CONFIG_DIR", "/etc/trino") \ + .with_bind_ports(DEFAULT_PORT, port) + + root = Path(__file__).parent.parent + + trino = trino \ + .with_volume_mapping(str(root / "etc/catalog"), "/etc/trino/catalog") + + # Enable spooling config + if supports_spooling_protocol: + trino \ + .with_volume_mapping( + str(root / "etc/spooling-manager.properties"), + "/etc/trino/spooling-manager.properties", "rw") \ + .with_volume_mapping(str(root / "etc/jvm.config"), "/etc/trino/jvm.config") \ + .with_volume_mapping(str(root / "etc/config.properties"), "/etc/trino/config.properties") + else: + trino \ + .with_volume_mapping(str(root / "etc/jvm-pre-466.config"), "/etc/trino/jvm.config") \ + .with_volume_mapping(str(root / "etc/config-pre-466.properties"), "/etc/trino/config.properties") + + print("Starting Trino container...") + trino.start() + + # Wait for logs indicating the service has started + wait_for_logs(trino, "SERVER STARTED", timeout=60) + + # Otherwise some tests fail with No nodes available + time.sleep(2) + + yield localstack, trino, network + finally: + # Stop containers when exiting the context + if trino: + print("Stopping Trino container...") + trino.stop() + if localstack: + print("Stopping LocalStack container...") + localstack.stop() + if network: + network.remove() + + +def main(): + """Run Trino setup independently from pytest.""" + with start_development_server(port=DEFAULT_PORT): + print(f"Trino started at {TRINO_HOST}:{DEFAULT_PORT}") + + # Keep the process running so that the containers stay up + input("Press Enter to stop containers...") + + +if __name__ == "__main__": + main() diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 23fe3037..3184de6e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -11,33 +11,27 @@ # limitations under the License. import os import socket -import subprocess -import time +import sys from contextlib import closing -from uuid import uuid4 import pytest import trino.logging -from trino.client import ClientSession -from trino.client import TrinoQuery -from trino.client import TrinoRequest +from tests.development_server import start_development_server +from tests.development_server import TRINO_HOST +from tests.development_server import TRINO_VERSION from trino.constants import DEFAULT_PORT logger = trino.logging.get_logger(__name__) -TRINO_VERSION = os.environ.get("TRINO_VERSION") or "latest" -TRINO_HOST = "127.0.0.1" -TRINO_PORT = 8080 - - -def is_trino_available(): +def is_trino_available(host, port): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: sock.settimeout(2) - result = sock.connect_ex((TRINO_HOST, DEFAULT_PORT)) + result = sock.connect_ex((host, port)) if result == 0: return True + return False def get_local_port(): @@ -46,116 +40,24 @@ def get_local_port(): return s.getsockname()[1] -def get_default_trino_image_tag(): - return "trinodb/trino:" + TRINO_VERSION - - -def start_trino(image_tag=None): - if not image_tag: - image_tag = get_default_trino_image_tag() - - container_id = "trino-python-client-tests-" + uuid4().hex[:7] - local_port = get_local_port() - logger.info("starting Docker container") - docker_run = [ - "docker", - "run", - "--rm", - "-p", - "{host_port}:{cont_port}".format(host_port=local_port, cont_port=TRINO_PORT), - "--name", - container_id, - image_tag, - ] - run = subprocess.Popen(docker_run, universal_newlines=True, stderr=subprocess.PIPE) - return (container_id, run, "localhost", local_port) - - -def wait_for_trino_workers(host, port, timeout=180): - request = TrinoRequest( - host=host, - port=port, - client_session=ClientSession( - user="test_fixture" - ) - ) - sql = "SELECT state FROM system.runtime.nodes" - t0 = time.time() - while True: - query = TrinoQuery(request, sql) - rows = list(query.execute()) - if any(row[0] == "active" for row in rows): - break - if time.time() - t0 > timeout: - raise TimeoutError - time.sleep(1) - - -def wait_for_trino_coordinator(stream, timeout=180): - started_tag = "======== SERVER STARTED ========" - t0 = time.time() - for line in iter(stream.readline, b""): - if line: - print(line) - if started_tag in line: - time.sleep(5) - return True - if time.time() - t0 > timeout: - logger.error("coordinator took longer than %s to start", timeout) - raise TimeoutError - return False - - -def start_local_trino_server(image_tag): - container_id, proc, host, port = start_trino(image_tag) - print("trino.server.state starting") - trino_ready = wait_for_trino_coordinator(proc.stderr) - if not trino_ready: - raise Exception("Trino server did not start") - wait_for_trino_workers(host, port) - print("trino.server.state ready") - return container_id, proc, host, port - - -def start_trino_and_wait(image_tag=None): - container_id = None - proc = None - host = os.environ.get("TRINO_RUNNING_HOST", None) - if host: - port = os.environ.get("TRINO_RUNNING_PORT", DEFAULT_PORT) - else: - container_id, proc, host, port = start_local_trino_server( - image_tag - ) - - print("trino.server.hostname {}".format(host)) - print("trino.server.port {}".format(port)) - if proc: - print("trino.server.pid {}".format(proc.pid)) - if container_id: - print("trino.server.contained_id {}".format(container_id)) - return container_id, proc, host, port - - -def stop_trino(container_id, proc): - subprocess.check_call(["docker", "kill", container_id]) - - -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def run_trino(): - if is_trino_available(): - yield None, TRINO_HOST, DEFAULT_PORT - return + host = os.environ.get("TRINO_RUNNING_HOST", TRINO_HOST) + port = os.environ.get("TRINO_RUNNING_PORT", DEFAULT_PORT) - image_tag = os.environ.get("TRINO_IMAGE") - if not image_tag: - image_tag = get_default_trino_image_tag() + # Is there any local Trino available + if is_trino_available(host, port): + yield host, port + return - container_id, proc, host, port = start_trino_and_wait(image_tag) - yield proc, host, port - if container_id or proc: - stop_trino(container_id, proc) + # Start Trino and MinIO server + print(f"Could not connect to Trino at {host}:{port}, starting server...") + local_port = get_local_port() + with start_development_server(port=local_port): + yield TRINO_HOST, local_port -def trino_version(): - return TRINO_VERSION +def trino_version() -> int: + if TRINO_VERSION == "latest": + return sys.maxsize + return int(TRINO_VERSION) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 7165408f..d94f97f0 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import sys import time as t import uuid from datetime import date @@ -37,18 +38,19 @@ from trino.transaction import IsolationLevel -@pytest.fixture -def trino_connection(run_trino): - _, host, port = run_trino +@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"]) +def trino_connection(request, run_trino): + host, port = run_trino + encoding = request.param yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 + host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding ) @pytest.fixture def trino_connection_with_transaction(run_trino): - _, host, port = run_trino + host, port = run_trino yield trino.dbapi.Connection( host=host, @@ -62,7 +64,7 @@ def trino_connection_with_transaction(run_trino): @pytest.fixture def trino_connection_in_autocommit(run_trino): - _, host, port = run_trino + host, port = run_trino yield trino.dbapi.Connection( host=host, @@ -80,10 +82,10 @@ def test_select_query(trino_connection): rows = cur.fetchall() assert len(rows) > 0 row = rows[0] - if trino_version() == "latest": + if trino_version() == sys.maxsize: assert row[2] is not None else: - assert row[2] == trino_version() + assert row[2] == str(trino_version()) columns = dict([desc[:2] for desc in cur.description]) assert columns["node_id"] == "varchar" assert columns["http_uri"] == "varchar" @@ -112,7 +114,7 @@ def test_select_query_result_iteration(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -146,7 +148,7 @@ def test_select_query_result_iteration_statement_params(legacy_prepared_statemen [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -165,7 +167,7 @@ def test_none_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -185,7 +187,7 @@ def test_string_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -224,7 +226,7 @@ def test_execute_many(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -249,7 +251,7 @@ def test_execute_many_without_params(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -268,7 +270,7 @@ def test_legacy_primitive_types_with_connection_and_cursor( cursor_legacy_primitive_types, run_trino ): - _, host, port = run_trino + host, port = run_trino connection = trino.dbapi.Connection( host=host, @@ -333,7 +335,7 @@ def test_legacy_primitive_types_with_connection_and_cursor( [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -353,7 +355,7 @@ def test_decimal_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -392,7 +394,7 @@ def test_null_decimal(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -413,7 +415,7 @@ def test_biggest_decimal(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -434,7 +436,7 @@ def test_smallest_decimal(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -455,7 +457,7 @@ def test_highest_precision_decimal(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -477,7 +479,7 @@ def test_datetime_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -499,7 +501,7 @@ def test_datetime_with_utc_time_zone_query_param(legacy_prepared_statements, run [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -523,7 +525,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(legacy_prepared_stat [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -575,7 +577,7 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -595,7 +597,7 @@ def test_datetimes_with_time_zone_in_dst_gap_query_param(legacy_prepared_stateme [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -619,7 +621,7 @@ def test_doubled_datetimes(fold, legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -667,7 +669,7 @@ def test_unsupported_python_dates(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -718,7 +720,7 @@ def test_char(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -740,7 +742,7 @@ def test_time_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -762,7 +764,7 @@ def test_time_with_named_time_zone_query_param(legacy_prepared_statements, run_t [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -838,7 +840,7 @@ def test_null_date_with_time_zone(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -867,7 +869,7 @@ def test_binary_query_param(binary_input, legacy_prepared_statements, run_trino) [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -896,7 +898,7 @@ def test_array_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -922,7 +924,7 @@ def test_array_none_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -948,7 +950,7 @@ def test_array_none_and_another_type_query_param(legacy_prepared_statements, run [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -974,7 +976,7 @@ def test_array_timestamp_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1003,7 +1005,7 @@ def test_array_timestamp_with_timezone_query_param(legacy_prepared_statements, r [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1027,7 +1029,7 @@ def test_dict_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1047,7 +1049,7 @@ def test_dict_timestamp_query_param_types(legacy_prepared_statements, run_trino) [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1071,7 +1073,7 @@ def test_boolean_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1090,7 +1092,7 @@ def test_row(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1160,7 +1162,7 @@ def test_nested_named_row(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1179,7 +1181,7 @@ def test_float_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1199,7 +1201,7 @@ def test_float_nan_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1223,7 +1225,7 @@ def test_float_inf_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1248,7 +1250,7 @@ def test_int_query_param(legacy_prepared_statements, run_trino): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1364,7 +1366,7 @@ def test_close_cursor(trino_connection): def test_session_properties(run_trino): - _, host, port = run_trino + host, port = run_trino connection = trino.dbapi.Connection( host=host, @@ -1419,8 +1421,8 @@ def test_transaction_multiple(trino_connection_with_transaction): assert len(rows2) == 1000 -@pytest.mark.skipif(trino_version() == '351', reason="Autocommit behaves " - "differently in older Trino versions") +@pytest.mark.skipif(trino_version() == 351, reason="Autocommit behaves " + "differently in older Trino versions") def test_transaction_autocommit(trino_connection_in_autocommit): with trino_connection_in_autocommit as connection: connection.start_transaction() @@ -1441,7 +1443,7 @@ def test_transaction_autocommit(trino_connection_in_autocommit): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1512,7 +1514,7 @@ def test_client_tags_special_characters(run_trino): def retrieve_client_tags_from_query(run_trino, client_tags): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, @@ -1536,7 +1538,7 @@ def retrieve_client_tags_from_query(run_trino, client_tags): return query_client_tags -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif(trino_version() == 351, reason="current_catalog not supported in older Trino versions") def test_use_catalog_schema(trino_connection): cur = trino_connection.cursor() cur.execute('SELECT current_catalog, current_schema') @@ -1559,9 +1561,9 @@ def test_use_catalog_schema(trino_connection): assert result[0][1] == 'sf1' -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif(trino_version() == 351, reason="current_catalog not supported in older Trino versions") def test_use_schema(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", source="test", catalog="tpch", max_attempts=1 @@ -1588,7 +1590,7 @@ def test_use_schema(run_trino): def test_set_role(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", catalog="tpch" @@ -1600,7 +1602,7 @@ def test_set_role(run_trino): cur.execute("SET ROLE ALL") cur.fetchall() - if trino_version() == "351": + if trino_version() == 351: assert_role_headers(cur, "tpch=ALL") else: # Newer Trino versions return the system role @@ -1608,7 +1610,7 @@ def test_set_role(run_trino): def test_set_role_in_connection(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"} @@ -1620,7 +1622,7 @@ def test_set_role_in_connection(run_trino): def test_set_system_role_in_connection(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", catalog="tpch", roles="ALL" @@ -1640,7 +1642,7 @@ def assert_role_headers(cursor, expected_header): [ True, pytest.param(None, marks=pytest.mark.skipif( - trino_version() > '417', + trino_version() > 417, reason="This would use EXECUTE IMMEDIATE")) ] ) @@ -1675,7 +1677,7 @@ def test_prepared_statements(legacy_prepared_statements, run_trino): def test_set_timezone_in_connection(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", catalog="tpch", timezone="Europe/Brussels" @@ -1687,7 +1689,7 @@ def test_set_timezone_in_connection(run_trino): def test_connection_without_timezone(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", catalog="tpch" @@ -1703,7 +1705,7 @@ def test_connection_without_timezone(run_trino): def test_describe(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", catalog="tpch", @@ -1719,7 +1721,7 @@ def test_describe(run_trino): def test_describe_table_query(run_trino): - _, host, port = run_trino + host, port = run_trino trino_connection = trino.dbapi.Connection( host=host, port=port, user="test", catalog="tpch", @@ -1799,7 +1801,7 @@ def test_rowcount_insert(trino_connection): [ True, pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', + trino_version() <= 417, reason="EXECUTE IMMEDIATE was introduced in version 418")), None ] @@ -1809,7 +1811,7 @@ def test_prepared_statement_capability_autodetection(legacy_prepared_statements, trino.dbapi.must_use_legacy_prepared_statements = TimeBoundLRUCache(1024, 3600) user_name = f"user_{t.monotonic_ns()}" - _, host, port = run_trino + host, port = run_trino connection = trino.dbapi.Connection( host=host, port=port, @@ -1829,8 +1831,61 @@ def test_prepared_statement_capability_autodetection(legacy_prepared_statements, assert statements.count("EXECUTE IMMEDIATE 'SELECT 1'") == (1 if legacy_prepared_statements is None else 0) +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_select_query_spooled_segments(trino_connection): + cur = trino_connection.cursor() + cur.execute("""SELECT l.* + FROM tpch.tiny.lineitem l, TABLE(sequence( + start => 1, + stop => 5, + step => 1)) n""") + rows = cur.fetchall() + assert len(rows) == 300875 + for row in rows: + assert isinstance(row[0], int), f"Expected integer for orderkey, got {type(row[0])}" + assert isinstance(row[1], int), f"Expected integer for partkey, got {type(row[1])}" + assert isinstance(row[2], int), f"Expected integer for suppkey, got {type(row[2])}" + assert isinstance(row[3], int), f"Expected int for linenumber, got {type(row[3])}" + assert isinstance(row[4], float), f"Expected float for quantity, got {type(row[4])}" + assert isinstance(row[5], float), f"Expected float for extendedprice, got {type(row[5])}" + assert isinstance(row[6], float), f"Expected float for discount, got {type(row[6])}" + assert isinstance(row[7], float), f"Expected string for tax, got {type(row[7])}" + assert isinstance(row[8], str), f"Expected string for returnflag, got {type(row[8])}" + assert isinstance(row[9], str), f"Expected string for linestatus, got {type(row[9])}" + assert isinstance(row[10], date), f"Expected date for shipdate, got {type(row[10])}" + assert isinstance(row[11], date), f"Expected date for commitdate, got {type(row[11])}" + assert isinstance(row[12], date), f"Expected date for receiptdate, got {type(row[12])}" + assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}" + + +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_segments_cursor(trino_connection): + if trino_connection._client_session.encoding is None: + with pytest.raises(ValueError, match=".*encoding.*"): + trino_connection.cursor("segment") + return + cur = trino_connection.cursor("segment") + cur.execute("""SELECT l.* + FROM tpch.tiny.lineitem l, TABLE(sequence( + start => 1, + stop => 5, + step => 1)) n""") + rows = cur.fetchall() + assert len(rows) > 0 + for spooled_data, spooled_segment in rows: + assert spooled_data.encoding == trino_connection._client_session.encoding + assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}" + assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}" + + def get_cursor(legacy_prepared_statements, run_trino): - _, host, port = run_trino + host, port = run_trino connection = trino.dbapi.Connection( host=host, diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 9d813365..896d0d91 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -29,9 +29,9 @@ @pytest.fixture def trino_connection(run_trino, request): - _, host, port = run_trino + host, port = run_trino connect_args = {"source": "test", "max_attempts": 1} - if trino_version() <= '417': + if trino_version() <= 417: connect_args["legacy_prepared_statements"] = True engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}", connect_args=connect_args) @@ -739,7 +739,7 @@ def test_get_view_names_raises(trino_connection): @pytest.mark.parametrize('trino_connection', ['system'], indirect=True) -@pytest.mark.skipif(trino_version() == '351', reason="version() not supported in older Trino versions") +@pytest.mark.skipif(trino_version() == 351, reason="version() not supported in older Trino versions") def test_version_is_lazy(trino_connection): _, conn = trino_connection result = conn.execute(sqla.text("SELECT 1")) diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 881619d0..cc927883 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -17,12 +17,18 @@ from tests.integration.conftest import trino_version -@pytest.fixture -def trino_connection(run_trino): - _, host, port = run_trino +@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"]) +def trino_connection(request, run_trino): + host, port = run_trino + encoding = request.param yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 + host=host, + port=port, + user="test", + source="test", + max_attempts=1, + encoding=encoding ) @@ -168,7 +174,7 @@ def test_date(trino_connection): ).execute() -@pytest.mark.skipif(trino_version() == '351', reason="time not rounded correctly in older Trino versions") +@pytest.mark.skipif(trino_version() == 351, reason="time not rounded correctly in older Trino versions") def test_time(trino_connection): ( SqlTest(trino_connection) @@ -290,7 +296,7 @@ def test_time(trino_connection): ).execute() -@pytest.mark.skipif(trino_version() == '351', reason="time not rounded correctly in older Trino versions") +@pytest.mark.skipif(trino_version() == 351, reason="time not rounded correctly in older Trino versions") @pytest.mark.parametrize( 'tz_str', [ diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 653423a0..b33b72f5 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -97,6 +97,7 @@ def test_request_headers(mock_get_and_post): accept_encoding_value = "identity,deflate,gzip" client_info_header = constants.HEADER_CLIENT_INFO client_info_value = "some_client_info" + encoding = "json+zstd" with pytest.deprecated_call(): req = TrinoRequest( @@ -109,6 +110,7 @@ def test_request_headers(mock_get_and_post): catalog=catalog, schema=schema, timezone=timezone, + encoding=encoding, headers={ accept_encoding_header: accept_encoding_value, client_info_header: client_info_value, @@ -143,7 +145,8 @@ def assert_headers(headers): "catalog2=" + urllib.parse.quote("ROLE{catalog2_role}") ) assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}" - assert len(headers.keys()) == 13 + assert headers[constants.HEADER_ENCODING] == encoding + assert len(headers.keys()) == 14 req.post("URL") _, post_kwargs = post.call_args diff --git a/trino/client.py b/trino/client.py index da5e4047..077a9ba1 100644 --- a/trino/client.py +++ b/trino/client.py @@ -34,27 +34,37 @@ """ from __future__ import annotations +import abc +import base64 import copy import functools +import json import os import random import re import threading import urllib.parse import warnings +from abc import abstractmethod +from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime from email.utils import parsedate_to_datetime from time import sleep from typing import Any +from typing import cast from typing import Dict from typing import List +from typing import Literal from typing import Optional from typing import Tuple +from typing import TypedDict from typing import Union from zoneinfo import ZoneInfo +import lz4.block import requests +import zstandard from tzlocal import get_localzone_name # type: ignore import trino.logging @@ -64,7 +74,16 @@ from trino.mapper import RowMapper from trino.mapper import RowMapperFactory -__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"] +__all__ = [ + "ClientSession", + "TrinoQuery", + "TrinoRequest", + "PROXIES", + "SpooledData", + "SpooledSegment", + "InlineSegment", + "Segment" +] logger = trino.logging.get_logger(__name__) @@ -114,6 +133,7 @@ class ClientSession: :param roles: roles for the current session. Some connectors do not support role management. See connector documentation for more details. :param timezone: The timezone for query processing. Defaults to the system's local timezone. + :param encoding: The encoding for the spooling protocol. Defaults to None. """ def __init__( @@ -130,6 +150,7 @@ def __init__( client_tags: Optional[List[str]] = None, roles: Optional[Union[Dict[str, str], str]] = None, timezone: Optional[str] = None, + encoding: Optional[Union[str, List[str]]] = None, ): self._object_lock = threading.Lock() self._prepared_statements: Dict[str, str] = {} @@ -148,6 +169,7 @@ def __init__( self._timezone = timezone or get_localzone_name() if timezone: # Check timezone validity ZoneInfo(timezone) + self._encoding = encoding @property def user(self) -> str: @@ -243,6 +265,11 @@ def timezone(self) -> str: with self._object_lock: return self._timezone + @property + def encoding(self): + with self._object_lock: + return self._encoding + @staticmethod def _format_roles(roles: Union[Dict[str, str], str]) -> Dict[str, str]: if isinstance(roles, str): @@ -308,7 +335,7 @@ class TrinoStatus: next_uri: Optional[str] update_type: Optional[str] update_count: Optional[int] - rows: List[Any] + rows: Union[List[Any], Dict[str, Any]] columns: List[Any] def __repr__(self): @@ -471,6 +498,14 @@ def http_headers(self) -> Dict[str, str]: headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone + if self._client_session.encoding is None: + pass + elif isinstance(self._client_session.encoding, list): + headers[constants.HEADER_ENCODING] = ",".join(self._client_session.encoding) + elif isinstance(self._client_session.encoding, str): + headers[constants.HEADER_ENCODING] = self._client_session.encoding + else: + raise ValueError("Invalid type for encoding: expected str or list") headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME' headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}" if len(self._client_session.roles.values()): @@ -743,6 +778,7 @@ def __init__( request: TrinoRequest, query: str, legacy_primitive_types: bool = False, + fetch_mode: Literal["mapped", "segments"] = "mapped" ) -> None: self._query_id: Optional[str] = None self._stats: Dict[Any, Any] = {} @@ -759,6 +795,7 @@ def __init__( self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types self._row_mapper: Optional[RowMapper] = None + self._fetch_mode = fetch_mode @property def query_id(self) -> Optional[str]: @@ -844,7 +881,7 @@ def _update_state(self, status): if status.columns: self._columns = status.columns - def fetch(self) -> List[List[Any]]: + def fetch(self) -> List[Union[List[Any]], Any]: """Continue fetching data for the current query_id""" try: response = self._request.get(self._request.next_uri) @@ -858,7 +895,34 @@ def fetch(self) -> List[List[Any]]: if not self._row_mapper: return [] - return self._row_mapper.map(status.rows) + rows = status.rows + if isinstance(status.rows, dict): + # spooling protocol + rows = cast(_SpooledProtocolResponseTO, rows) + segments = self._to_segments(rows) + if self._fetch_mode == "segments": + return segments + return list(SegmentIterator(segments, self._row_mapper)) + elif isinstance(status.rows, list): + return self._row_mapper.map(rows) + else: + raise ValueError(f"Unexpected type: {type(status.rows)}") + + def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: + encoding = rows["encoding"] + segments = [] + for segment in rows["segments"]: + segment_type = segment["type"] + if segment_type == "inline": + inline_segment = cast(_InlineSegmentTO, segment) + segments.append(InlineSegment(inline_segment)) + elif segment_type == "spooled": + spooled_segment = cast(_SpooledSegmentTO, segment) + segments.append(SpooledSegment(spooled_segment, self._request)) + else: + raise ValueError(f"Unsupported segment type: {segment_type}") + + return SpooledData(encoding, segments) def cancel(self) -> None: """Cancel the current query""" @@ -934,3 +998,284 @@ def _parse_retry_after_header(retry_after): retry_date = parsedate_to_datetime(retry_after) now = datetime.utcnow() return (retry_date - now).total_seconds() + + +# Trino Spooled protocol transfer objects +class _SpooledProtocolResponseTO(TypedDict): + encoding: Literal["json", "json+std", "json+lz4"] + segments: List[_SegmentTO] + + +class _SegmentMetadataTO(TypedDict): + uncompressedSize: str + segmentSize: str + + +class _SegmentTO(_SegmentMetadataTO): + type: Literal["spooled", "inline"] + metadata: _SegmentMetadataTO + + +class _SpooledSegmentTO(_SegmentTO): + uri: str + ackUri: str + headers: Dict[str, List[str]] + + +class _InlineSegmentTO(_SegmentTO): + data: str + + +class Segment(abc.ABC): + """ + Abstract base class representing a segment of data produced by the spooling protocol. + + Attributes: + metadata (property): Metadata associated with the segment. + rows (property): Returns the decoded and mapped rows of data. + """ + def __init__(self, segment: _SegmentTO) -> None: + self._segment = segment + + @property + @abstractmethod + def data(self): + pass + + @property + def metadata(self) -> _SegmentMetadataTO: + return self._segment["metadata"] + + +class InlineSegment(Segment): + """ + A subclass of Segment that handles inline data segments. The data is base64 encoded and + requires mapping to rows using the provided row_mapper. + + Attributes: + rows (property): The rows of data in the segment, decoded and mapped from the base64 encoded data. + """ + def __init__(self, segment: _InlineSegmentTO) -> None: + super().__init__(segment) + self._segment = cast(_InlineSegmentTO, segment) + + @property + def data(self) -> bytes: + return base64.b64decode(self._segment["data"]) + + def __repr__(self): + return f"InlineSegment(metadata={self.metadata})" + + +class SpooledSegment(Segment): + """ + A subclass of Segment that handles spooled data segments, where data may be compressed and needs to be + retrieved via HTTP requests. The segment includes methods for acknowledging processing and loading the + segment from remote storage. + + Attributes: + rows (property): The rows of data, loaded and mapped from the spooled segment. + uri (property): The URI for the spooled segment. + ack_uri (property): The URI for acknowledging the processing of the spooled segment. + headers (property): The headers associated with the spooled segment. + + Methods: + acknowledge(): Sends an acknowledgment request for the segment. + """ + def __init__( + self, + segment: _SpooledSegmentTO, + request: TrinoRequest, + ) -> None: + super().__init__(segment) + self._segment = cast(_SpooledSegmentTO, segment) + self._request = request + + @property + def data(self) -> bytes: + http_response = self._send_spooling_request(self.uri) + if not http_response.ok: + self._request.raise_response_error(http_response) + return http_response.content + + @property + def uri(self) -> str: + return self._segment["uri"] + + @property + def ack_uri(self) -> str: + return self._segment["ackUri"] + + @property + def headers(self) -> Dict[str, List[str]]: + return self._segment.get("headers", {}) + + def acknowledge(self) -> None: + def acknowledge_request(): + try: + http_response = self._send_spooling_request(self.ack_uri, timeout=2) + if not http_response.ok: + self._request.raise_response_error(http_response) + except Exception as e: + logger.error(f"Failed to acknowledge spooling request for segment {self}: {e}") + # Start the acknowledgment in a background thread + thread = threading.Thread(target=acknowledge_request, daemon=True) + thread.start() + + def _send_spooling_request(self, uri: str, **kwargs) -> requests.Response: + headers_with_single_value = {} + for key, values in self.headers.items(): + if len(values) > 1: + raise ValueError(f"Header '{key}' contains multiple values: {values}") + headers_with_single_value[key] = values[0] + return self._request._get(uri, headers=headers_with_single_value, **kwargs) + + def __repr__(self): + return ( + f"SpooledSegment(metadata={self.metadata})" + ) + + +class SpooledData: + """ + Represents a collection of spooled segments of data, with an encoding format. + + Attributes: + encoding (str): The encoding format of the spooled data. + segments (List[Segment]): The list of segments in the spooled data. + """ + def __init__(self, encoding: str, segments: List[Segment]) -> None: + self._encoding = encoding + self._segments = segments + self._segments_iterator = iter(segments) + + @property + def encoding(self): + return self._encoding + + @property + def segments(self): + return self._segments + + def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]: + return self + + def __next__(self) -> Tuple["SpooledData", "Segment"]: + return self, next(self._segments_iterator) + + def __repr__(self): + return (f"SpooledData(encoding={self._encoding}, segments={list(self._segments)})") + + +class SegmentIterator: + def __init__(self, spooled_data: SpooledData, mapper: RowMapper) -> None: + self._segments = iter(spooled_data._segments) + self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(mapper).create(spooled_data.encoding)) + self._rows: Iterator[List[List[Any]]] = iter([]) + self._finished = False + self._current_segment: Optional[Segment] = None + + def __iter__(self) -> Iterator[List[Any]]: + return self + + def __next__(self) -> List[Any]: + # If rows are exhausted, fetch the next segment + while True: + try: + return next(self._rows) + except StopIteration: + if self._current_segment and isinstance(self._current_segment, SpooledSegment): + self._current_segment.acknowledge() + if self._finished: + raise StopIteration + self._load_next_segment() + + def _load_next_segment(self): + try: + self._current_segment = segment = next(self._segments) + self._rows = iter(self._decoder.decode(segment)) + except StopIteration: + self._finished = True + + +class SegmentDecoder(): + def __init__(self, decoder: QueryDataDecoder): + self._decoder = decoder + + def decode(self, segment: Segment) -> List[List[Any]]: + if isinstance(segment, InlineSegment): + inline_segment = cast(InlineSegment, segment) + return self._decoder.decode(inline_segment.data, inline_segment.metadata) + elif isinstance(segment, SpooledSegment): + spooled_data = cast(SpooledSegment, segment) + return self._decoder.decode(spooled_data.data, spooled_data.metadata) + else: + raise ValueError(f"Unsupported segment type: {type(segment)}") + + +class CompressedQueryDataDecoderFactory(): + def __init__(self, mapper: RowMapper) -> None: + self._mapper = mapper + + def create(self, encoding: str) -> QueryDataDecoder: + if encoding == "json+zstd": + return ZStdQueryDataDecoder(JsonQueryDataDecoder(self._mapper)) + elif encoding == "json+lz4": + return Lz4QueryDataDecoder(JsonQueryDataDecoder(self._mapper)) + elif encoding == "json": + return JsonQueryDataDecoder(self._mapper) + else: + raise ValueError(f"Unsupported encoding: {encoding}") + + +class QueryDataDecoder(abc.ABC): + @abstractmethod + def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]: + pass + + +class JsonQueryDataDecoder(QueryDataDecoder): + def __init__(self, mapper: RowMapper) -> None: + self._mapper = mapper + + def decode(self, data: bytes, metadata: Dict[str, Any]) -> List[List[Any]]: + return self._mapper.map(json.loads(data.decode("utf8"))) + + +class CompressedQueryDataDecoder(QueryDataDecoder): + def __init__(self, delegate: QueryDataDecoder) -> None: + self._delegate = delegate + + @abstractmethod + def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: + pass + + def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]: + if "uncompressedSize" in metadata: + # Data is compressed + expected_compressed_size = metadata["segmentSize"] + if not len(data) == expected_compressed_size: + raise RuntimeError(f"Expected to read {expected_compressed_size} bytes but got {len(data)}") + compressed_data = self.decompress(data, metadata) + expected_uncompressed_size = metadata["uncompressedSize"] + if not len(compressed_data) == expected_uncompressed_size: + raise RuntimeError( + "Decompressed size does not match expected segment size, " + f"expected {expected_uncompressed_size}, got {len(compressed_data)}" + ) + return self._delegate.decode(compressed_data, metadata) + # Data not compressed - below threshold + return self._delegate.decode(data, metadata) + + +class ZStdQueryDataDecoder(CompressedQueryDataDecoder): + def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: + zstd_decompressor = zstandard.ZstdDecompressor() + return zstd_decompressor.decompress(data) + + +class Lz4QueryDataDecoder(CompressedQueryDataDecoder): + def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: + expected_uncompressed_size = metadata["uncompressedSize"] + decoded_bytes = lz4.block.decompress(data, uncompressed_size=int(expected_uncompressed_size)) + return decoded_bytes diff --git a/trino/constants.py b/trino/constants.py index 8193f218..20714e9f 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -37,6 +37,7 @@ HEADER_CLIENT_TAGS = "X-Trino-Client-Tags" HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential" HEADER_TIMEZONE = "X-Trino-Time-Zone" +HEADER_ENCODING = "X-Trino-Query-Data-Encoding" HEADER_SESSION = "X-Trino-Session" HEADER_SET_SESSION = "X-Trino-Set-Session" diff --git a/trino/dbapi.py b/trino/dbapi.py index fb24f867..dee7cdb7 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -30,6 +30,7 @@ from typing import List from typing import NamedTuple from typing import Optional +from typing import Union from urllib.parse import urlparse from zoneinfo import ZoneInfo @@ -128,6 +129,9 @@ def connect(*args, **kwargs): return Connection(*args, **kwargs) +_USE_DEFAULT_ENCODING = object() + + class Connection: """Trino supports transactions and the ability to either commit or rollback a sequence of SQL statements. A single query i.e. the execution of a SQL @@ -159,10 +163,18 @@ def __init__( legacy_prepared_statements=None, roles=None, timezone=None, + encoding: Union[str, List[str]] = _USE_DEFAULT_ENCODING, ): # Automatically assign http_schema, port based on hostname parsed_host = urlparse(host, allow_fragments=False) + if encoding is _USE_DEFAULT_ENCODING: + encoding = [ + "json+zstd", + "json+lz4", + "json", + ] + self.host = host if parsed_host.hostname is None else parsed_host.hostname + parsed_host.path self.port = port if parsed_host.port is None else parsed_host.port self.user = user @@ -182,6 +194,7 @@ def __init__( client_tags=client_tags, roles=roles, timezone=timezone, + encoding=encoding, ) # mypy cannot follow module import if http_session is None: @@ -255,7 +268,7 @@ def _create_request(self): self.request_timeout, ) - def cursor(self, legacy_primitive_types: bool = None): + def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None): """Return a new :py:class:`Cursor` object using the connection.""" if self.isolation_level != IsolationLevel.AUTOCOMMIT: if self.transaction is None: @@ -264,11 +277,21 @@ def cursor(self, legacy_primitive_types: bool = None): request = self.transaction.request else: request = self._create_request() - return Cursor( + + cursor_class = { + # Add any custom Cursor classes here + "segment": SegmentCursor, + "row": Cursor + }.get(cursor_style.lower(), Cursor) + + return cursor_class( self, request, - # if legacy params are not explicitly set in Cursor, take them from Connection - legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types + legacy_primitive_types=( + legacy_primitive_types + if legacy_primitive_types is not None + else self.legacy_primitive_types + ) ) def _use_legacy_prepared_statements(self): @@ -701,6 +724,28 @@ def close(self): # but also any other outstanding queries executed through this cursor. +class SegmentCursor(Cursor): + def __init__( + self, + connection, + request, + legacy_primitive_types: bool = False): + super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types) + if self.connection._client_session.encoding is None: + raise ValueError("SegmentCursor can only be used if encoding is set on the connection") + + def execute(self, operation, params=None): + if params: + # TODO: refactor code to allow for params to be supported + raise ValueError("params not supported") + + self._query = trino.client.TrinoQuery(self._request, query=operation, + legacy_primitive_types=self._legacy_primitive_types, + fetch_mode="segments") + self._iterator = iter(self._query.execute()) + return self + + Date = datetime.date Time = datetime.time Timestamp = datetime.datetime