Skip to content

Commit

Permalink
Step 8
Browse files Browse the repository at this point in the history
  • Loading branch information
pfouque committed Jul 5, 2024
1 parent 802178b commit d8836db
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ repos:
- id: ruff

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
rev: v1.10.0
hooks:
- id: mypy
additional_dependencies:
- django-stubs==4.2.6
- django-stubs==5.0.0
- django-guardian
36 changes: 18 additions & 18 deletions django_fsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
from __future__ import annotations

import inspect
from collections.abc import Callable
from collections.abc import Collection
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from functools import partialmethod
from functools import wraps
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -41,6 +36,11 @@
]

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Collection
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Self

from _typeshed import Incomplete
Expand Down Expand Up @@ -121,10 +121,10 @@ def has_perm(self, instance: _Instance, user: UserWithPermissions) -> bool:
return True
return False

def __hash__(self):
def __hash__(self) -> int:
return hash(self.name)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, str):
return other == self.name
if isinstance(other, Transition):
Expand Down Expand Up @@ -240,8 +240,8 @@ def has_transition_perm(self, instance: _Instance, state: str, user: UserWithPer

if not transition:
return False
else:
return bool(transition.has_perm(instance, user))

return bool(transition.has_perm(instance, user))

def next_state(self, current_state: str) -> _StateValue:
transition = self.get_transition(current_state)
Expand Down Expand Up @@ -309,7 +309,7 @@ def deconstruct(self) -> Any:
def get_state(self, instance: _Instance) -> Any:
# The state field may be deferred. We delegate the logic of figuring this out
# and loading the deferred field on-demand to Django's built-in DeferredAttribute class.
return DeferredAttribute(self).__get__(instance) # type: ignore[attr-defined]
return DeferredAttribute(self).__get__(instance)

def set_state(self, instance: _Instance, state: str) -> None:
instance.__dict__[self.name] = state
Expand Down Expand Up @@ -479,14 +479,14 @@ class FSMModelMixin(_FSMModel):
Mixin that allows refresh_from_db for models with fsm protected fields
"""

def _get_protected_fsm_fields(self):
def is_fsm_and_protected(f):
def _get_protected_fsm_fields(self) -> set[str]:
def is_fsm_and_protected(f: object) -> Any:
return isinstance(f, FSMFieldMixin) and f.protected

protected_fields = filter(is_fsm_and_protected, self._meta.concrete_fields)
protected_fields: Iterable[Any] = filter(is_fsm_and_protected, self._meta.concrete_fields) # type: ignore[attr-defined, arg-type]
return {f.attname for f in protected_fields}

def refresh_from_db(self, *args, **kwargs):
def refresh_from_db(self, *args: Any, **kwargs: Any) -> None:
fields = kwargs.pop("fields", None)

# Use provided fields, if not set then reload all non-deferred fields.0
Expand All @@ -495,7 +495,7 @@ def refresh_from_db(self, *args, **kwargs):
protected_fields = self._get_protected_fsm_fields()
skipped_fields = deferred_fields.union(protected_fields)

fields = [f.attname for f in self._meta.concrete_fields if f.attname not in skipped_fields]
fields = [f.attname for f in self._meta.concrete_fields if f.attname not in skipped_fields] # type: ignore[attr-defined]

kwargs["fields"] = fields
super().refresh_from_db(*args, **kwargs)
Expand Down Expand Up @@ -538,9 +538,9 @@ def state_fields(self) -> Iterable[Any]:
def _do_update(
self,
base_qs: QuerySet[Self],
using: Any,
using: str | None,
pk_val: Any,
values: Collection[Any] | None,
values: Collection[tuple[_Field, type[models.Model] | None, Any]],
update_fields: Iterable[str] | None,
forced_update: bool,
) -> bool:
Expand All @@ -553,7 +553,7 @@ def _do_update(
# state filter will be used to narrow down the standard filter checking only PK
state_filter = {field.attname: self.__initial_states[field.attname] for field in filter_on}

updated: bool = super()._do_update( # type: ignore[misc]
updated: bool = super()._do_update(
base_qs=base_qs.filter(**state_filter),
using=using,
pk_val=pk_val,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ extend-select = [
"RET",
"C",
# "B",
"TCH", # trailing comma
]
fixable = ["I"]
fixable = ["I", "TCH"]


[tool.ruff.lint.isort]
Expand Down

0 comments on commit d8836db

Please sign in to comment.