From a5b2d8d69ba0c7c632851daab0453e6f905efa0e Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Sat, 3 Feb 2024 17:05:31 +0530 Subject: [PATCH] Add type hints to types module and enable mypy --- setup.cfg | 2 +- trino/mapper.py | 3 +-- trino/types.py | 25 ++++++++++++++----------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7d391545..372d84b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,5 @@ ignore_missing_imports = true no_implicit_optional = true warn_unused_ignores = true -[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*,trino.types] +[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*] ignore_errors = true diff --git a/trino/mapper.py b/trino/mapper.py index 4ceba8b4..48a905a6 100644 --- a/trino/mapper.py +++ b/trino/mapper.py @@ -204,10 +204,9 @@ def __init__(self, mappers: List[ValueMapper[Any]], names: List[Optional[str]], def map(self, value: Optional[List[Any]]) -> Optional[Tuple[Optional[Any], ...]]: if value is None: return None - # TODO: Fix typing for self.names return NamedRowTuple( list(self.mappers[i].map(v) for i, v in enumerate(value)), - self.names, # type: ignore + self.names, self.types ) diff --git a/trino/types.py b/trino/types.py index 8a745f52..627b8d3d 100644 --- a/trino/types.py +++ b/trino/types.py @@ -3,7 +3,7 @@ import abc from datetime import datetime, time, timedelta from decimal import Decimal -from typing import Any, Dict, Generic, List, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast from dateutil import tz @@ -26,7 +26,7 @@ def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> Temporal def to_python_type(self) -> PythonTemporalType: pass - def round_to(self, precision: int) -> TemporalType: + def round_to(self, precision: int) -> TemporalType[PythonTemporalType]: """ Python datetime and time only support up to microsecond precision In case the supplied value exceeds the specified precision, @@ -34,7 +34,8 @@ def round_to(self, precision: int) -> TemporalType: """ precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER) remaining_fractional_seconds = self._remaining_fractional_seconds - digits = abs(remaining_fractional_seconds.as_tuple().exponent) + # exponent can return `n`, `N`, `F` too if the value is a NaN for example + digits = abs(remaining_fractional_seconds.as_tuple().exponent) # type: ignore if digits > precision: rounding_factor = POWERS_OF_TEN[precision] rounded = remaining_fractional_seconds.quantize(Decimal(1 / rounding_factor)) @@ -101,16 +102,18 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ def normalize(self, value: datetime) -> datetime: if tz.datetime_ambiguous(value): - return self._whole_python_temporal_value.tzinfo.normalize(value) + # This appears to be dead code since tzinfo doesn't actually have a `normalize` method. + # TODO: Fix this or remove. + return self._whole_python_temporal_value.tzinfo.normalize(value) # type: ignore return value -class NamedRowTuple(tuple): +class NamedRowTuple(Tuple[Any, ...]): """Custom tuple class as namedtuple doesn't support missing or duplicate names""" - def __new__(cls, values, names: List[str], types: List[str]): - return super().__new__(cls, values) + def __new__(cls, values: List[Any], names: List[str], types: List[str]) -> NamedRowTuple: + return cast(NamedRowTuple, super().__new__(cls, values)) - def __init__(self, values, names: List[str], types: List[str]): + def __init__(self, values: List[Any], names: List[Optional[str]], types: List[str]): self._names = names # With names and types users can retrieve the name and Trino data type of a row self.__annotations__ = dict() @@ -118,16 +121,16 @@ def __init__(self, values, names: List[str], types: List[str]): self.__annotations__["types"] = types elements: List[Any] = [] for name, value in zip(names, values): - if names.count(name) == 1: + if name is not None and names.count(name) == 1: setattr(self, name, value) elements.append(f"{name}: {repr(value)}") else: elements.append(repr(value)) self._repr = "(" + ", ".join(elements) + ")" - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if self._names.count(name): raise ValueError("Ambiguous row field reference: " + name) - def __repr__(self): + def __repr__(self) -> str: return self._repr