Skip to content
This repository was archived by the owner on Feb 20, 2019. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion elasticutils/contrib/django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ class MappingType(BaseMappingType):
`get_model()`.

"""

id_field = 'pk'

def get_object(self):
"""Returns the database object for this result

Expand All @@ -205,7 +208,8 @@ def get_object(self):
self.get_model().objects.get(pk=self._id)

"""
return self.get_model().objects.get(pk=self._id)
kwargs = {self.id_field: self._id}
return self.get_model().objects.get(**kwargs)

@classmethod
def get_model(cls):
Expand Down
14 changes: 9 additions & 5 deletions elasticutils/contrib/django/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@


@task
def index_objects(mapping_type, ids, chunk_size=100, es=None, index=None):
def index_objects(mapping_type, ids, chunk_size=100, id_field='id', es=None,
index=None):
"""Index documents of a specified mapping type.

This allows for asynchronous indexing.
Expand Down Expand Up @@ -48,21 +49,24 @@ def update_in_index(sender, instance, **kw):

# Get the model this mapping type is based on.
model = mapping_type.get_model()
filter_key = '{0}__in'.format(id_field)

# Retrieve all the objects that we're going to index and do it in
# bulk.
for id_list in chunked(ids, chunk_size):
documents = []

for obj in model.objects.filter(id__in=id_list):
for obj in model.objects.filter(**{filter_key: id_list}):
try:
documents.append(mapping_type.extract_document(obj.id, obj))
except Exception as exc:
_id = str(getattr(obj, id_field))
documents.append(mapping_type.extract_document(_id, obj))
except StandardError as exc:
log.exception('Unable to extract document {0}: {1}'.format(
obj, repr(exc)))

if documents:
mapping_type.bulk_index(documents, id_field='id', es=es, index=index)
mapping_type.bulk_index(documents, id_field=id_field, es=es,
index=index)


@task
Expand Down
38 changes: 28 additions & 10 deletions elasticutils/contrib/django/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# We need to put these in a separate module so they're easy to import
# on a test-by-test basis so that we can skip django-requiring tests
# if django isn't installed.
from uuid import UUID


from elasticutils.contrib.django import MappingType, Indexable
Expand All @@ -22,14 +23,20 @@ class SearchQuerySet(object):
# Yes. This is kind of crazy, but ... whatever.
def __init__(self, model):
self.model = model
self.id_field = model.id_field
self.steps = []

def get(self, pk):
pk = int(pk)
return [m for m in _model_cache if m.id == pk][0]

def filter(self, id__in=None):
self.steps.append(('filter', id__in))
def get(self, pk=None, uuid=None):
if pk:
pk = int(pk)
return [m for m in _model_cache if m.id == pk][0]
if uuid:
uuid = UUID(uuid)
return [m for m in _model_cache if m.uuid == uuid][0]
return []

def filter(self, id__in=None, uuid__in=None):
self.steps.append(('filter', id__in or uuid__in))
return self

def order_by(self, *fields):
Expand All @@ -47,7 +54,8 @@ def __iter__(self):

for mem in self.steps:
if mem[0] == 'filter':
objs = [obj for obj in objs if obj.id in mem[1]]
objs = [obj for obj in objs
if getattr(obj, self.id_field) in mem[1]]
elif mem[0] == 'order_by':
order_by_field = mem[1][0]
elif mem[0] == 'values_list':
Expand All @@ -58,16 +66,16 @@ def __iter__(self):

if values_list:
# Note: Hard-coded to just id and flat
objs = [obj.id for obj in objs]
objs = [getattr(obj, self.id_field) for obj in objs]
return iter(objs)


class Manager(object):
def get_query_set(self):
return SearchQuerySet(self)

def get(self, pk):
return self.get_query_set().get(pk)
def get(self, pk=None, uuid=None):
return self.get_query_set().get(pk=pk, uuid=uuid)

def filter(self, *args, **kwargs):
return self.get_query_set().filter(*args, **kwargs)
Expand All @@ -84,6 +92,7 @@ class FakeModel(object):
objects = Manager()

def __init__(self, **kw):
self.objects.id_field = kw.pop('id_field', 'id')
self._doc = kw
for key in kw:
setattr(self, key, kw[key])
Expand All @@ -102,3 +111,12 @@ def extract_document(cls, obj_id, obj=None):
'what to do with these args.')

return obj._doc

class FakeDjangoWithUuidMappingType(FakeDjangoMappingType):
id_field = 'uuid'

