From fce4027f655ea47ed52bb8d8c3c187d08df6e1ce Mon Sep 17 00:00:00 2001
From: Pavel Chugaev
Date: Fri, 10 Jan 2025 10:41:13 +0500
Subject: [PATCH] update linters (#24)
* update linters
* update github workflow
* fix mypy errors
* fix tests
---
.github/workflows/qa.yml | 4 +-
Makefile | 19 ++--
pyproject.toml | 106 ++++++++++++++-----
setup.py | 17 +--
sqlalchemy_kusto/__init__.py | 7 +-
sqlalchemy_kusto/dbapi.py | 125 +++++++++++++++--------
sqlalchemy_kusto/dialect_base.py | 118 +++++++++++++++------
sqlalchemy_kusto/dialect_kql.py | 52 ++++++----
sqlalchemy_kusto/dialect_sql.py | 17 +--
sqlalchemy_kusto/errors.py | 2 +-
tests/integration/conftest.py | 8 +-
tests/integration/test_dbapi.py | 7 +-
tests/integration/test_dialect_sql.py | 87 ++++++++++------
tests/integration/test_error_handling.py | 2 +-
tests/unit/test_dialect_kql.py | 92 ++++++++++++-----
15 files changed, 457 insertions(+), 206 deletions(-)
diff --git a/.github/workflows/qa.yml b/.github/workflows/qa.yml
index 5ae432b..2bb9aa2 100644
--- a/.github/workflows/qa.yml
+++ b/.github/workflows/qa.yml
@@ -18,7 +18,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v2
with:
- python-version: 3.8
+ python-version: '3.10'
- uses: actions/cache@v2
with:
@@ -44,7 +44,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v2
with:
- python-version: 3.8
+ python-version: '3.10'
- uses: actions/cache@v2
with:
diff --git a/Makefile b/Makefile
index 415fa69..a1bcb55 100644
--- a/Makefile
+++ b/Makefile
@@ -27,17 +27,24 @@ install-dev: # Install dev dependencies
##############################################################################
# Development process
##############################################################################
-check: # Run formatters and linters
- @echo "Running checkers..."
+format:
+ @echo "Running formatters..."
- @echo "\n1. Run $(GREEN_ITALIC)isort$(DEFAULT) to order imports."
- $(PYTHON) -m isort --profile black .
+ @echo "\n1. Run $(GREEN_ITALIC)ruff$(DEFAULT) to format code."
+ $(PYTHON) -m ruff check --fix-only .
@echo "\n2. Run $(GREEN_ITALIC)black$(DEFAULT) to format code."
$(PYTHON) -m black .
- @echo "\n3. Run $(GREEN_ITALIC)pylint$(DEFAULT) to lint the project."
- $(PYTHON) -m pylint setup.py sqlalchemy_kusto/
+
+check: # Run formatters and linters
+ @echo "Running checkers..."
+
+ @echo "\n1. Run $(GREEN_ITALIC)ruff$(DEFAULT) to check code."
+ $(PYTHON) -m ruff check .
+
+ @echo "\n2. Run $(GREEN_ITALIC)black$(DEFAULT) to check code formatting."
+ $(PYTHON) -m black . --check
@echo "\n4. Run $(GREEN_ITALIC)mypy$(DEFAULT) for type checking."
$(PYTHON) -m mypy .
diff --git a/pyproject.toml b/pyproject.toml
index f04d999..58eb0c3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,9 +5,84 @@ requires = [
]
build-backend = "setuptools.build_meta"
-[tool.black]
+[tool.ruff]
+target-version = "py310"
+
+# https://beta.ruff.rs/docs/settings/#line-length
line-length = 120
-target-version = ["py38", "py39", "py310", "py311"]
+
+# https://beta.ruff.rs/docs/settings/#select
+lint.select = [
+ "F", # Pyflakes (https://beta.ruff.rs/docs/rules/#pyflakes-f)
+ "E", # pycodestyle (https://beta.ruff.rs/docs/rules/#pycodestyle-e-w)
+ "C90", # mccabe (https://beta.ruff.rs/docs/rules/#mccabe-c90)
+ "N", # pep8-naming (https://beta.ruff.rs/docs/rules/#pep8-naming-n)
+ "D", # pydocstyle (https://beta.ruff.rs/docs/rules/#pydocstyle-d)
+ "UP", # pyupgrade (https://beta.ruff.rs/docs/rules/#pyupgrade-up)
+ "ANN", # flake8-annotations (https://beta.ruff.rs/docs/rules/#flake8-annotations-ann)
+ "B", # flake8-bugbear (https://beta.ruff.rs/docs/rules/#flake8-bugbear-b)
+ "C4", # flake8-comprehensions (https://beta.ruff.rs/docs/rules/#flake8-comprehensions-c4)
+ "G", # flake8-logging-format (https://beta.ruff.rs/docs/rules/#flake8-logging-format-g)
+ "T20", # flake8-print (https://beta.ruff.rs/docs/rules/#flake8-print-t20)
+ "PT", # flake8-pytest-style (https://beta.ruff.rs/docs/rules/#flake8-pytest-style-pt)
+ "TID", # flake8-tidy-imports (https://beta.ruff.rs/docs/rules/#flake8-tidy-imports-tid)
+ "ARG", # flake8-unused-arguments (https://beta.ruff.rs/docs/rules/#flake8-unused-arguments-arg)
+ "PTH", # flake8-use-pathlib (https://beta.ruff.rs/docs/rules/#flake8-use-pathlib-pth)
+ "ERA", # eradicate (https://beta.ruff.rs/docs/rules/#eradicate-era)
+ "PL", # pylint (https://beta.ruff.rs/docs/rules/#pylint-pl)
+ "TRY", # tryceratops (https://beta.ruff.rs/docs/rules/#tryceratops-try)
+ "RUF100", # Unused noqa directive
+]
+
+# https://beta.ruff.rs/docs/settings/#ignore
+lint.ignore = [
+ "C901", # too complex
+
+ # pycodestyle (https://beta.ruff.rs/docs/rules/#pydocstyle-d)
+ "D100", # Missing docstring in public module
+ "D101", # Missing docstring in public class
+ "D102", # Missing docstring in public method
+ "D103", # Missing docstring in public function
+ "D104", # Missing docstring in public package
+ "D105", # Missing docstring in magic method
+ "D106", # Missing docstring in public nested class
+ "D107", # Missing docstring in `__init__`
+ "D203", # 1 blank line required before class docstring
+ "D205", # 1 blank line required between summary line and description
+ "D212", # Multi-line docstring summary should start at the first line
+
+ "N818", # Exception name {name} should be named with an Error suffix;
+
+ "TRY003", # Avoid specifying long messages outside the exception class
+
+ # flake8-annotations
+ "ANN001", # Missing type annotation for function argument
+ "ANN002", # Missing type annotation for `*args`
+ "ANN003", # Missing type annotation for `**kwargs`
+ "ANN201", # Missing return type annotation for public function
+ "ANN202", # Missing return type annotation for private function
+ "ANN204", # Missing return type annotation for special method
+ "ANN401", # Dynamically typed expressions (typing.Any) are disallowed
+
+ "ARG002", # Unused method argument
+
+ "PLR0913", # Too many arguments in function definition
+]
+
+[tool.ruff.lint.pycodestyle]
+max-doc-length = 120
+
+[tool.ruff.lint.pydocstyle]
+# Use Google-style docstrings
+convention = "google"
+
+[tool.ruff.lint.flake8-pytest-style]
+# Set the parametrize values type in tests.
+parametrize-values-type = "list"
+
+[tool.black]
+line-length = 88
+target-version = ["py310", "py311"]
include = ".pyi?$"
exclude = """
(
@@ -24,36 +99,13 @@ exclude = """
)
"""
-[tool.isort]
-line_length = 120
-multi_line_output = 3
-include_trailing_comma = true
-force_grid_wrap = 0
-use_parentheses = true
-
[tool.mypy]
-python_version = "3.8"
+python_version = "3.10"
strict_optional = true
show_error_codes = true
warn_redundant_casts = true
warn_unused_ignores = true
-disallow_any_generics = true
+disallow_any_generics = false
check_untyped_defs = true
no_implicit_reexport = true
ignore_missing_imports = true
-
-[tool.pylint.messages_control]
-max-line-length = 120
-disable = [
- "consider-using-f-string",
- "missing-class-docstring",
- "missing-function-docstring",
- "missing-module-docstring",
- "no-self-use",
- "protected-access",
- "too-few-public-methods",
- "too-many-arguments",
- "too-many-locals",
- "too-many-public-methods",
- "unused-argument",
-]
diff --git a/setup.py b/setup.py
index b39a299..1bcf0bb 100644
--- a/setup.py
+++ b/setup.py
@@ -1,3 +1,4 @@
+from pathlib import Path
from setuptools import find_packages, setup
NAME = "sqlalchemy-kusto"
@@ -12,16 +13,16 @@
]
EXTRAS = {
"dev": [
- "black>=21.12b0",
- "isort>=5.10.1",
- "mypy==0.971",
- "pylint==2.15.0",
- "pytest>=6.2.5",
- "python-dotenv>=0.19.2",
+ "black>=24.10.0",
+ "mypy>=1.14.1",
+ "pytest>=8.3.4",
+ "python-dotenv>=1.0.1",
+ "ruff>=0.8.6",
]
}
-with open("README.md", "r", encoding="utf-8") as f:
+path = Path("README.md")
+with path.open(encoding="utf-8") as f:
LONG_DESCRIPTION = f.read()
setup(
@@ -51,7 +52,7 @@
project_urls={
"Bug Tracker": "https://github.com/dodopizza/sqlalchemy-kusto/issues",
},
- python_requires=">=3.8",
+ python_requires=">=3.10",
version=VERSION,
zip_safe=False,
)
diff --git a/sqlalchemy_kusto/__init__.py b/sqlalchemy_kusto/__init__.py
index 2d1062c..2f5fcb1 100644
--- a/sqlalchemy_kusto/__init__.py
+++ b/sqlalchemy_kusto/__init__.py
@@ -1,6 +1,5 @@
from sqlalchemy_kusto.dbapi import connect
-# pylint: disable=redefined-builtin
from sqlalchemy_kusto.errors import (
DatabaseError,
DataError,
@@ -31,7 +30,7 @@
"Warning",
]
-apilevel = "2.0" # pylint: disable=invalid-name
+apilevel = "2.0"
# Threads may share the module and connections
-threadsafety = 2 # pylint: disable=invalid-name
-paramstyle = "pyformat" # pylint: disable=invalid-name
+threadsafety = 2
+paramstyle = "pyformat"
diff --git a/sqlalchemy_kusto/dbapi.py b/sqlalchemy_kusto/dbapi.py
index 71ba521..1b697a4 100644
--- a/sqlalchemy_kusto/dbapi.py
+++ b/sqlalchemy_kusto/dbapi.py
@@ -1,7 +1,12 @@
from collections import namedtuple
-from typing import Any, List, Optional, Tuple
+from typing import Any
-from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder
+from azure.identity import WorkloadIdentityCredential
+from azure.kusto.data import (
+ ClientRequestProperties,
+ KustoClient,
+ KustoConnectionStringBuilder,
+)
from azure.kusto.data._models import KustoResultColumn
from azure.kusto.data.exceptions import KustoAuthenticationError, KustoServiceError
@@ -13,7 +18,7 @@ def check_closed(func):
def decorator(self, *args, **kwargs):
if self.closed:
- raise Exception("{klass} already closed".format(klass=self.__class__.__name__))
+ raise ValueError(f"{self.__class__.__name__} already closed")
return func(self, *args, **kwargs)
return decorator
@@ -23,8 +28,8 @@ def check_result(func):
"""Decorator that checks if the cursor has results from `execute`."""
def decorator(self, *args, **kwargs):
- if self._results is None: # pylint: disable=protected-access
- raise Exception("Called before `execute`")
+ if self._results is None:
+ raise ValueError("Called before `execute`")
return func(self, *args, **kwargs)
return decorator
@@ -34,13 +39,23 @@ def connect(
cluster: str,
database: str,
msi: bool = False,
- user_msi: str = None,
- azure_ad_client_id: str = None,
- azure_ad_client_secret: str = None,
- azure_ad_tenant_id: str = None,
+ workload_identity: bool = False,
+ user_msi: str | None = None,
+ azure_ad_client_id: str | None = None,
+ azure_ad_client_secret: str | None = None,
+ azure_ad_tenant_id: str | None = None,
):
"""Return a connection to the database."""
- return Connection(cluster, database, msi, user_msi, azure_ad_client_id, azure_ad_client_secret, azure_ad_tenant_id)
+ return Connection(
+ cluster,
+ database,
+ msi,
+ workload_identity,
+ user_msi,
+ azure_ad_client_id,
+ azure_ad_client_secret,
+ azure_ad_tenant_id,
+ )
class Connection:
@@ -51,13 +66,14 @@ def __init__(
cluster: str,
database: str,
msi: bool = False,
- user_msi: str = None,
- azure_ad_client_id: str = None,
- azure_ad_client_secret: str = None,
- azure_ad_tenant_id: str = None,
+ workload_identity: bool = False,
+ user_msi: str | None = None,
+ azure_ad_client_id: str | None = None,
+ azure_ad_client_secret: str | None = None,
+ azure_ad_tenant_id: str | None = None,
):
self.closed = False
- self.cursors: List[Cursor] = []
+ self.cursors: list[Cursor] = []
kcsb = None
if azure_ad_client_id and azure_ad_client_secret and azure_ad_tenant_id:
@@ -68,15 +84,27 @@ def __init__(
app_key=azure_ad_client_secret,
authority_id=azure_ad_tenant_id,
)
+ elif workload_identity:
+ # Workload Identity
+ kcsb = KustoConnectionStringBuilder.with_azure_token_credential(
+ cluster, WorkloadIdentityCredential()
+ )
elif msi:
# Managed Service Identity (MSI)
- kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication(
- cluster, client_id=user_msi
- )
+ if user_msi is None or user_msi == "":
+ # System managed identity
+ kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication(
+ cluster
+ )
+ else:
+ # user managed identity
+ kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication(
+ cluster, client_id=user_msi
+ )
else:
# neither SP or MSI
kcsb = KustoConnectionStringBuilder.with_az_cli_authentication(cluster)
- kcsb._set_connector_details("sqlalchemy-kusto", "0.1.0") # pylint: disable=protected-access
+ kcsb._set_connector_details("sqlalchemy-kusto", "1.1.0")
self.kusto_client = KustoClient(kcsb)
self.database = database
self.properties = ClientRequestProperties()
@@ -84,7 +112,6 @@ def __init__(
@check_closed
def close(self):
"""Close the connection now. Kusto does not require to close the connection."""
- # self.closed = True
for cursor in self.cursors:
cursor.close()
@@ -130,17 +157,19 @@ def __init__(
self,
kusto_client: KustoClient,
database: str,
- properties: Optional[ClientRequestProperties] = None,
+ properties: ClientRequestProperties | None = None,
):
- self._results: Optional[List[Tuple[Any, ...]]] = None
+ self._results: list[tuple[Any, ...]] | None = None
self.kusto_client = kusto_client
self.database = database
self.closed = False
- self.description: Optional[List[CursorDescriptionRow]] = None
+ self.description: list[CursorDescriptionRow] | None = None
self.current_item_index = 0
- self.properties = properties if properties is not None else ClientRequestProperties()
+ self.properties = (
+ properties if properties is not None else ClientRequestProperties()
+ )
- @property # type: ignore
+ @property
@check_result
@check_closed
def rowcount(self) -> int:
@@ -152,7 +181,6 @@ def rowcount(self) -> int:
@check_closed
def close(self):
"""Closes the cursor."""
- # self.closed = True
@check_closed
def execute(self, operation, parameters=None) -> "Cursor":
@@ -165,23 +193,29 @@ def execute(self, operation, parameters=None) -> "Cursor":
query = Cursor._apply_parameters(operation, parameters)
query = query.rstrip()
try:
- server_response = self.kusto_client.execute(self.database, query, self.properties)
+ server_response = self.kusto_client.execute(
+ self.database, query, self.properties
+ )
except KustoServiceError as kusto_error:
- raise errors.DatabaseError(str(kusto_error))
+ raise errors.DatabaseError(str(kusto_error)) from kusto_error
except KustoAuthenticationError as context_error:
- raise errors.OperationalError(str(context_error))
+ raise errors.OperationalError(str(context_error)) from context_error
rows = []
for row in server_response.primary_results[0]:
rows.append(tuple(row.to_list()))
self._results = rows
- self.description = self._get_description_from_columns(server_response.primary_results[0].columns)
+ self.description = self._get_description_from_columns(
+ server_response.primary_results[0].columns
+ )
return self
@check_closed
def executemany(self, operation, seq_of_parameters=None):
- """Not supported"""
- raise NotImplementedError("`executemany` is not supported, use `execute` instead")
+ """Not supported."""
+ raise NotImplementedError(
+ "`executemany` is not supported, use `execute` instead"
+ )
@check_result
@check_closed
@@ -199,7 +233,7 @@ def fetchone(self):
@check_result
@check_closed
- def fetchmany(self, size: int = None):
+ def fetchmany(self, size: int | None = None):
"""
Fetches the next set of rows of a query result, returning a sequence of
sequences (e.g. a list of tuples). An empty sequence is returned when
@@ -224,15 +258,17 @@ def fetchall(self):
@check_closed
def setinputsizes(self, sizes):
- """Not supported"""
+ """Not supported."""
@check_closed
def setoutputsizes(self, sizes):
- """Not supported"""
+ """Not supported."""
@staticmethod
- def _get_description_from_columns(columns: List[KustoResultColumn]) -> List[CursorDescriptionRow]:
- """Gets CursorDescriptionRow for Kusto columns"""
+ def _get_description_from_columns(
+ columns: list[KustoResultColumn],
+ ) -> list[CursorDescriptionRow]:
+ """Gets CursorDescriptionRow for Kusto columns."""
return [
CursorDescriptionRow(
name=column.column_name,
@@ -258,31 +294,32 @@ def __next__(self):
next = __next__
@staticmethod
- def _apply_parameters(operation, parameters) -> str:
- """Applies parameters to operation string"""
+ def _apply_parameters(operation, parameters: dict) -> str:
+ """Applies parameters to operation string."""
if not parameters:
return operation
- escaped_parameters = {key: Cursor._escape(value) for key, value in parameters.items()}
+ escaped_parameters = {
+ key: Cursor._escape(value) for key, value in parameters.items()
+ }
return operation % escaped_parameters
@staticmethod
- def _escape(value) -> str:
+ def _escape(value: Any) -> str:
"""
Escape the parameter value.
Note that bool is a subclass of int so order of statements matter.
"""
-
if value == "*":
return value
if isinstance(value, str):
return "'{}'".format(value.replace("'", "''"))
if isinstance(value, bool):
return "TRUE" if value else "FALSE"
- if isinstance(value, (int, float)):
+ if isinstance(value, int | float):
return str(value)
- if isinstance(value, (list, tuple)):
+ if isinstance(value, list | tuple):
return ", ".join(Cursor._escape(element) for element in value)
return value
diff --git a/sqlalchemy_kusto/dialect_base.py b/sqlalchemy_kusto/dialect_base.py
index ce874ba..0438f95 100644
--- a/sqlalchemy_kusto/dialect_base.py
+++ b/sqlalchemy_kusto/dialect_base.py
@@ -1,12 +1,20 @@
import json
from abc import ABC
from types import ModuleType
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any
from sqlalchemy.engine import Connection, default
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import compiler
-from sqlalchemy.types import DATE, TIMESTAMP, BigInteger, Boolean, Float, Integer, String
+from sqlalchemy.types import (
+ DATE,
+ TIMESTAMP,
+ BigInteger,
+ Boolean,
+ Float,
+ Integer,
+ String,
+)
import sqlalchemy_kusto
@@ -56,7 +64,7 @@ class KustoBaseDialect(default.DefaultDialect, ABC):
description_encoding = None
supports_native_boolean = True
supports_simple_order_by_label = True
- _map_parse_connection_parameters: Dict[str, Any] = {
+ _map_parse_connection_parameters: dict[str, Any] = {
"msi": parse_bool_argument,
"azure_ad_client_id": str,
"azure_ad_client_secret": str,
@@ -66,11 +74,11 @@ class KustoBaseDialect(default.DefaultDialect, ABC):
}
@classmethod
- def dbapi(cls) -> ModuleType: # pylint: disable-msg=method-hidden
+ def dbapi(cls) -> ModuleType:
return sqlalchemy_kusto
- def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
- kwargs: Dict[str, Any] = {
+ def create_connect_args(self, url: URL) -> tuple[list[Any], dict[str, Any]]:
+ kwargs: dict[str, Any] = {
"cluster": "https://" + url.host,
"database": url.database,
}
@@ -84,21 +92,33 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
return [], kwargs
- def get_schema_names(self, connection: Connection, **kwargs) -> List[str]:
+ def get_schema_names(self, connection: Connection, **kwargs) -> list[str]:
result = connection.execute(".show databases | project DatabaseName")
return [row.DatabaseName for row in result]
- def has_table(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs) -> bool:
+ def has_table(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: str | None = None,
+ **kwargs,
+ ) -> bool:
return table_name in self.get_table_names(connection, schema)
- def get_table_names(self, connection: Connection, schema: Optional[str] = None, **kwargs) -> List[str]:
+ def get_table_names(
+ self, connection: Connection, schema: str | None = None, **kwargs
+ ) -> list[str]:
# Schema is not used in Kusto cause database is written in the connection string
result = connection.execute(".show tables | project TableName")
return [row.TableName for row in result]
def get_columns(
- self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs
- ) -> List[Dict[str, Any]]:
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: str | None = None,
+ **kwargs,
+ ) -> list[dict[str, Any]]:
table_search_query = f"""
.show tables
| where TableName == "{table_name}"
@@ -117,16 +137,23 @@ def get_columns(
query_result = connection.execute(function_schema)
rows = list(query_result)
entity_schema = json.loads(rows[0].Schema)
- return [self.schema_definition(column) for column in entity_schema["OutputColumns"]]
- entity_type = "table" if table_search_result.rowcount == 1 else "materialized-view"
+ return [
+ self.schema_definition(column)
+ for column in entity_schema["OutputColumns"]
+ ]
+ entity_type = (
+ "table" if table_search_result.rowcount == 1 else "materialized-view"
+ )
query = f".show {entity_type} {table_name} schema as json"
query_result = connection.execute(query)
rows = list(query_result)
entity_schema = json.loads(rows[0].Schema)
- return [self.schema_definition(column) for column in entity_schema["OrderedColumns"]]
+ return [
+ self.schema_definition(column) for column in entity_schema["OrderedColumns"]
+ ]
@staticmethod
- def schema_definition(column):
+ def schema_definition(column) -> dict:
return {
"name": column["Name"],
"type": kql_to_sql_types[column["CslType"].lower()],
@@ -134,39 +161,65 @@ def schema_definition(column):
"default": "",
}
- def get_view_names(self, connection: Connection, schema: Optional[str] = None, **kwargs) -> List[str]:
- materialized_views = connection.execute(".show materialized-views | project Name")
+ def get_view_names(
+ self, connection: Connection, schema: str | None = None, **kwargs
+ ) -> list[str]:
+ materialized_views = connection.execute(
+ ".show materialized-views | project Name"
+ )
# Functions are also Views.
# Filtering no input functions specifically here as there is no way to pass parameters today
- functions = connection.execute(".show functions | where Parameters =='()' | project Name")
+ functions = connection.execute(
+ ".show functions | where Parameters =='()' | project Name"
+ )
materialized_view = [row.Name for row in materialized_views]
view = [row.Name for row in functions]
return materialized_view + view
- def get_pk_constraint(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw):
+ def get_pk_constraint(
+ self, connection: Connection, table_name: str, schema: str | None = None, **kw
+ ):
return {"constrained_columns": [], "name": None}
def get_foreign_keys(self, connection, table_name, schema=None, **kwargs):
return []
- def get_check_constraints(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs):
+ def get_check_constraints(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: str | None = None,
+ **kwargs,
+ ):
return []
def get_table_comment(
- self, connection: Connection, table_name, schema: Optional[str] = None, **kwargs
- ) -> Dict[str, Any]:
- """Not implemented"""
+ self, connection: Connection, table_name, schema: str | None = None, **kwargs
+ ) -> dict[str, Any]:
+ """Not implemented."""
return {"text": ""}
def get_indexes(
- self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs
- ) -> List[Dict[str, Any]]:
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: str | None = None,
+ **kwargs,
+ ) -> list[dict[str, Any]]:
return []
- def get_unique_constraints(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs):
+ def get_unique_constraints(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: str | None = None,
+ **kwargs,
+ ):
return []
- def _check_unicode_returns(self, connection: Connection, additional_tests: List[Any] = None) -> bool:
+ def _check_unicode_returns(
+ self, connection: Connection, additional_tests: list[Any] | None = None
+ ) -> bool:
return True
def _check_unicode_description(self, connection: Connection) -> bool:
@@ -176,9 +229,10 @@ def do_ping(self, dbapi_connection: sqlalchemy_kusto.dbapi.Connection):
try:
query = ".show tables"
dbapi_connection.execute(query)
- return True
except sqlalchemy_kusto.OperationalError:
return False
+ else:
+ return True
def do_rollback(self, dbapi_connection: sqlalchemy_kusto.dbapi.Connection):
pass
@@ -225,7 +279,13 @@ def set_isolation_level(self, dbapi_conn, level):
def get_isolation_level(self, dbapi_conn):
pass
- def get_view_definition(self, connection: Connection, view_name: str, schema: Optional[str] = None, **kwargs):
+ def get_view_definition(
+ self,
+ connection: Connection,
+ view_name: str,
+ schema: str | None = None,
+ **kwargs,
+ ):
pass
def get_primary_keys(self, connection, table_name, schema=None, **kw):
diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py
index c60b43a..efbec5e 100644
--- a/sqlalchemy_kusto/dialect_kql.py
+++ b/sqlalchemy_kusto/dialect_kql.py
@@ -1,6 +1,5 @@
import logging
import re
-from typing import List, Optional, Tuple
from sqlalchemy import Column, exc
from sqlalchemy.sql import compiler, operators, selectable
@@ -17,7 +16,7 @@
class UniversalSet:
- def __contains__(self, item):
+ def __contains__(self, item) -> bool:
return True
@@ -25,7 +24,7 @@ class KustoKqlIdentifierPreparer(compiler.IdentifierPreparer):
# We want to quote all table and column names to prevent unconventional names usage
reserved_words = UniversalSet()
- def __init__(self, dialect, **kw):
+ def __init__(self, dialect, **kw) -> None:
super().__init__(dialect, initial_quote='["', final_quote='"]', **kw)
@@ -48,11 +47,13 @@ def visit_select(
lateral=False,
from_linter=None,
**kwargs,
- ):
+ ) -> str:
logger.debug("Incoming query: %s", select_stmt)
if len(select_stmt.get_final_froms()) != 1:
- raise NotSupportedError('Only single "select from" query is supported in kql compiler')
+ raise NotSupportedError(
+ 'Only single "select from" query is supported in kql compiler'
+ )
compiled_query_lines = []
@@ -61,7 +62,9 @@ def visit_select(
query = self._get_most_inner_element(from_object.element)
(main, lets) = self._extract_let_statements(query.text)
compiled_query_lines.extend(lets)
- compiled_query_lines.append(f"let {from_object.name} = ({self._convert_schema_in_statement(main)});")
+ compiled_query_lines.append(
+ f"let {from_object.name} = ({self._convert_schema_in_statement(main)});"
+ )
compiled_query_lines.append(from_object.name)
elif hasattr(from_object, "name"):
if from_object.schema is not None:
@@ -70,7 +73,9 @@ def visit_select(
unquoted_name = from_object.name.strip("\"'")
compiled_query_lines.append(f'["{unquoted_name}"]')
else:
- compiled_query_lines.append(self._convert_schema_in_statement(from_object.text))
+ compiled_query_lines.append(
+ self._convert_schema_in_statement(from_object.text)
+ )
if select_stmt._whereclause is not None:
where_clause = select_stmt._whereclause._compiler_dispatch(self, **kwargs)
@@ -81,11 +86,11 @@ def visit_select(
if projections:
compiled_query_lines.append(projections)
- if select_stmt._limit_clause is not None: # pylint: disable=protected-access
+ if select_stmt._limit_clause is not None:
kwargs["literal_execute"] = True
compiled_query_lines.append(
f"| take {self.process(select_stmt._limit_clause, **kwargs)}"
- ) # pylint: disable=protected-access
+ )
compiled_query_lines = list(filter(None, compiled_query_lines))
@@ -97,7 +102,7 @@ def limit_clause(self, select, **kw):
return ""
def _get_projection_or_summarize(self, select: selectable.Select) -> str:
- """Builds the ending part of the query either project or summarize"""
+ """Builds the ending part of the query either project or summarize."""
columns = select.inner_columns
if columns is not None:
column_labels = []
@@ -108,10 +113,14 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str:
if column_name in aggregates_sql_to_kql:
is_summarize = True
column_labels.append(
- self._build_column_projection(aggregates_sql_to_kql[column_name], column_alias)
+ self._build_column_projection(
+ aggregates_sql_to_kql[column_name], column_alias
+ )
)
else:
- column_labels.append(self._build_column_projection(column_name, column_alias))
+ column_labels.append(
+ self._build_column_projection(column_name, column_alias)
+ )
if column_labels:
projection_type = "summarize" if is_summarize else "project"
@@ -119,7 +128,7 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str:
return ""
def _get_most_inner_element(self, clause):
- """Finds the most nested element in clause"""
+ """Finds the most nested element in clause."""
inner_element = getattr(clause, "element", None)
if inner_element is not None:
return self._get_most_inner_element(inner_element)
@@ -127,8 +136,8 @@ def _get_most_inner_element(self, clause):
return clause
@staticmethod
- def _extract_let_statements(clause) -> Tuple[str, List[str]]:
- """Separates the final query from let statements"""
+ def _extract_let_statements(clause) -> tuple[str, list[str]]:
+ """Separates the final query from let statements."""
rows = [s.strip() for s in clause.split(";")]
main = next(filter(lambda row: not row.startswith("let"), rows), None)
@@ -139,15 +148,17 @@ def _extract_let_statements(clause) -> Tuple[str, List[str]]:
return main, lets
@staticmethod
- def _extract_column_name_and_alias(column: Column) -> Tuple[str, Optional[str]]:
+ def _extract_column_name_and_alias(column: Column) -> tuple[str, str | None]:
if hasattr(column, "element"):
return column.element.name, column.name
return column.name, None
@staticmethod
- def _build_column_projection(column_name: str, column_alias: str = None):
- """Generates column alias semantic for project statement"""
+ def _build_column_projection(
+ column_name: str, column_alias: str | None = None
+ ) -> str:
+ """Generates column alias semantic for project statement."""
return f"{column_alias} = {column_name}" if column_alias else column_name
@staticmethod
@@ -166,7 +177,6 @@ def _convert_schema_in_statement(query: str) -> str:
- ["schema"].["table"] -> database("schema").["table"]
- ["table"] -> ["table"]
"""
-
pattern = r"^\[?([a-zA-Z0-9]+\b|\"[a-zA-Z0-9 \-_.]+\")?\]?\.?\[?([a-zA-Z0-9]+\b|\"[a-zA-Z0-9 \-_.]+\")\]?"
match = re.search(pattern, query)
if not match:
@@ -179,7 +189,9 @@ def _convert_schema_in_statement(query: str) -> str:
return query.replace(original, f'["{unquoted_table}"]', 1)
unquoted_schema = match.group(1).strip("\"'")
- return query.replace(original, f'database("{unquoted_schema}").["{unquoted_table}"]', 1)
+ return query.replace(
+ original, f'database("{unquoted_schema}").["{unquoted_table}"]', 1
+ )
class KustoKqlHttpsDialect(KustoBaseDialect):
diff --git a/sqlalchemy_kusto/dialect_sql.py b/sqlalchemy_kusto/dialect_sql.py
index 6404080..6ce2de8 100644
--- a/sqlalchemy_kusto/dialect_sql.py
+++ b/sqlalchemy_kusto/dialect_sql.py
@@ -5,17 +5,17 @@
class KustoSqlCompiler(compiler.SQLCompiler):
def get_select_precolumns(self, select, **kw) -> str:
- """Kusto uses TOP instead of LIMIT"""
+ """Kusto uses TOP instead of LIMIT."""
select_precolumns = super().get_select_precolumns(select, **kw)
if select._limit_clause is not None:
kw["literal_execute"] = True
- select_precolumns += "TOP %s " % self.process(select._limit_clause, **kw)
+ select_precolumns += f"TOP {self.process(select._limit_clause, **kw)} "
return select_precolumns
def limit_clause(self, select, **kw):
- """Do not add LIMIT to the end of the query"""
+ """Do not add LIMIT to the end of the query."""
return ""
def visit_sequence(self, sequence, **kw):
@@ -24,16 +24,21 @@ def visit_sequence(self, sequence, **kw):
def visit_empty_set_expr(self, element_types):
pass
- def update_from_clause(self, update_stmt, from_table, extra_froms, from_hints, **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
pass
- def delete_extra_from_clause(self, update_stmt, from_table, extra_froms, from_hints, **kw):
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
pass
class KustoSqlHttpsDialect(KustoBaseDialect):
name = "kustosql"
statement_compiler = KustoSqlCompiler
- # For some reason supports_statement_cache doesn't work when defined in the KustoBaseDialect.
+ # For some reason supports_statement_cache
+ # doesn't work when defined in the KustoBaseDialect.
# Need to investigate why it happens.
supports_statement_cache = True
diff --git a/sqlalchemy_kusto/errors.py b/sqlalchemy_kusto/errors.py
index ef40a75..70116d1 100644
--- a/sqlalchemy_kusto/errors.py
+++ b/sqlalchemy_kusto/errors.py
@@ -2,7 +2,7 @@ class Error(Exception):
pass
-class Warning(Exception): # pylint: disable-msg=redefined-builtin
+class Warning(Exception):
pass
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index 50e0f3c..90a2bcb 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -3,8 +3,12 @@
from dotenv import load_dotenv
from sqlalchemy.dialects import registry
-registry.register("kustosql.https", "sqlalchemy_kusto.dialect_sql", "KustoSqlHttpsDialect")
-registry.register("kustokql.https", "sqlalchemy_kusto.dialect_kql", "KustoKqlHttpsDialect")
+registry.register(
+ "kustosql.https", "sqlalchemy_kusto.dialect_sql", "KustoSqlHttpsDialect"
+)
+registry.register(
+ "kustokql.https", "sqlalchemy_kusto.dialect_kql", "KustoKqlHttpsDialect"
+)
load_dotenv()
AZURE_AD_CLIENT_ID = os.environ.get("AZURE_AD_CLIENT_ID", "")
diff --git a/tests/integration/test_dbapi.py b/tests/integration/test_dbapi.py
index f68fe4a..94c7d39 100644
--- a/tests/integration/test_dbapi.py
+++ b/tests/integration/test_dbapi.py
@@ -8,20 +8,21 @@
)
-def test_connect():
+def test_connect() -> None:
connection = connect("test", DATABASE, True)
assert connection is not None
-def test_execute():
+def test_execute() -> None:
connection = connect(
KUSTO_URL,
DATABASE,
False,
+ False,
None,
azure_ad_client_id=AZURE_AD_CLIENT_ID,
azure_ad_client_secret=AZURE_AD_CLIENT_SECRET,
azure_ad_tenant_id=AZURE_AD_TENANT_ID,
)
- result = connection.execute(f"select 1").fetchall()
+ result = connection.execute("select 1").fetchall()
assert result is not None
diff --git a/tests/integration/test_dialect_sql.py b/tests/integration/test_dialect_sql.py
index 07168d2..4f26563 100644
--- a/tests/integration/test_dialect_sql.py
+++ b/tests/integration/test_dialect_sql.py
@@ -1,7 +1,13 @@
+from collections.abc import Generator
+from typing import Any
import uuid
import pytest
-from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder
+from azure.kusto.data import (
+ ClientRequestProperties,
+ KustoClient,
+ KustoConnectionStringBuilder,
+)
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from tests.integration.conftest import (
@@ -21,31 +27,31 @@
)
-def test_ping():
+def test_ping() -> None:
conn = engine.connect()
result = engine.dialect.do_ping(conn)
assert result is True
-def test_get_table_names(temp_table_name):
+def test_get_table_names(temp_table_name: str) -> None:
conn = engine.connect()
result = engine.dialect.get_table_names(conn)
assert temp_table_name in result
-def test_get_view_names(temp_table_name):
+def test_get_view_names(temp_table_name: str) -> None:
conn = engine.connect()
result = engine.dialect.get_view_names(conn)
assert f"{temp_table_name}_fn" in result
-def test_get_columns(temp_table_name):
+def test_get_columns(temp_table_name: str) -> None:
conn = engine.connect()
columns_result = engine.dialect.get_columns(conn, temp_table_name)
- assert set(["Id", "Text"]) == set([c["name"] for c in columns_result])
+ assert {"Id", "Text"} == {c["name"] for c in columns_result}
-def test_fetch_one(temp_table_name):
+def test_fetch_one(temp_table_name: str) -> None:
engine.connect()
result = engine.execute(f"select top 2 * from {temp_table_name} order by Id")
assert result.fetchone() == (1, "value_1")
@@ -53,21 +59,33 @@ def test_fetch_one(temp_table_name):
assert result.fetchone() is None
-def test_fetch_many(temp_table_name):
+def test_fetch_many(temp_table_name: str) -> None:
engine.connect()
result = engine.execute(f"select top 5 * from {temp_table_name} order by Id")
- assert set([(x[0], x[1]) for x in result.fetchmany(3)]) == set([(1, "value_1"), (2, "value_2"), (3, "value_3")])
- assert set([(x[0], x[1]) for x in result.fetchmany(3)]) == set([(4, "value_4"), (5, "value_5")])
+ assert {(x[0], x[1]) for x in result.fetchmany(3)} == {
+ (1, "value_1"),
+ (2, "value_2"),
+ (3, "value_3"),
+ }
+ assert {(x[0], x[1]) for x in result.fetchmany(3)} == {
+ (4, "value_4"),
+ (5, "value_5"),
+ }
-def test_fetch_all(temp_table_name):
+def test_fetch_all(temp_table_name: str) -> None:
engine.connect()
result = engine.execute(f"select top 3 * from {temp_table_name} order by Id")
- assert set([(x[0], x[1]) for x in result.fetchall()]) == set([(1, "value_1"), (2, "value_2"), (3, "value_3")])
+ assert {(x[0], x[1]) for x in result.fetchall()} == {
+ (1, "value_1"),
+ (2, "value_2"),
+ (3, "value_3"),
+ }
-def test_limit(temp_table_name):
+def test_limit(temp_table_name: str) -> None:
+ limit = 5
stream = Table(
temp_table_name,
MetaData(),
@@ -75,61 +93,72 @@ def test_limit(temp_table_name):
Column("Text", String),
)
- query = stream.select().limit(5)
+ query = stream.select().limit(limit)
engine.connect()
result = engine.execute(query)
result_length = len(result.fetchall())
- assert result_length == 5
+ assert result_length == limit
-def get_kcsb():
+def get_kcsb() -> Any:
return (
KustoConnectionStringBuilder.with_az_cli_authentication(KUSTO_URL)
- if not AZURE_AD_CLIENT_ID and not AZURE_AD_CLIENT_SECRET and not AZURE_AD_TENANT_ID
+ if not AZURE_AD_CLIENT_ID
+ and not AZURE_AD_CLIENT_SECRET
+ and not AZURE_AD_TENANT_ID
else KustoConnectionStringBuilder.with_aad_application_key_authentication(
KUSTO_URL, AZURE_AD_CLIENT_ID, AZURE_AD_CLIENT_SECRET, AZURE_AD_TENANT_ID
)
)
-def _create_temp_table(table_name: str):
+def _create_temp_table(table_name: str) -> None:
client = KustoClient(get_kcsb())
- response = client.execute(DATABASE, f".create table {table_name}(Id: int, Text: string)", ClientRequestProperties())
+ client.execute(
+ DATABASE,
+ f".create table {table_name}(Id: int, Text: string)",
+ ClientRequestProperties(),
+ )
-def _create_temp_fn(fn_name: str):
+def _create_temp_fn(fn_name: str) -> None:
client = KustoClient(get_kcsb())
- response = client.execute(DATABASE, f".create function {fn_name}() {{ print now()}}", ClientRequestProperties())
+ client.execute(
+ DATABASE,
+ f".create function {fn_name}() {{ print now()}}",
+ ClientRequestProperties(),
+ )
-def _ingest_data_to_table(table_name: str):
+def _ingest_data_to_table(table_name: str) -> None:
client = KustoClient(get_kcsb())
data_to_ingest = {i: "value_" + str(i) for i in range(1, 10)}
str_data = "\n".join("{},{}".format(*p) for p in data_to_ingest.items())
ingest_query = f""".ingest inline into table {table_name} <|
{str_data}"""
- response = client.execute(DATABASE, ingest_query, ClientRequestProperties())
+ client.execute(DATABASE, ingest_query, ClientRequestProperties())
-def _drop_table(table_name: str):
+def _drop_table(table_name: str) -> None:
client = KustoClient(get_kcsb())
_ = client.execute(DATABASE, f".drop table {table_name}", ClientRequestProperties())
- _ = client.execute(DATABASE, f".drop function {table_name}_fn", ClientRequestProperties())
+ _ = client.execute(
+ DATABASE, f".drop function {table_name}_fn", ClientRequestProperties()
+ )
-@pytest.fixture()
-def temp_table_name():
+@pytest.fixture
+def temp_table_name() -> str:
return "_temp_" + uuid.uuid4().hex
@pytest.fixture(autouse=True)
-def run_around_tests(temp_table_name):
+def run_around_tests(temp_table_name: str) -> Generator[str, None, None]:
_create_temp_table(temp_table_name)
_create_temp_fn(f"{temp_table_name}_fn")
_ingest_data_to_table(temp_table_name)
# A test function will be run at this point
yield temp_table_name
_drop_table(temp_table_name)
- # assert files_before == files_after
diff --git a/tests/integration/test_error_handling.py b/tests/integration/test_error_handling.py
index 38079e5..0a618b2 100644
--- a/tests/integration/test_error_handling.py
+++ b/tests/integration/test_error_handling.py
@@ -5,7 +5,7 @@
from tests.integration.conftest import DATABASE, KUSTO_SQL_ALCHEMY_URL
-def test_operational_error():
+def test_operational_error() -> None:
wrong_tenant_id = "wrong_tenant_id"
azure_ad_client_id = "x"
azure_ad_client_secret = "x"
diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py
index e621075..71824c5 100644
--- a/tests/unit/test_dialect_kql.py
+++ b/tests/unit/test_dialect_kql.py
@@ -1,12 +1,22 @@
import pytest
import sqlalchemy as sa
-from sqlalchemy import Column, MetaData, String, Table, column, create_engine, literal_column, select, text
+from sqlalchemy import (
+ Column,
+ MetaData,
+ String,
+ Table,
+ column,
+ create_engine,
+ literal_column,
+ select,
+ text,
+)
from sqlalchemy.sql.selectable import TextAsFrom
engine = create_engine("kustokql+https://localhost/testdb")
-def test_compiler_with_projection():
+def test_compiler_with_projection() -> None:
statement_str = "logs | take 10"
stmt = TextAsFrom(sa.text(statement_str), []).alias("virtual_table")
query = sa.select(
@@ -31,7 +41,7 @@ def test_compiler_with_projection():
assert query_compiled == query_expected
-def test_compiler_with_star():
+def test_compiler_with_star() -> None:
statement_str = "logs | take 10"
stmt = TextAsFrom(sa.text(statement_str), []).alias("virtual_table")
query = sa.select(
@@ -42,20 +52,30 @@ def test_compiler_with_star():
query = query.limit(10)
query_compiled = str(query.compile(engine)).replace("\n", "")
- query_expected = 'let virtual_table = (["logs"] | take 10);' "virtual_table" "| take __[POSTCOMPILE_param_1]"
+ query_expected = (
+ 'let virtual_table = (["logs"] | take 10);'
+ "virtual_table"
+ "| take __[POSTCOMPILE_param_1]"
+ )
assert query_compiled == query_expected
-def test_select_from_text():
- query = select([column("Field1"), column("Field2")]).select_from(text("logs")).limit(100)
- query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
+def test_select_from_text() -> None:
+ query = (
+ select([column("Field1"), column("Field2")])
+ .select_from(text("logs"))
+ .limit(100)
+ )
+ query_compiled = str(
+ query.compile(engine, compile_kwargs={"literal_binds": True})
+ ).replace("\n", "")
query_expected = '["logs"]' "| project Field1, Field2" "| take 100"
assert query_compiled == query_expected
-def test_use_table():
+def test_use_table() -> None:
metadata = MetaData()
stream = Table(
"logs",
@@ -67,23 +87,31 @@ def test_use_table():
query = stream.select().limit(5)
query_compiled = str(query.compile(engine)).replace("\n", "")
- query_expected = '["logs"]' "| project Field1, Field2" "| take __[POSTCOMPILE_param_1]"
+ query_expected = (
+ '["logs"]' "| project Field1, Field2" "| take __[POSTCOMPILE_param_1]"
+ )
assert query_compiled == query_expected
-def test_limit():
+def test_limit() -> None:
sql = "logs"
limit = 5
- query = select("*").select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")).limit(limit)
+ query = (
+ select("*")
+ .select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry"))
+ .limit(limit)
+ )
- query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
+ query_compiled = str(
+ query.compile(engine, compile_kwargs={"literal_binds": True})
+ ).replace("\n", "")
query_expected = 'let inner_qry = (["logs"]);' "inner_qry" "| take 5"
assert query_compiled == query_expected
-def test_select_count():
+def test_select_count() -> None:
kql_query = "logs"
column_count = literal_column("count(*)").label("count")
query = (
@@ -95,7 +123,9 @@ def test_select_count():
.limit(5)
)
- query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
+ query_compiled = str(
+ query.compile(engine, compile_kwargs={"literal_binds": True})
+ ).replace("\n", "")
query_expected = (
'let inner_qry = (["logs"]);'
@@ -108,11 +138,17 @@ def test_select_count():
assert query_compiled == query_expected
-def test_select_with_let():
+def test_select_with_let() -> None:
kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y"
- query = select("*").select_from(TextAsFrom(text(kql_query), ["*"]).alias("inner_qry")).limit(5)
+ query = (
+ select("*")
+ .select_from(TextAsFrom(text(kql_query), ["*"]).alias("inner_qry"))
+ .limit(5)
+ )
- query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
+ query_compiled = str(
+ query.compile(engine, compile_kwargs={"literal_binds": True})
+ ).replace("\n", "")
query_expected = (
"let x = 5;"
@@ -125,7 +161,7 @@ def test_select_with_let():
assert query_compiled == query_expected
-def test_quotes():
+def test_quotes() -> None:
quote = engine.dialect.identifier_preparer.quote
metadata = MetaData()
stream = Table(
@@ -150,7 +186,7 @@ def test_quotes():
@pytest.mark.parametrize(
- "schema_name,table_name,expected_table_name",
+ ("schema_name", "table_name", "expected_table_name"),
[
("schema", "table", 'database("schema").["table"]'),
("schema", '"table.name"', 'database("schema").["table.name"]'),
@@ -161,7 +197,9 @@ def test_quotes():
(None, "MyTable", '["MyTable"]'),
],
)
-def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_name: str):
+def test_schema_from_metadata(
+ table_name: str, schema_name: str, expected_table_name: str
+) -> None:
metadata = MetaData(schema=schema_name) if schema_name else MetaData()
stream = Table(
table_name,
@@ -176,7 +214,7 @@ def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_
@pytest.mark.parametrize(
- "query_table_name,expected_table_name",
+ ("query_table_name", "expected_table_name"),
[
("schema.table", 'database("schema").["table"]'),
('schema."table.name"', 'database("schema").["table.name"]'),
@@ -189,10 +227,16 @@ def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_
('["table"]', '["table"]'),
],
)
-def test_schema_from_query(query_table_name: str, expected_table_name: str):
- query = select("*").select_from(TextAsFrom(text(query_table_name), ["*"]).alias("inner_qry")).limit(5)
+def test_schema_from_query(query_table_name: str, expected_table_name: str) -> None:
+ query = (
+ select("*")
+ .select_from(TextAsFrom(text(query_table_name), ["*"]).alias("inner_qry"))
+ .limit(5)
+ )
- query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
+ query_compiled = str(
+ query.compile(engine, compile_kwargs={"literal_binds": True})
+ ).replace("\n", "")
query_expected = f"let inner_qry = ({expected_table_name});inner_qry| take 5"
assert query_compiled == query_expected