Skip to content

Commit

Permalink
chore: mypy for ak.forms
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Nov 1, 2023
1 parent a0ffb82 commit a7b8020
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 222 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ module = [
'awkward._nplikes.*',
'awkward._behavior.*',
'awkward._backends.*',
'awkward.forms.*',
]
ignore_errors = false
ignore_missing_imports = true
Expand Down
16 changes: 5 additions & 11 deletions src/awkward/_nplikes/numpylike.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ class NumpyMetadata(PublicSingleton):
str_ = numpy.str_
bytes_ = numpy.bytes_

datetime64 = numpy.datetime64
timedelta64 = numpy.timedelta64

intp = numpy.intp
integer = numpy.integer
signedinteger = numpy.signedinteger
Expand All @@ -228,11 +231,8 @@ class NumpyMetadata(PublicSingleton):
inf = numpy.inf

nat = numpy.datetime64("NaT")
datetime_data = numpy.datetime_data

@property
def issubdtype(self):
return numpy.issubdtype
datetime_data = staticmethod(numpy.datetime_data)
issubdtype = staticmethod(numpy.issubdtype)

AxisError = numpy.AxisError

Expand All @@ -246,12 +246,6 @@ def issubdtype(self):
if hasattr(numpy, "complex256"):
NumpyMetadata.complex256 = numpy.complex256 # type: ignore[attr-defined]

if hasattr(numpy, "datetime64"):
NumpyMetadata.datetime64 = numpy.datetime64 # type: ignore[attr-defined]

if hasattr(numpy, "timedelta64"):
NumpyMetadata.timedelta64 = numpy.timedelta64 # type: ignore[attr-defined]


class UfuncLike(Protocol):
nargs: int
Expand Down
26 changes: 6 additions & 20 deletions src/awkward/forms/bitmaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import awkward as ak
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._parameters import type_parameters_equal
from awkward._typing import Iterator, JSONSerializable, Self, final
from awkward._typing import DType, Iterator, JSONSerializable, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, index_to_dtype
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

np = NumpyMetadata.instance()

Expand Down Expand Up @@ -209,24 +209,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None:
if next_content is None:
return None
else:
return BitMaskedForm(
self._mask,
next_content,
self._valid_when,
self._lsb_order,
parameters=self._parameters,
form_key=self._form_key,
)
return self.copy(content=next_content)

def _select_columns(self, match_specifier):
return BitMaskedForm(
self._mask,
self._content._select_columns(match_specifier),
self._valid_when,
self._lsb_order,
parameters=self._parameters,
form_key=self._form_key,
)
def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self:
return self.copy(content=self._content._select_columns(match_specifier))

def _column_types(self):
return self._content._column_types()
Expand Down Expand Up @@ -263,7 +249,7 @@ def __setstate__(self, state):

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
) -> Iterator[tuple[str, DType]]:
yield (getkey(self, "mask"), index_to_dtype[self._mask])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
24 changes: 6 additions & 18 deletions src/awkward/forms/bytemaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import awkward as ak
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._parameters import type_parameters_equal
from awkward._typing import Iterator, JSONSerializable, Self, final
from awkward._typing import DType, Iterator, JSONSerializable, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, index_to_dtype
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

np = NumpyMetadata.instance()

Expand Down Expand Up @@ -187,22 +187,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None:
if next_content is None:
return None
else:
return ByteMaskedForm(
self._mask,
next_content,
self._valid_when,
parameters=self._parameters,
form_key=self._form_key,
)
return self.copy(content=next_content)

def _select_columns(self, match_specifier):
return ByteMaskedForm(
self._mask,
self._content._select_columns(match_specifier),
self._valid_when,
parameters=self._parameters,
form_key=self._form_key,
)
def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self:
return self.copy(content=self._content._select_columns(match_specifier))

