diff --git a/src/brackets/mixins/form_views.py b/src/brackets/mixins/form_views.py index 03cc67a..fe887d4 100644 --- a/src/brackets/mixins/form_views.py +++ b/src/brackets/mixins/form_views.py @@ -69,13 +69,12 @@ class MultipleFormsMixin(FormMixin): form_initial_values: Optional[Mapping[str, Mapping[str, Any]]] = None form_instances: Optional[Mapping[str, models.Model]] = None - def get_context_data(self, **kwargs: Mapping[str, Any]) -> dict[str, Any]: + def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]: """Add the forms to the view context.""" kwargs.setdefault("view", self) if self.extra_context is not None: kwargs.update(self.extra_context) - if "forms" not in kwargs: - kwargs["forms"] = self.get_forms() + kwargs["forms"] = self.get_forms() return kwargs def get_form_classes(self) -> Mapping[str, type[forms.BaseForm]]: @@ -97,10 +96,10 @@ def get_form_classes(self) -> Mapping[str, type[forms.BaseForm]]: def get_forms(self) -> dict[str, forms.BaseForm]: """Instantiate the forms with their kwargs.""" - _forms: dict[str, forms.BaseForm] = {} + forms: dict[str, forms.BaseForm] = {} for name, form_class in self.get_form_classes().items(): - _forms[name] = form_class(**self.get_form_kwargs(name)) - return _forms + forms[name] = form_class(**self.get_form_kwargs(name)) + return forms def get_instance(self, name: str) -> models.Model: """Connect instances to forms.""" diff --git a/tests/mixins/test_form_views.py b/tests/mixins/test_form_views.py index 9507039..9ba1a7a 100644 --- a/tests/mixins/test_form_views.py +++ b/tests/mixins/test_form_views.py @@ -72,6 +72,16 @@ def test_csrf_exempt(self, form, view, rf): class TestMultipleFormsMixin: """Tests related to the MultipleFormsMixin.""" + def test_extra_context(self, form_view, form_class, rf): + """A view can take extra context.""" + request = rf.get("/") + view_class = form_view()( + extra_context={"foo": "bar"}, + form_classes={"one": form_class()}, + request=request, + ) + assert view_class.get_context_data()["foo"] == "bar" + def test_missing_form_classes(self, form_view): """A view with no instances or initials should fail.""" view_class = form_view()