From 2891aba6c2cd09a43969e20a94b74d2299fcd37c Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Fri, 6 Sep 2024 20:53:31 -0500 Subject: [PATCH] Add overloads on `returning` (#40) * add overloads on returning * Remove unused import --- CHANGELOG.md | 6 ++ pgbulk/core.py | 152 ++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 +- 3 files changed, 158 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea38b4b..dcf60cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 3.0.1 + +#### Trivial + + - Add overloads on `upsert`, `aupsert`, `update`, and `aupdate` to improve type-checking on `returning=...` by [@max-muoto](https://github.com/max-muoto) in [#40](https://github.com/Opus10/django-pgbulk/pull/40/). + ## 3.0.0 #### Breaking Changes diff --git a/pgbulk/core.py b/pgbulk/core.py index 038e644..1cd8be0 100644 --- a/pgbulk/core.py +++ b/pgbulk/core.py @@ -11,6 +11,7 @@ TypeVar, Union, cast, + overload, ) from asgiref.sync import sync_to_async @@ -63,7 +64,6 @@ def _psycopg_version() -> Tuple[int, int, int]: if psycopg_maj_version == 2: - from psycopg2.extensions import AsIs as Literal # type: ignore from psycopg2.extensions import quote_ident # type: ignore elif psycopg_maj_version == 3: import psycopg.adapt # type: ignore @@ -598,6 +598,42 @@ def _update( return updated if returning else None +@overload +def update( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + update_fields: Union[List[str], None] = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], Literal[True]] = ..., + ignore_unchanged: bool = False, +) -> List["Row"]: ... + + +@overload +def update( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + update_fields: Union[List[str], None] = None, + *, + exclude: Union[List[str], None] = None, + returning: Literal[False] = False, + ignore_unchanged: bool = False, +) -> None: ... + + +@overload +def update( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + update_fields: Union[List[str], None] = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], bool] = False, + ignore_unchanged: bool = False, +) -> Union[List["Row"], None]: ... + + def update( queryset: QuerySet[_M], model_objs: Iterable[_M], @@ -641,6 +677,42 @@ def update( ) +@overload +async def aupdate( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + update_fields: Union[List[str], None] = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], Literal[True]] = ..., + ignore_unchanged: bool = False, +) -> List["Row"]: ... + + +@overload +async def aupdate( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + update_fields: Union[List[str], None] = None, + *, + exclude: Union[List[str], None] = None, + returning: Literal[False] = False, + ignore_unchanged: bool = False, +) -> None: ... + + +@overload +async def aupdate( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + update_fields: Union[List[str], None] = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], bool] = False, + ignore_unchanged: bool = False, +) -> Union[List["Row"], None]: ... + + async def aupdate( queryset: QuerySet[_M], model_objs: Iterable[_M], @@ -670,6 +742,45 @@ async def aupdate( ) +@overload +def upsert( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + unique_fields: List[str], + update_fields: UpdateFieldsTypeDef = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], Literal[True]] = ..., + ignore_unchanged: bool = False, +) -> UpsertResult: ... + + +@overload +def upsert( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + unique_fields: List[str], + update_fields: UpdateFieldsTypeDef = None, + *, + exclude: Union[List[str], None] = None, + returning: Literal[False] = False, + ignore_unchanged: bool = False, +) -> None: ... + + +@overload +def upsert( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + unique_fields: List[str], + update_fields: UpdateFieldsTypeDef = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], bool] = False, + ignore_unchanged: bool = False, +) -> Union[UpsertResult, None]: ... + + def upsert( queryset: QuerySet[_M], model_objs: Iterable[_M], @@ -724,6 +835,45 @@ def upsert( ) +@overload +async def aupsert( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + unique_fields: List[str], + update_fields: UpdateFieldsTypeDef = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], Literal[True]] = ..., + ignore_unchanged: bool = False, +) -> UpsertResult: ... + + +@overload +async def aupsert( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + unique_fields: List[str], + update_fields: UpdateFieldsTypeDef = None, + *, + exclude: Union[List[str], None] = None, + returning: Literal[False] = False, + ignore_unchanged: bool = False, +) -> None: ... + + +@overload +async def aupsert( + queryset: QuerySet[_M], + model_objs: Iterable[_M], + unique_fields: List[str], + update_fields: UpdateFieldsTypeDef = None, + *, + exclude: Union[List[str], None] = None, + returning: Union[List[str], bool] = False, + ignore_unchanged: bool = False, +) -> Union[UpsertResult, None]: ... + + async def aupsert( queryset: QuerySet[_M], model_objs: Iterable[_M], diff --git a/pyproject.toml b/pyproject.toml index 165cd29..7e9bdd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ packages = [ exclude = [ "*/tests/" ] -version = "3.0.0" +version = "3.0.1" description = "Native Postgres update, upsert, and copy operations." authors = ["Wes Kendall"] classifiers = [