From fba836a616d301b8dadf48d404bb4791642f1bf5 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:48:33 -0500 Subject: [PATCH] Fix `isinstance` behavior for urls (#10766) --- pydantic/networks.py | 355 +++++++++++++++++++++++++++++++++++--- tests/test_json_schema.py | 12 +- tests/test_networks.py | 27 ++- 3 files changed, 366 insertions(+), 28 deletions(-) diff --git a/pydantic/networks.py b/pydantic/networks.py index 3fe0714a21..c565549295 100644 --- a/pydantic/networks.py +++ b/pydantic/networks.py @@ -5,11 +5,14 @@ import dataclasses as _dataclasses import re from dataclasses import fields +from functools import lru_cache from importlib.metadata import version from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from typing import TYPE_CHECKING, Any, ClassVar -from pydantic_core import MultiHostUrl, PydanticCustomError, Url, core_schema +from pydantic_core import MultiHostHost, PydanticCustomError, core_schema +from pydantic_core import MultiHostUrl as _CoreMultiHostUrl +from pydantic_core import Url as _CoreUrl from typing_extensions import Annotated, Self, TypeAlias from pydantic.errors import PydanticUserError @@ -18,6 +21,7 @@ from ._migration import getattr_migration from .annotated_handlers import GetCoreSchemaHandler from .json_schema import JsonSchemaValue +from .type_adapter import TypeAdapter if TYPE_CHECKING: import email_validator @@ -95,17 +99,175 @@ def defined_constraints(self) -> dict[str, Any]: return {field.name: getattr(self, field.name) for field in fields(self)} -# TODO: there's a lot of repeated code in these two base classes - should we consolidate, or does that up -# the complexity enough that it's not worth saving a few lines? +class _BaseUrl: + _constraints: ClassVar[UrlConstraints] = UrlConstraints() + _url: _CoreUrl + def __init__(self, url: str | _CoreUrl | _BaseUrl) -> None: + self._url = _build_type_adapter(self.__class__).validate_python(url) -class _BaseUrl(Url): - _constraints: ClassVar[UrlConstraints] = UrlConstraints() + @property + def scheme(self) -> str: + """The scheme part of the URL. + + e.g. `https` in `https://user:pass@host:port/path?query#fragment` + """ + return self._url.scheme + + @property + def username(self) -> str | None: + """The username part of the URL, or `None`. + + e.g. `user` in `https://user:pass@host:port/path?query#fragment` + """ + return self._url.username + + @property + def password(self) -> str | None: + """The password part of the URL, or `None`. + + e.g. `pass` in `https://user:pass@host:port/path?query#fragment` + """ + return self._url.password + + @property + def host(self) -> str | None: + """The host part of the URL, or `None`. + + If the URL must be punycode encoded, this is the encoded host, e.g if the input URL is `https://£££.com`, + `host` will be `xn--9aaa.com` + """ + return self._url.host + + def unicode_host(self) -> str | None: + """The host part of the URL as a unicode string, or `None`. + + e.g. `host` in `https://user:pass@host:port/path?query#fragment` + + If the URL must be punycode encoded, this is the decoded host, e.g if the input URL is `https://£££.com`, + `unicode_host()` will be `£££.com` + """ + return self._url.unicode_host() + + @property + def port(self) -> int | None: + """The port part of the URL, or `None`. + + e.g. `port` in `https://user:pass@host:port/path?query#fragment` + """ + return self._url.port + + @property + def path(self) -> str | None: + """The path part of the URL, or `None`. + + e.g. `/path` in `https://user:pass@host:port/path?query#fragment` + """ + return self._url.path + + @property + def query(self) -> str | None: + """The query part of the URL, or `None`. + + e.g. `query` in `https://user:pass@host:port/path?query#fragment` + """ + return self._url.query + + def query_params(self) -> list[tuple[str, str]]: + """The query part of the URL as a list of key-value pairs. + + e.g. `[('foo', 'bar')]` in `https://user:pass@host:port/path?foo=bar#fragment` + """ + return self._url.query_params() + + @property + def fragment(self) -> str | None: + """The fragment part of the URL, or `None`. + + e.g. `fragment` in `https://user:pass@host:port/path?query#fragment` + """ + return self._url.fragment + + def unicode_string(self) -> str: + """The URL as a unicode string, unlike `__str__()` this will not punycode encode the host. + + If the URL must be punycode encoded, this is the decoded string, e.g if the input URL is `https://£££.com`, + `unicode_string()` will be `https://£££.com` + """ + return self._url.unicode_string() + + def __str__(self) -> str: + """The URL as a string, this will punycode encode the host if required.""" + return str(self._url) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({str(self._url)!r})' + + def __deepcopy__(self, memo: dict) -> Self: + return self.__class__(self._url) @classmethod - def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def build( + cls, + *, + scheme: str, + username: str | None = None, + password: str | None = None, + host: str, + port: int | None = None, + path: str | None = None, + query: str | None = None, + fragment: str | None = None, + ) -> Self: + """Build a new `Url` instance from its component parts. + + Args: + scheme: The scheme part of the URL. + username: The username part of the URL, or omit for no username. + password: The password part of the URL, or omit for no password. + host: The host part of the URL. + port: The port part of the URL, or omit for no port. + path: The path part of the URL, or omit for no path. + query: The query part of the URL, or omit for no query. + fragment: The fragment part of the URL, or omit for no fragment. + + Returns: + An instance of URL + """ + return cls( + _CoreUrl.build( + scheme=scheme, + username=username, + password=password, + host=host, + port=port, + path=path, + query=query, + fragment=fragment, + ) + ) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[_BaseUrl], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: if issubclass(cls, source): - return core_schema.url_schema(**cls._constraints.defined_constraints) + + def wrap_val(value, handler): + if isinstance(value, source): + return value + if isinstance(value, _BaseUrl): + value = str(value) + core_url = handler(value) + instance = source.__new__(source) + instance._url = core_url + return instance + + return core_schema.no_info_wrap_validator_function( + wrap_val, + schema=core_schema.url_schema(**cls._constraints.defined_constraints), + serialization=core_schema.to_string_ser_schema(), + ) else: schema = handler(source) # TODO: this logic is used in types.py as well in the _check_annotated_type function, should we move that to somewhere more central? @@ -118,13 +280,153 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH return schema -class _BaseMultiHostUrl(MultiHostUrl): +class _BaseMultiHostUrl: _constraints: ClassVar[UrlConstraints] = UrlConstraints() + _url: _CoreMultiHostUrl + + def __init__(self, url: str | _CoreMultiHostUrl | _BaseMultiHostUrl) -> None: + self._url = _build_type_adapter(self.__class__).validate_python(url) + + @property + def scheme(self) -> str: + """The scheme part of the URL. + + e.g. `https` in `https://foo.com,bar.com/path?query#fragment` + """ + return self._url.scheme + + @property + def path(self) -> str | None: + """The path part of the URL, or `None`. + + e.g. `/path` in `https://foo.com,bar.com/path?query#fragment` + """ + return self._url.path + + @property + def query(self) -> str | None: + """The query part of the URL, or `None`. + + e.g. `query` in `https://foo.com,bar.com/path?query#fragment` + """ + return self._url.query + + def query_params(self) -> list[tuple[str, str]]: + """The query part of the URL as a list of key-value pairs. + + e.g. `[('foo', 'bar')]` in `https://foo.com,bar.com/path?query#fragment` + """ + return self._url.query_params() + + @property + def fragment(self) -> str | None: + """The fragment part of the URL, or `None`. + + e.g. `fragment` in `https://foo.com,bar.com/path?query#fragment` + """ + return self._url.fragment + + def hosts(self) -> list[MultiHostHost]: + '''The hosts of the `MultiHostUrl` as [`MultiHostHost`][pydantic_core.MultiHostHost] typed dicts. + + ```py + from pydantic_core import MultiHostUrl + + mhu = MultiHostUrl('https://foo.com:123,foo:bar@bar.com/path') + print(mhu.hosts()) + """ + [ + {'username': None, 'password': None, 'host': 'foo.com', 'port': 123}, + {'username': 'foo', 'password': 'bar', 'host': 'bar.com', 'port': 443} + ] + ``` + Returns: + A list of dicts, each representing a host. + ''' + return self._url.hosts() + + def unicode_string(self) -> str: + """The URL as a unicode string, unlike `__str__()` this will not punycode encode the hosts.""" + return self._url.unicode_string() + + def __str__(self) -> str: + """The URL as a string, this will punycode encode the host if required.""" + return str(self._url) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({str(self._url)!r})' + + def __deepcopy__(self, memo: dict) -> Self: + return self.__class__(self._url) @classmethod - def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def build( + cls, + *, + scheme: str, + hosts: list[MultiHostHost] | None = None, + username: str | None = None, + password: str | None = None, + host: str | None = None, + port: int | None = None, + path: str | None = None, + query: str | None = None, + fragment: str | None = None, + ) -> Self: + """Build a new `MultiHostUrl` instance from its component parts. + + This method takes either `hosts` - a list of `MultiHostHost` typed dicts, or the individual components + `username`, `password`, `host` and `port`. + + Args: + scheme: The scheme part of the URL. + hosts: Multiple hosts to build the URL from. + username: The username part of the URL. + password: The password part of the URL. + host: The host part of the URL. + port: The port part of the URL. + path: The path part of the URL. + query: The query part of the URL, or omit for no query. + fragment: The fragment part of the URL, or omit for no fragment. + + Returns: + An instance of `MultiHostUrl` + """ + return cls( + _CoreMultiHostUrl.build( + scheme=scheme, + hosts=hosts, + username=username, + password=password, + host=host, + port=port, + path=path, + query=query, + fragment=fragment, + ) + ) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[_BaseMultiHostUrl], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: if issubclass(cls, source): - return core_schema.multi_host_url_schema(**cls._constraints.defined_constraints) + + def wrap_val(value, handler): + if isinstance(value, source): + return value + if isinstance(value, _BaseMultiHostUrl): + value = str(value) + core_url = handler(value) + instance = source.__new__(source) + instance._url = core_url + return instance + + return core_schema.no_info_wrap_validator_function( + wrap_val, + schema=core_schema.multi_host_url_schema(**cls._constraints.defined_constraints), + serialization=core_schema.to_string_ser_schema(), + ) else: schema = handler(source) # TODO: this logic is used in types.py as well in the _check_annotated_type function, should we move that to somewhere more central? @@ -137,6 +439,11 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH return schema +@lru_cache +def _build_type_adapter(cls: type[_BaseUrl | _BaseMultiHostUrl]) -> TypeAdapter: + return TypeAdapter(cls) + + class AnyUrl(_BaseUrl): """Base type for all URLs. @@ -162,7 +469,7 @@ class AnyUrl(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class AnyHttpUrl(_BaseUrl): @@ -177,7 +484,7 @@ class AnyHttpUrl(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class HttpUrl(_BaseUrl): @@ -263,7 +570,7 @@ class MyModel(BaseModel): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class AnyWebsocketUrl(_BaseUrl): @@ -278,7 +585,7 @@ class AnyWebsocketUrl(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class WebsocketUrl(_BaseUrl): @@ -294,7 +601,7 @@ class WebsocketUrl(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class FileUrl(_BaseUrl): @@ -342,7 +649,7 @@ class MyModel(BaseModel): # the repr() method for a url will display all properties of the url print(repr(m.url)) - #> Url('http://www.example.com/') + #> HttpUrl('http://www.example.com/') print(m.url.scheme) #> http print(m.url.host) @@ -371,7 +678,7 @@ def check_db_name(cls, v): db Assertion failed, database must be provided assert (None) - + where None = MultiHostUrl('postgres://user:pass@localhost:5432').path [type=assertion_error, input_value='postgres://user:pass@localhost:5432', input_type=str] + + where None = PostgresDsn('postgres://user:pass@localhost:5432').path [type=assertion_error, input_value='postgres://user:pass@localhost:5432', input_type=str] ''' ``` """ @@ -394,7 +701,7 @@ def check_db_name(cls, v): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class CockroachDsn(_BaseUrl): @@ -417,7 +724,7 @@ class CockroachDsn(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class AmqpDsn(_BaseUrl): @@ -450,7 +757,7 @@ class RedisDsn(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class MongoDsn(_BaseMultiHostUrl): @@ -518,7 +825,7 @@ class MySQLDsn(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class MariaDBDsn(_BaseUrl): @@ -538,7 +845,7 @@ class MariaDBDsn(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class ClickHouseDsn(_BaseUrl): @@ -559,7 +866,7 @@ class ClickHouseDsn(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore class SnowflakeDsn(_BaseUrl): @@ -578,7 +885,7 @@ class SnowflakeDsn(_BaseUrl): @property def host(self) -> str: """The required URL host.""" - ... + return self._url.host # type: ignore def import_email_validator() -> None: diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index fb7039e075..bad671d117 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -76,7 +76,15 @@ model_json_schema, models_json_schema, ) -from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, MultiHostUrl, NameEmail +from pydantic.networks import ( + AnyUrl, + EmailStr, + IPvAnyAddress, + IPvAnyInterface, + IPvAnyNetwork, + NameEmail, + _CoreMultiHostUrl, +) from pydantic.type_adapter import TypeAdapter from pydantic.types import ( UUID1, @@ -933,7 +941,7 @@ class Model(BaseModel): Annotated[AnyUrl, Field(max_length=2**16)], {'title': 'A', 'type': 'string', 'format': 'uri', 'minLength': 1, 'maxLength': 2**16}, ), - (MultiHostUrl, {'title': 'A', 'type': 'string', 'format': 'multi-host-uri', 'minLength': 1}), + (_CoreMultiHostUrl, {'title': 'A', 'type': 'string', 'format': 'multi-host-uri', 'minLength': 1}), ], ) def test_special_str_types(field_type, expected_schema): diff --git a/tests/test_networks.py b/tests/test_networks.py index 31f954ea3e..184f4031fb 100644 --- a/tests/test_networks.py +++ b/tests/test_networks.py @@ -24,6 +24,7 @@ RedisDsn, SnowflakeDsn, Strict, + TypeAdapter, UrlConstraints, ValidationError, WebsocketUrl, @@ -167,7 +168,7 @@ class Model(BaseModel): def test_any_url_parts(): url = validate_url('http://example.org') assert str(url) == 'http://example.org/' - assert repr(url) == "Url('http://example.org/')" + assert repr(url) == "AnyUrl('http://example.org/')" assert url.scheme == 'http' assert url.host == 'example.org' assert url.port == 80 @@ -176,7 +177,7 @@ def test_any_url_parts(): def test_url_repr(): url = validate_url('http://user:password@example.org:1234/the/path/?query=here#fragment=is;this=bit') assert str(url) == 'http://user:password@example.org:1234/the/path/?query=here#fragment=is;this=bit' - assert repr(url) == "Url('http://user:password@example.org:1234/the/path/?query=here#fragment=is;this=bit')" + assert repr(url) == "AnyUrl('http://user:password@example.org:1234/the/path/?query=here#fragment=is;this=bit')" assert url.scheme == 'http' assert url.username == 'user' assert url.password == 'password' @@ -1030,3 +1031,25 @@ class Model(BaseModel): obj = json.loads(m.model_dump_json()) Model(email=obj['email']) + + +def test_specialized_urls() -> None: + ta = TypeAdapter(HttpUrl) + + http_url = ta.validate_python('http://example.com/something') + assert str(http_url) == 'http://example.com/something' + assert repr(http_url) == "HttpUrl('http://example.com/something')" + assert http_url.__class__ == HttpUrl + assert http_url.host == 'example.com' + assert http_url.path == '/something' + assert http_url.username is None + assert http_url.password is None + + http_url2 = ta.validate_python(http_url) + assert str(http_url2) == 'http://example.com/something' + assert repr(http_url2) == "HttpUrl('http://example.com/something')" + assert http_url2.__class__ == HttpUrl + assert http_url2.host == 'example.com' + assert http_url2.path == '/something' + assert http_url2.username is None + assert http_url2.password is None