Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mongoengine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class BaseField(object):
# Fields may have _types inserted into indexes by default
_index_with_types = True
_geo_index = False
_parent = None

# These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly
Expand Down Expand Up @@ -199,6 +200,10 @@ def __get__(self, instance, owner):
if callable(value):
value = value()

if isinstance(value, BaseDocument):
# if the field is an EmbeddedDocument, set the parent document
value._parent = instance

return value

def __set__(self, instance, value):
Expand All @@ -208,6 +213,10 @@ def __set__(self, instance, value):
if instance._initialised:
instance._mark_as_changed(self.name)

if isinstance(value, BaseDocument):
# if the field is an EmbeddedDocument, set the parent document
value._parent = instance

def error(self, message="", errors=None, field_name=None):
"""Raises a ValidationError.
"""
Expand Down
28 changes: 28 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2680,6 +2680,34 @@ class Employee(self.Person):
promoted_employee.reload()
self.assertEqual(promoted_employee.details, None)

def test_embedded_document_parents(self):
"""Ensure that embedded documents provide a _parent property
"""
class Position(EmbeddedDocument):
name = StringField()

class Employee(self.Person):
salary = IntField()
position = EmbeddedDocumentField(Position)
previous_positions = ListField(EmbeddedDocumentField(Position))

# assert _position is populated on __set__
employee = Employee(name='Test Employee', salary=20000)
position = Position(name='Developer')
employee.position = position
self.assertEqual(position._parent, employee)

# assert _position is populated on __get__
employee.save()
employee.reload()

self.assertEqual(employee.position._parent, employee)

# TODO get this working for list fields
employee.previous_positions = [Position(name='Analyst'),
Position(name='Intern')]
#self.assertEqual(employee.previous_positions[0]._parent, employee)

def test_mixins_dont_add_to_types(self):

class Mixin(object):
Expand Down