Skip to content

Commit

Permalink
fix(ingest): use typing_extensions.Self (datahub-project#12230)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Dec 27, 2024
1 parent d042354 commit 4e3103e
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 31 deletions.
7 changes: 3 additions & 4 deletions metadata-ingestion/scripts/avro_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def merge_schemas(schemas_obj: List[dict]) -> str:
# Patch add_name method to NOT complain about duplicate names.
class NamesWithDups(avro.schema.Names):
def add_name(self, name_attr, space_attr, new_schema):

to_add = avro.schema.Name(name_attr, space_attr, self.default_namespace)
assert to_add.name
assert to_add.space
Expand Down Expand Up @@ -626,7 +625,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
class {class_name}(_SpecificUrn):
ENTITY_TYPE: ClassVar[str] = "{entity_type}"
URN_PARTS: ClassVar[int] = {arg_count}
_URN_PARTS: ClassVar[int] = {arg_count}
def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None:
if _allow_coercion:
Expand All @@ -640,8 +639,8 @@ def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None:
@classmethod
def _parse_ids(cls, entity_ids: List[str]) -> "{class_name}":
if len(entity_ids) != cls.URN_PARTS:
raise InvalidUrnError(f"{class_name} should have {{cls.URN_PARTS}} parts, got {{len(entity_ids)}}: {{entity_ids}}")
if len(entity_ids) != cls._URN_PARTS:
raise InvalidUrnError(f"{class_name} should have {{cls._URN_PARTS}} parts, got {{len(entity_ids)}}: {{entity_ids}}")
return cls({parse_ids_mapping}, _allow_coercion=False)
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

base_requirements = {
# Our min version of typing_extensions is somewhat constrained by Airflow.
"typing_extensions>=3.10.0.2",
"typing_extensions>=4.2.0",
# Actual dependencies.
"typing-inspect",
# pydantic 1.8.2 is incompatible with mypy 0.910.
Expand Down
7 changes: 2 additions & 5 deletions metadata-ingestion/src/datahub/configuration/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
List,
Optional,
Type,
TypeVar,
Union,
runtime_checkable,
)
Expand All @@ -19,14 +18,12 @@
from cached_property import cached_property
from pydantic import BaseModel, Extra, ValidationError
from pydantic.fields import Field
from typing_extensions import Protocol
from typing_extensions import Protocol, Self

from datahub.configuration._config_enum import ConfigEnum as ConfigEnum # noqa: I250
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
from datahub.utilities.dedup_list import deduplicate_list

_ConfigSelf = TypeVar("_ConfigSelf", bound="ConfigModel")

REDACT_KEYS = {
"password",
"token",
Expand Down Expand Up @@ -109,7 +106,7 @@ def _schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None:
schema_extra = _schema_extra

@classmethod
def parse_obj_allow_extras(cls: Type[_ConfigSelf], obj: Any) -> _ConfigSelf:
def parse_obj_allow_extras(cls, obj: Any) -> Self:
if PYDANTIC_VERSION_2:
try:
with unittest.mock.patch.dict(
Expand Down
6 changes: 3 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/api/closeable.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from abc import abstractmethod
from contextlib import AbstractContextManager
from types import TracebackType
from typing import Optional, Type, TypeVar
from typing import Optional, Type

_Self = TypeVar("_Self", bound="Closeable")
from typing_extensions import Self


class Closeable(AbstractContextManager):
@abstractmethod
def close(self) -> None:
pass

def __enter__(self: _Self) -> _Self:
def __enter__(self) -> Self:
# This method is mainly required for type checking.
return self

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, NewType, Optional, Type, TypeVar
from typing import Any, Dict, NewType, Optional

from typing_extensions import Self

import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel
Expand All @@ -17,9 +19,6 @@ class IngestionCheckpointingProviderConfig(ConfigModel):
pass


_Self = TypeVar("_Self", bound="IngestionCheckpointingProviderBase")


@dataclass()
class IngestionCheckpointingProviderBase(StatefulCommittable[CheckpointJobStatesMap]):
"""
Expand All @@ -32,9 +31,7 @@ def __init__(self, name: str, commit_policy: CommitPolicy = CommitPolicy.ALWAYS)

@classmethod
@abstractmethod
def create(
cls: Type[_Self], config_dict: Dict[str, Any], ctx: PipelineContext
) -> "_Self":
def create(cls, config_dict: Dict[str, Any], ctx: PipelineContext) -> Self:
pass

@abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/api/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def to_pure_python_obj(some_val: Any) -> Any:
return some_val.as_obj()
elif isinstance(some_val, pydantic.BaseModel):
return Report.to_pure_python_obj(some_val.dict())
elif dataclasses.is_dataclass(some_val):
elif dataclasses.is_dataclass(some_val) and not isinstance(some_val, type):
# The `is_dataclass` function returns `True` for both instances and classes.
# We need an extra check to ensure an instance was passed in.
# https://docs.python.org/3/library/dataclasses.html#dataclasses.is_dataclass
return dataclasses.asdict(some_val)
elif isinstance(some_val, list):
return [Report.to_pure_python_obj(v) for v in some_val if v is not None]
Expand Down
7 changes: 4 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/api/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass, field
from typing import Any, Generic, Optional, Type, TypeVar, cast

from typing_extensions import Self

from datahub.configuration.common import ConfigModel
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
Expand Down Expand Up @@ -79,7 +81,6 @@ def on_failure(

SinkReportType = TypeVar("SinkReportType", bound=SinkReport, covariant=True)
SinkConfig = TypeVar("SinkConfig", bound=ConfigModel, covariant=True)
Self = TypeVar("Self", bound="Sink")


class Sink(Generic[SinkConfig, SinkReportType], Closeable, metaclass=ABCMeta):
Expand All @@ -90,7 +91,7 @@ class Sink(Generic[SinkConfig, SinkReportType], Closeable, metaclass=ABCMeta):
report: SinkReportType

@classmethod
def get_config_class(cls: Type[Self]) -> Type[SinkConfig]:
def get_config_class(cls) -> Type[SinkConfig]:
config_class = get_class_from_annotation(cls, Sink, ConfigModel)
assert config_class, "Sink subclasses must define a config class"
return cast(Type[SinkConfig], config_class)
Expand All @@ -112,7 +113,7 @@ def __post_init__(self) -> None:
pass

@classmethod
def create(cls: Type[Self], config_dict: dict, ctx: PipelineContext) -> "Self":
def create(cls, config_dict: dict, ctx: PipelineContext) -> "Self":
return cls(ctx, cls.get_config_class().parse_obj(config_dict))

def handle_work_unit_start(self, workunit: WorkUnit) -> None:
Expand Down
12 changes: 5 additions & 7 deletions metadata-ingestion/src/datahub/utilities/urns/_urn_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import functools
import urllib.parse
from abc import abstractmethod
from typing import ClassVar, Dict, List, Optional, Type, TypeVar
from typing import ClassVar, Dict, List, Optional, Type

from deprecated import deprecated
from typing_extensions import Self

from datahub.utilities.urns.error import InvalidUrnError

Expand Down Expand Up @@ -42,9 +43,6 @@ def _split_entity_id(entity_id: str) -> List[str]:
return parts


_UrnSelf = TypeVar("_UrnSelf", bound="Urn")


@functools.total_ordering
class Urn:
"""
Expand Down Expand Up @@ -88,7 +86,7 @@ def entity_ids(self) -> List[str]:
return self._entity_ids

@classmethod
def from_string(cls: Type[_UrnSelf], urn_str: str) -> "_UrnSelf":
def from_string(cls, urn_str: str) -> Self:
"""
Creates an Urn from its string representation.
Expand Down Expand Up @@ -174,7 +172,7 @@ def __hash__(self) -> int:

@classmethod
@deprecated(reason="prefer .from_string")
def create_from_string(cls: Type[_UrnSelf], urn_str: str) -> "_UrnSelf":
def create_from_string(cls, urn_str: str) -> Self:
return cls.from_string(urn_str)

@deprecated(reason="prefer .entity_ids")
Expand Down Expand Up @@ -270,5 +268,5 @@ def underlying_key_aspect_type(cls) -> Type:

@classmethod
@abstractmethod
def _parse_ids(cls: Type[_UrnSelf], entity_ids: List[str]) -> _UrnSelf:
def _parse_ids(cls, entity_ids: List[str]) -> Self:
raise NotImplementedError()

0 comments on commit 4e3103e

Please sign in to comment.