def _column_types(self):
return self._content._column_types()
Expand All @@ -226,7 +214,7 @@ def __setstate__(self, state):

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
) -> Iterator[tuple[str, DType]]:
yield (getkey(self, "mask"), index_to_dtype[self._mask])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
27 changes: 17 additions & 10 deletions src/awkward/forms/emptyform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from awkward._errors import deprecate
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._typing import Iterator, JSONSerializable, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, JSONMapping
from awkward._typing import DType, Iterator, JSONSerializable, Self, final
from awkward._util import UNSET, Sentinel
from awkward.forms.form import Form, JSONMapping, _SpecifierMatcher

np = NumpyMetadata.instance()

Expand All @@ -22,22 +22,29 @@ class EmptyForm(Form):
is_numpy = True
is_unknown = True

def __init__(self, *, parameters: JSONMapping | None = None, form_key=None):
def __init__(
self, *, parameters: JSONMapping | None = None, form_key: str | None = None
):
if not (parameters is None or len(parameters) == 0):
raise TypeError(f"{type(self).__name__} cannot contain parameters")
self._init(parameters=parameters, form_key=form_key)

def copy(
self, *, parameters: JSONMapping | None = UNSET, form_key=UNSET
self,
*,
parameters: JSONMapping | Sentinel | None = UNSET,
form_key: str | Sentinel | None = UNSET,
) -> EmptyForm:
if not (parameters is UNSET or parameters is None or len(parameters) == 0):
if not (parameters is UNSET or parameters is None or len(parameters) == 0): # type: ignore[arg-type]
raise TypeError(f"{type(self).__name__} cannot contain parameters")
return EmptyForm(
form_key=self._form_key if form_key is UNSET else form_key,
form_key=self._form_key if form_key is UNSET else form_key, # type: ignore[arg-type]
)

@classmethod
def simplified(cls, *, parameters=None, form_key=None) -> Form:
def simplified(
cls, *, parameters: JSONMapping | None = None, form_key: str | None = None
) -> Form:
if not (parameters is None or len(parameters) == 0):
raise TypeError(f"{cls.__name__} cannot contain parameters")
return cls(parameters=parameters, form_key=form_key)
Expand Down Expand Up @@ -123,7 +130,7 @@ def dimension_optiontype(self) -> bool:
def _columns(self, path, output, list_indicator):
output.append(".".join(path))

def _select_columns(self, match_specifier):
def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self:
return self

def _prune_columns(self, is_inside_record_or_union: bool) -> Self:
Expand Down Expand Up @@ -152,5 +159,5 @@ def __setstate__(self, state):

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
) -> Iterator[tuple[str, DType]]:
yield from ()
48 changes: 28 additions & 20 deletions src/awkward/forms/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._parameters import parameters_union
from awkward._typing import Final, Iterator, JSONMapping, JSONSerializable, Self
from awkward._typing import (
ClassVar,
DType,
Final,
Iterator,
JSONMapping,
JSONSerializable,
Self,
)

np = NumpyMetadata.instance()
numpy_backend = NumpyBackend.instance()
Expand Down Expand Up @@ -328,16 +336,14 @@ def __call__(self, field: str, *, next_match_if_empty: bool = False) -> Self | N
next_specifiers.extend(self._match_to_next_specifiers[pattern])

if has_matched:
return _SpecifierMatcher(
next_specifiers, match_if_empty=next_match_if_empty
)
return type(self)(next_specifiers, match_if_empty=next_match_if_empty)
elif self.is_empty and self._match_if_empty:
return self
else:
return
return None


def regularize_buffer_key(buffer_key: str | callable) -> Callable[[Form, str], str]:
def regularize_buffer_key(buffer_key: str | Callable) -> Callable[[Form, str], str]:
if isinstance(buffer_key, str):

def getkey(form, attribute):
Expand All @@ -358,7 +364,7 @@ def getkey(form, attribute):
)


