Skip to content

Commit

Permalink
Add type hints to types module and enable mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
hashhar committed Feb 5, 2024
1 parent 7c2a0f4 commit 1a9226b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ ignore_missing_imports = true
no_implicit_optional = true
warn_unused_ignores = true

[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*,trino.types]
[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*]
ignore_errors = true
3 changes: 1 addition & 2 deletions trino/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,9 @@ def __init__(self, mappers: List[ValueMapper[Any]], names: List[Optional[str]],
def map(self, value: Optional[List[Any]]) -> Optional[Tuple[Optional[Any], ...]]:
if value is None:
return None
# TODO: Fix typing for self.names
return NamedRowTuple(
list(self.mappers[i].map(v) for i, v in enumerate(value)),
self.names, # type: ignore
self.names,
self.types
)

Expand Down
25 changes: 14 additions & 11 deletions trino/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
from datetime import datetime, time, timedelta
from decimal import Decimal
from typing import Any, Dict, Generic, List, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast

from dateutil import tz

Expand All @@ -26,15 +26,16 @@ def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> Temporal
def to_python_type(self) -> PythonTemporalType:
pass

def round_to(self, precision: int) -> TemporalType:
def round_to(self, precision: int) -> TemporalType[PythonTemporalType]:
"""
Python datetime and time only support up to microsecond precision
In case the supplied value exceeds the specified precision,
the value needs to be rounded.
"""
precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER)
remaining_fractional_seconds = self._remaining_fractional_seconds
digits = abs(remaining_fractional_seconds.as_tuple().exponent)
# exponent can return `n`, `N`, `F` too if the value is a NaN for example
digits = abs(remaining_fractional_seconds.as_tuple().exponent) # type: ignore
if digits > precision:
rounding_factor = POWERS_OF_TEN[precision]
rounded = remaining_fractional_seconds.quantize(Decimal(1 / rounding_factor))
Expand Down Expand Up @@ -101,33 +102,35 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ

def normalize(self, value: datetime) -> datetime:
if tz.datetime_ambiguous(value):
return self._whole_python_temporal_value.tzinfo.normalize(value)
# This appears to be dead code since tzinfo doesn't actually have a `normalize` method.
# TODO: Fix this or remove. (https://github.com/trinodb/trino-python-client/issues/449)
return self._whole_python_temporal_value.tzinfo.normalize(value) # type: ignore
return value


class NamedRowTuple(tuple):
class NamedRowTuple(Tuple[Any, ...]):
"""Custom tuple class as namedtuple doesn't support missing or duplicate names"""
def __new__(cls, values, names: List[str], types: List[str]):
return super().__new__(cls, values)
def __new__(cls, values: List[Any], names: List[str], types: List[str]) -> NamedRowTuple:
return cast(NamedRowTuple, super().__new__(cls, values))

def __init__(self, values, names: List[str], types: List[str]):
def __init__(self, values: List[Any], names: List[Optional[str]], types: List[str]):
self._names = names
# With names and types users can retrieve the name and Trino data type of a row
self.__annotations__ = dict()
self.__annotations__["names"] = names
self.__annotations__["types"] = types
elements: List[Any] = []
for name, value in zip(names, values):
if names.count(name) == 1:
if name is not None and names.count(name) == 1:
setattr(self, name, value)
elements.append(f"{name}: {repr(value)}")
else:
elements.append(repr(value))
self._repr = "(" + ", ".join(elements) + ")"

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if self._names.count(name):
raise ValueError("Ambiguous row field reference: " + name)

def __repr__(self):
def __repr__(self) -> str:
return self._repr

0 comments on commit 1a9226b

Please sign in to comment.