diff --git a/elasticutils/contrib/django/__init__.py b/elasticutils/contrib/django/__init__.py index 1b8892c..6a9075f 100644 --- a/elasticutils/contrib/django/__init__.py +++ b/elasticutils/contrib/django/__init__.py @@ -197,6 +197,9 @@ class MappingType(BaseMappingType): `get_model()`. """ + + id_field = 'pk' + def get_object(self): """Returns the database object for this result @@ -205,7 +208,8 @@ def get_object(self): self.get_model().objects.get(pk=self._id) """ - return self.get_model().objects.get(pk=self._id) + kwargs = {self.id_field: self._id} + return self.get_model().objects.get(**kwargs) @classmethod def get_model(cls): diff --git a/elasticutils/contrib/django/tasks.py b/elasticutils/contrib/django/tasks.py index 7c65458..d19d9c9 100644 --- a/elasticutils/contrib/django/tasks.py +++ b/elasticutils/contrib/django/tasks.py @@ -10,7 +10,8 @@ @task -def index_objects(mapping_type, ids, chunk_size=100, es=None, index=None): +def index_objects(mapping_type, ids, chunk_size=100, id_field='id', es=None, + index=None): """Index documents of a specified mapping type. This allows for asynchronous indexing. @@ -48,21 +49,24 @@ def update_in_index(sender, instance, **kw): # Get the model this mapping type is based on. model = mapping_type.get_model() + filter_key = '{0}__in'.format(id_field) # Retrieve all the objects that we're going to index and do it in # bulk. for id_list in chunked(ids, chunk_size): documents = [] - for obj in model.objects.filter(id__in=id_list): + for obj in model.objects.filter(**{filter_key: id_list}): try: - documents.append(mapping_type.extract_document(obj.id, obj)) - except Exception as exc: + _id = str(getattr(obj, id_field)) + documents.append(mapping_type.extract_document(_id, obj)) + except StandardError as exc: log.exception('Unable to extract document {0}: {1}'.format( obj, repr(exc))) if documents: - mapping_type.bulk_index(documents, id_field='id', es=es, index=index) + mapping_type.bulk_index(documents, id_field=id_field, es=es, + index=index) @task diff --git a/elasticutils/contrib/django/tests/__init__.py b/elasticutils/contrib/django/tests/__init__.py index 7941d60..d181506 100644 --- a/elasticutils/contrib/django/tests/__init__.py +++ b/elasticutils/contrib/django/tests/__init__.py @@ -1,6 +1,7 @@ # We need to put these in a separate module so they're easy to import # on a test-by-test basis so that we can skip django-requiring tests # if django isn't installed. +from uuid import UUID from elasticutils.contrib.django import MappingType, Indexable @@ -22,14 +23,20 @@ class SearchQuerySet(object): # Yes. This is kind of crazy, but ... whatever. def __init__(self, model): self.model = model + self.id_field = model.id_field self.steps = [] - def get(self, pk): - pk = int(pk) - return [m for m in _model_cache if m.id == pk][0] - - def filter(self, id__in=None): - self.steps.append(('filter', id__in)) + def get(self, pk=None, uuid=None): + if pk: + pk = int(pk) + return [m for m in _model_cache if m.id == pk][0] + if uuid: + uuid = UUID(uuid) + return [m for m in _model_cache if m.uuid == uuid][0] + return [] + + def filter(self, id__in=None, uuid__in=None): + self.steps.append(('filter', id__in or uuid__in)) return self def order_by(self, *fields): @@ -47,7 +54,8 @@ def __iter__(self): for mem in self.steps: if mem[0] == 'filter': - objs = [obj for obj in objs if obj.id in mem[1]] + objs = [obj for obj in objs + if getattr(obj, self.id_field) in mem[1]] elif mem[0] == 'order_by': order_by_field = mem[1][0] elif mem[0] == 'values_list': @@ -58,7 +66,7 @@ def __iter__(self): if values_list: # Note: Hard-coded to just id and flat - objs = [obj.id for obj in objs] + objs = [getattr(obj, self.id_field) for obj in objs] return iter(objs) @@ -66,8 +74,8 @@ class Manager(object): def get_query_set(self): return SearchQuerySet(self) - def get(self, pk): - return self.get_query_set().get(pk) + def get(self, pk=None, uuid=None): + return self.get_query_set().get(pk=pk, uuid=uuid) def filter(self, *args, **kwargs): return self.get_query_set().filter(*args, **kwargs) @@ -84,6 +92,7 @@ class FakeModel(object): objects = Manager() def __init__(self, **kw): + self.objects.id_field = kw.pop('id_field', 'id') self._doc = kw for key in kw: setattr(self, key, kw[key]) @@ -102,3 +111,12 @@ def extract_document(cls, obj_id, obj=None): 'what to do with these args.') return obj._doc + +class FakeDjangoWithUuidMappingType(FakeDjangoMappingType): + id_field = 'uuid' + + @classmethod + def extract_document(cls, obj_id, obj=None): + doc = super(FakeDjangoWithUuidMappingType, cls)\ + .extract_document(obj_id, obj=obj) + return {k:str(v) for k,v in doc.iteritems()} diff --git a/elasticutils/contrib/django/tests/test_models.py b/elasticutils/contrib/django/tests/test_models.py index a9c6762..8e4c172 100644 --- a/elasticutils/contrib/django/tests/test_models.py +++ b/elasticutils/contrib/django/tests/test_models.py @@ -1,8 +1,11 @@ +import uuid + from nose.tools import eq_ from elasticutils.contrib.django import S, get_es from elasticutils.contrib.django.tests import ( - FakeDjangoMappingType, FakeModel, reset_model_cache) + FakeDjangoMappingType, FakeDjangoWithUuidMappingType, FakeModel, + reset_model_cache) from elasticutils.contrib.django.estestcase import ESTestCase @@ -20,12 +23,13 @@ def tearDown(self): IndexableTest.cleanup_index(FakeDjangoMappingType.get_index()) reset_model_cache() - def persist_data(self, data): + def persist_data(self, data, id_field='id'): for doc in data: - FakeModel(**doc) + FakeModel(id_field=id_field, **doc) # Index the document with .index() - FakeDjangoMappingType.index(doc, id_=doc['id']) + FakeDjangoMappingType.index({k:str(v) for k,v in doc.iteritems()}, + id_=str(doc[id_field])) self.refresh(FakeDjangoMappingType.get_index()) @@ -51,6 +55,17 @@ def test_get_object(self): obj = s[0] eq_(obj.object.id, 1) + def test_get_object_with_custom_pk(self): + data = [ + {'uuid': uuid.uuid4(), 'name': 'odin skullcrusher'}, + {'uuid': uuid.uuid4(), 'name': 'olaf bloodbiter'}, + ] + self.persist_data(data, id_field='uuid') + + s = S(FakeDjangoWithUuidMappingType).query(name__prefix='odin') + obj = s[0] + eq_(obj.object.uuid, data[0]['uuid']) + def test_get_indexable(self): self.persist_data([ {'id': 1, 'name': 'odin skullcrusher'}, diff --git a/elasticutils/contrib/django/tests/test_tasks.py b/elasticutils/contrib/django/tests/test_tasks.py index 216df23..61b1431 100644 --- a/elasticutils/contrib/django/tests/test_tasks.py +++ b/elasticutils/contrib/django/tests/test_tasks.py @@ -1,9 +1,12 @@ +import uuid + from nose.tools import eq_ from elasticutils.contrib.django import get_es from elasticutils.contrib.django.tasks import index_objects, unindex_objects from elasticutils.contrib.django.tests import ( - FakeDjangoMappingType, FakeModel, reset_model_cache) + FakeDjangoMappingType, FakeDjangoWithUuidMappingType, FakeModel, + reset_model_cache) from elasticutils.contrib.django.estestcase import ESTestCase @@ -76,7 +79,7 @@ def bulk_index(cls, *args, **kwargs): index_objects(MockMappingType, [1, 2, 3], chunk_size=1) eq_(MockMappingType.bulk_index_count, 3) - # test index and es kwargs + # test index and es kwargs MockMappingType.index_kwarg = None MockMappingType.es_kwarg = None index_objects(MockMappingType, [1, 2, 3]) @@ -86,3 +89,34 @@ def bulk_index(cls, *args, **kwargs): index_objects(MockMappingType, [1, 2, 3], es='crazy_es', index='crazy_index') eq_(MockMappingType.index_kwarg, 'crazy_index') eq_(MockMappingType.es_kwarg, 'crazy_es') + + def test_tasks_with_custom_id_field(self): + docs = [ + {'uuid': uuid.uuid4(), 'name': 'odin skullcrusher'}, + {'uuid': uuid.uuid4(), 'name': 'heimdall kneebiter'}, + {'uuid': uuid.uuid4(), 'name': 'erik rose'} + ] + + for d in docs: + FakeModel(id_field='uuid', **d) + + ids = [d['uuid'] for d in docs] + + # Test index_objects task + index_objects(FakeDjangoWithUuidMappingType, ids) + FakeDjangoWithUuidMappingType.refresh_index() + # need some sleep ? + from time import sleep; sleep(1) + # nothing was indexed because a StandardError was catched silently, + # may be explicit should be better. + eq_(FakeDjangoWithUuidMappingType.search().count(), 0) + + # Test everything has been indexed + index_objects(FakeDjangoWithUuidMappingType, ids, id_field='uuid') + FakeDjangoWithUuidMappingType.refresh_index() + eq_(FakeDjangoWithUuidMappingType.search().count(), 3) + + # Test unindex_objects task + unindex_objects(FakeDjangoWithUuidMappingType, ids) + FakeDjangoWithUuidMappingType.refresh_index() + eq_(FakeDjangoWithUuidMappingType.search().count(), 0)