Skip to content

Commit

Permalink
Validate embedded model fields.
Browse files Browse the repository at this point in the history
  • Loading branch information
WaVEV committed Dec 30, 2024
1 parent 6f43abc commit bf71f56
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
16 changes: 14 additions & 2 deletions django_mongodb/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def __init__(self, embedded_model, *args, **kwargs):
stored. Like other relational fields, it may also be passed as a
string.
"""
if not isinstance(embedded_model, str):
self._validate_embedded_field(self, embedded_model)

self.embedded_model = embedded_model
super().__init__(*args, **kwargs)

Expand All @@ -28,6 +31,14 @@ def deconstruct(self):
def get_internal_type(self):
return "EmbeddedModelField"

@staticmethod
def _validate_embedded_field(_, model):
for field in model._meta.local_fields:
if isinstance(field, models.ForeignKey | models.OneToOneField):
raise TypeError(
f"Field of type {type(field)!r} is not supported within an EmbeddedModelField."
)

def _set_model(self, model):
"""
Resolve embedded model class once the field knows the model it belongs
Expand All @@ -48,6 +59,7 @@ def _resolve_lookup(_, resolved_model):
self.embedded_model = resolved_model

lazy_related_operation(_resolve_lookup, model, self.embedded_model)
lazy_related_operation(self._validate_embedded_field, model, self.embedded_model)

model = property(lambda self: self._model, _set_model)

Expand Down Expand Up @@ -139,7 +151,7 @@ def formfield(self, **kwargs):


class KeyTransform(Transform):
def __init__(self, key_name, ref_field=None, *args, **kwargs):
def __init__(self, key_name, ref_field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)
self.ref_field = ref_field
Expand Down Expand Up @@ -195,7 +207,7 @@ def key_transform(self, compiler, connection):


class KeyTransformFactory:
def __init__(self, key_name, ref_field=None):
def __init__(self, key_name, ref_field):
self.key_name = key_name
self.ref_field = ref_field

Expand Down
12 changes: 12 additions & 0 deletions tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
F,
IntegerField,
Max,
Model,
OuterRef,
Subquery,
Sum,
Expand All @@ -21,6 +22,7 @@
Author,
Book,
DecimalKey,
DecimalParent,
EmbeddedModel,
EmbeddedModelFieldModel,
Library,
Expand Down Expand Up @@ -97,6 +99,16 @@ def test_embedded_field_with_foreign_conversion(self):
# decimal_parent = DecimalParent.objects.create(child=decimal)
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal)

def test_foreign_key_in_embedded_object(self):
msg = (
"Field of type <class 'django.db.models.fields.related.ForeignKey'> "
"is not supported within an EmbeddedModelField."
)
with self.assertRaisesMessage(TypeError, msg):

class EmbeddedModelTest(Model):
decimal = EmbeddedModelField(DecimalParent, null=True, blank=True)


class QueryingTests(TestCase):
@classmethod
Expand Down

0 comments on commit bf71f56

Please sign in to comment.