index_to_dtype: Final[dict[str, np.dtype]] = {
index_to_dtype: Final[dict[str, DType]] = {
"i8": np.dtype("<i1"),
"u8": np.dtype("<u1"),
"i32": np.dtype("<i4"),
Expand All @@ -368,16 +374,16 @@ def getkey(form, attribute):


class Form:
is_numpy = False
is_unknown = False
is_list = False
is_regular = False
is_option = False
is_indexed = False
is_record = False
is_union = False

def _init(self, *, parameters, form_key):
is_numpy: ClassVar = False
is_unknown: ClassVar = False
is_list: ClassVar = False
is_regular: ClassVar = False
is_option: ClassVar = False
is_indexed: ClassVar = False
is_record: ClassVar = False
is_union: ClassVar = False

def _init(self, *, parameters: JSONMapping | None, form_key: str | None):
if parameters is not None and not isinstance(parameters, dict):
raise TypeError(
"{} 'parameters' must be of type dict or None, not {}".format(
Expand Down Expand Up @@ -511,6 +517,8 @@ def select_columns(
specifier = [[] if item == "" else item.split(".") for item in set(specifier)]
match_specifier = _SpecifierMatcher(specifier, match_if_empty=False)
selection = self._select_columns(match_specifier)
assert selection is not None, "top-level selections always return a Form"

if prune_unions_and_records:
return selection._prune_columns(False)
else:
Expand All @@ -522,10 +530,10 @@ def column_types(self):
def _columns(self, path, output, list_indicator):
raise NotImplementedError

def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None:
def _prune_columns(self, is_inside_record_or_union: bool) -> Form | None:
raise NotImplementedError

def _select_columns(self, match_specifier) -> Self | None:
def _select_columns(self, match_specifier: _SpecifierMatcher) -> Form | None:
raise NotImplementedError

def _column_types(self):
Expand Down Expand Up @@ -658,7 +666,7 @@ def prepare(form, multiplier):

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
) -> Iterator[tuple[str, DType]]:
raise NotImplementedError

def expected_from_buffers(
Expand Down
22 changes: 6 additions & 16 deletions src/awkward/forms/indexedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import awkward as ak
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._parameters import parameters_union, type_parameters_equal
from awkward._typing import Iterator, JSONSerializable, Self, final
from awkward._typing import DType, Iterator, JSONSerializable, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, index_to_dtype
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

np = NumpyMetadata.instance()

Expand Down Expand Up @@ -195,20 +195,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None:
if next_content is None:
return None
else:
return IndexedForm(
self._index,
next_content,
parameters=self._parameters,
form_key=self._form_key,
)
return self.copy(content=next_content)

def _select_columns(self, match_specifier):
return IndexedForm(
self._index,
self._content._select_columns(match_specifier),
parameters=self._parameters,
form_key=self._form_key,
)
def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self:
return self.copy(content=self._content._select_columns(match_specifier))

def _column_types(self):
return self._content._column_types()
Expand All @@ -230,7 +220,7 @@ def __setstate__(self, state):

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
) -> Iterator[tuple[str, DType]]:
yield (getkey(self, "index"), index_to_dtype[self._index])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
22 changes: 6 additions & 16 deletions src/awkward/forms/indexedoptionform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import awkward as ak
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._parameters import parameters_union, type_parameters_equal
from awkward._typing import Iterator, JSONSerializable, Self, final
from awkward._typing import DType, Iterator, JSONSerializable, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, index_to_dtype
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

np = NumpyMetadata.instance()

Expand Down Expand Up @@ -176,20 +176,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None:
if next_content is None:
return None
else:
return IndexedOptionForm(
self._index,
next_content,
parameters=self._parameters,
form_key=self._form_key,
)
return self.copy(content=next_content)

def _select_columns(self, match_specifier):
return IndexedOptionForm(
self._index,
self._content._select_columns(match_specifier),
parameters=self._parameters,
form_key=self._form_key,
)
def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self:
return self.copy(content=self._content._select_columns(match_specifier))

def _column_types(self):
return self._content._column_types()
Expand All @@ -211,7 +201,7 @@ def __setstate__(self, state):

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
) -> Iterator[tuple[str, DType]]:
yield (getkey(self, "index"), index_to_dtype[self._index])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
Loading

0 comments on commit a7b8020

Please sign in to comment.