Skip to content

Commit

Permalink
Add overloads on returning (#40)
Browse files Browse the repository at this point in the history
* add overloads on returning

* Remove unused import
  • Loading branch information
max-muoto authored Sep 7, 2024
1 parent 798415a commit 2891aba
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 3.0.1

#### Trivial

- Add overloads on `upsert`, `aupsert`, `update`, and `aupdate` to improve type-checking on `returning=...` by [@max-muoto](https://github.com/max-muoto) in [#40](https://github.com/Opus10/django-pgbulk/pull/40/).

## 3.0.0

#### Breaking Changes
Expand Down
152 changes: 151 additions & 1 deletion pgbulk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TypeVar,
Union,
cast,
overload,
)

from asgiref.sync import sync_to_async
Expand Down Expand Up @@ -63,7 +64,6 @@ def _psycopg_version() -> Tuple[int, int, int]:


if psycopg_maj_version == 2:
from psycopg2.extensions import AsIs as Literal # type: ignore
from psycopg2.extensions import quote_ident # type: ignore
elif psycopg_maj_version == 3:
import psycopg.adapt # type: ignore
Expand Down Expand Up @@ -598,6 +598,42 @@ def _update(
return updated if returning else None


@overload
def update(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
update_fields: Union[List[str], None] = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], Literal[True]] = ...,
ignore_unchanged: bool = False,
) -> List["Row"]: ...


@overload
def update(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
update_fields: Union[List[str], None] = None,
*,
exclude: Union[List[str], None] = None,
returning: Literal[False] = False,
ignore_unchanged: bool = False,
) -> None: ...


@overload
def update(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
update_fields: Union[List[str], None] = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], bool] = False,
ignore_unchanged: bool = False,
) -> Union[List["Row"], None]: ...


def update(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
Expand Down Expand Up @@ -641,6 +677,42 @@ def update(
)


@overload
async def aupdate(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
update_fields: Union[List[str], None] = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], Literal[True]] = ...,
ignore_unchanged: bool = False,
) -> List["Row"]: ...


@overload
async def aupdate(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
update_fields: Union[List[str], None] = None,
*,
exclude: Union[List[str], None] = None,
returning: Literal[False] = False,
ignore_unchanged: bool = False,
) -> None: ...


@overload
async def aupdate(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
update_fields: Union[List[str], None] = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], bool] = False,
ignore_unchanged: bool = False,
) -> Union[List["Row"], None]: ...


async def aupdate(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
Expand Down Expand Up @@ -670,6 +742,45 @@ async def aupdate(
)


@overload
def upsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
unique_fields: List[str],
update_fields: UpdateFieldsTypeDef = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], Literal[True]] = ...,
ignore_unchanged: bool = False,
) -> UpsertResult: ...


@overload
def upsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
unique_fields: List[str],
update_fields: UpdateFieldsTypeDef = None,
*,
exclude: Union[List[str], None] = None,
returning: Literal[False] = False,
ignore_unchanged: bool = False,
) -> None: ...


@overload
def upsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
unique_fields: List[str],
update_fields: UpdateFieldsTypeDef = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], bool] = False,
ignore_unchanged: bool = False,
) -> Union[UpsertResult, None]: ...


def upsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
Expand Down Expand Up @@ -724,6 +835,45 @@ def upsert(
)


@overload
async def aupsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
unique_fields: List[str],
update_fields: UpdateFieldsTypeDef = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], Literal[True]] = ...,
ignore_unchanged: bool = False,
) -> UpsertResult: ...


@overload
async def aupsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
unique_fields: List[str],
update_fields: UpdateFieldsTypeDef = None,
*,
exclude: Union[List[str], None] = None,
returning: Literal[False] = False,
ignore_unchanged: bool = False,
) -> None: ...


@overload
async def aupsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
unique_fields: List[str],
update_fields: UpdateFieldsTypeDef = None,
*,
exclude: Union[List[str], None] = None,
returning: Union[List[str], bool] = False,
ignore_unchanged: bool = False,
) -> Union[UpsertResult, None]: ...


async def aupsert(
queryset: QuerySet[_M],
model_objs: Iterable[_M],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ packages = [
exclude = [
"*/tests/"
]
version = "3.0.0"
version = "3.0.1"
description = "Native Postgres update, upsert, and copy operations."
authors = ["Wes Kendall"]
classifiers = [
Expand Down

0 comments on commit 2891aba

Please sign in to comment.