@classmethod
def extract_document(cls, obj_id, obj=None):
doc = super(FakeDjangoWithUuidMappingType, cls)\
.extract_document(obj_id, obj=obj)
return {k:str(v) for k,v in doc.iteritems()}
23 changes: 19 additions & 4 deletions elasticutils/contrib/django/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import uuid

from nose.tools import eq_

from elasticutils.contrib.django import S, get_es
from elasticutils.contrib.django.tests import (
FakeDjangoMappingType, FakeModel, reset_model_cache)
FakeDjangoMappingType, FakeDjangoWithUuidMappingType, FakeModel,
reset_model_cache)
from elasticutils.contrib.django.estestcase import ESTestCase


Expand All @@ -20,12 +23,13 @@ def tearDown(self):
IndexableTest.cleanup_index(FakeDjangoMappingType.get_index())
reset_model_cache()

def persist_data(self, data):
def persist_data(self, data, id_field='id'):
for doc in data:
FakeModel(**doc)
FakeModel(id_field=id_field, **doc)

# Index the document with .index()
FakeDjangoMappingType.index(doc, id_=doc['id'])
FakeDjangoMappingType.index({k:str(v) for k,v in doc.iteritems()},
id_=str(doc[id_field]))

self.refresh(FakeDjangoMappingType.get_index())

Expand All @@ -51,6 +55,17 @@ def test_get_object(self):
obj = s[0]
eq_(obj.object.id, 1)

def test_get_object_with_custom_pk(self):
data = [
{'uuid': uuid.uuid4(), 'name': 'odin skullcrusher'},
{'uuid': uuid.uuid4(), 'name': 'olaf bloodbiter'},
]
self.persist_data(data, id_field='uuid')

s = S(FakeDjangoWithUuidMappingType).query(name__prefix='odin')
obj = s[0]
eq_(obj.object.uuid, data[0]['uuid'])

def test_get_indexable(self):
self.persist_data([
{'id': 1, 'name': 'odin skullcrusher'},
Expand Down
38 changes: 36 additions & 2 deletions elasticutils/contrib/django/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import uuid

from nose.tools import eq_

from elasticutils.contrib.django import get_es
from elasticutils.contrib.django.tasks import index_objects, unindex_objects
from elasticutils.contrib.django.tests import (
FakeDjangoMappingType, FakeModel, reset_model_cache)
FakeDjangoMappingType, FakeDjangoWithUuidMappingType, FakeModel,
reset_model_cache)
from elasticutils.contrib.django.estestcase import ESTestCase


Expand Down Expand Up @@ -76,7 +79,7 @@ def bulk_index(cls, *args, **kwargs):
index_objects(MockMappingType, [1, 2, 3], chunk_size=1)
eq_(MockMappingType.bulk_index_count, 3)

# test index and es kwargs
# test index and es kwargs
MockMappingType.index_kwarg = None
MockMappingType.es_kwarg = None
index_objects(MockMappingType, [1, 2, 3])
Expand All @@ -86,3 +89,34 @@ def bulk_index(cls, *args, **kwargs):
index_objects(MockMappingType, [1, 2, 3], es='crazy_es', index='crazy_index')
eq_(MockMappingType.index_kwarg, 'crazy_index')
eq_(MockMappingType.es_kwarg, 'crazy_es')

def test_tasks_with_custom_id_field(self):
docs = [
{'uuid': uuid.uuid4(), 'name': 'odin skullcrusher'},
{'uuid': uuid.uuid4(), 'name': 'heimdall kneebiter'},
{'uuid': uuid.uuid4(), 'name': 'erik rose'}
]

for d in docs:
FakeModel(id_field='uuid', **d)

ids = [d['uuid'] for d in docs]

# Test index_objects task
index_objects(FakeDjangoWithUuidMappingType, ids)
FakeDjangoWithUuidMappingType.refresh_index()
# need some sleep ?
from time import sleep; sleep(1)
# nothing was indexed because a StandardError was catched silently,
# may be explicit should be better.
eq_(FakeDjangoWithUuidMappingType.search().count(), 0)

# Test everything has been indexed
index_objects(FakeDjangoWithUuidMappingType, ids, id_field='uuid')
FakeDjangoWithUuidMappingType.refresh_index()
eq_(FakeDjangoWithUuidMappingType.search().count(), 3)

# Test unindex_objects task
unindex_objects(FakeDjangoWithUuidMappingType, ids)
FakeDjangoWithUuidMappingType.refresh_index()
eq_(FakeDjangoWithUuidMappingType.search().count(), 0)