diff --git a/django_elasticsearch/managers.py b/django_elasticsearch/managers.py index 1f5cc7f..531edf7 100644 --- a/django_elasticsearch/managers.py +++ b/django_elasticsearch/managers.py @@ -28,14 +28,30 @@ u'ForeignKey': 'object', u'OneToOneField': 'object', - u'ManyToManyField': 'object' + u'ManyToManyField': 'object', + + # reverse relationship + u'ManyToOneRel': 'object', + u'ManyToManyRel': 'object', + + # dj <1.8 + u'RelatedObject': 'object' } def needs_instance(f): def wrapper(*args, **kwargs): if args[0].instance is None: - raise AttributeError("This method requires an instance of the model.") + raise AttributeError(u"This method requires an instance of the model.") + return f(*args, **kwargs) + return wrapper + + +def no_abstract(f): + def wrapper(*args, **kwargs): + if args[0].model.Elasticsearch.abstract is True: + raise ValueError(u"This model {0} is abstract - not indexed.".format( + args[0].instance.__class__)) return f(*args, **kwargs) return wrapper @@ -60,6 +76,7 @@ def __init__(self, k): self.serializer = None self._mapping = None + @no_abstract def get_index(self): return self.model.Elasticsearch.index @@ -67,6 +84,7 @@ def get_index(self): def index(self): return self.get_index() + @no_abstract def get_doc_type(self): return (self.model.Elasticsearch.doc_type or 'model-{0}'.format(self.model.__name__)) @@ -112,6 +130,7 @@ def deserialize(self, source): else: return serializer.deserialize(source) + @no_abstract @needs_instance def do_index(self): body = self.serialize() @@ -120,6 +139,7 @@ def do_index(self): id=self.instance.id, body=body) + @no_abstract @needs_instance def delete(self): es_client.delete(index=self.index, @@ -127,6 +147,7 @@ def delete(self): id=self.instance.id, ignore=404) + @no_abstract def get(self, **kwargs): if 'pk' in kwargs: pk = kwargs.pop('pk') @@ -140,6 +161,7 @@ def get(self, **kwargs): return self.queryset.get(id=pk) + @no_abstract @needs_instance def mlt(self, **kwargs): """ @@ -158,6 +180,7 @@ def mlt(self, **kwargs): """ return self.queryset.mlt(id=self.instance.id, **kwargs) + @no_abstract def count(self): return self.queryset.count() @@ -165,6 +188,7 @@ def count(self): def queryset(self): return EsQueryset(self.model) + @no_abstract def search(self, query, facets=None, facets_limit=None, global_facets=True, suggest_fields=None, suggest_limit=None, @@ -203,18 +227,22 @@ def search(self, query, return q.query(query) # Convenience methods + @no_abstract def all(self): """ proxy to an empty search. """ return self.search("") + @no_abstract def filter(self, **kwargs): return self.queryset.filter(**kwargs) + @no_abstract def exclude(self, **kwargs): return self.queryset.exclude(**kwargs) + @no_abstract def complete(self, field_name, query): """ Returns a list of close values for auto-completion @@ -227,6 +255,7 @@ def complete(self, field_name, query): complete_name = "{0}_complete".format(field_name) return self.queryset.complete(complete_name, query) + @no_abstract def do_update(self): """ Hit this if you are in a hurry, @@ -248,24 +277,29 @@ def make_mapping(self): for field_name in self.get_fields(): try: - field = self.model._meta.get_field(field_name) + field, a, b, c = self.model._meta.get_field_by_name(field_name) except FieldDoesNotExist: # abstract field mapping = {} else: mapping = {'type': ELASTICSEARCH_FIELD_MAP.get( - field.get_internal_type(), 'string')} + field.__class__.__name__, 'string')} try: # if an analyzer is set as default, use it. # TODO: could be also tokenizer, filter, char_filter if mapping['type'] == 'string': analyzer = settings.ELASTICSEARCH_SETTINGS['analysis']['default'] mapping['analyzer'] = analyzer - except (ValueError, AttributeError, KeyError, TypeError): + except (AttributeError, KeyError): + # AttributeError - settings.ELASTICSEARCH_SETTINGS is not set + # KeyError - either 'analysis' or 'default' is not set pass + try: mapping.update(self.model.Elasticsearch.mappings[field_name]) - except (AttributeError, KeyError, TypeError): + except (AttributeError, KeyError): + # AttributeError - Elasticsearch.mappings is not set + # KeyError - Elastisearch.mappings[field_name] is not set pass mappings[field_name] = mapping @@ -281,6 +315,7 @@ def make_mapping(self): } } + @no_abstract def get_mapping(self): if self._mapping is None: # TODO: could be done once for every index/doc_type ? @@ -296,6 +331,7 @@ def get_settings(self): """ return es_client.indices.get_settings(index=self.index) + @no_abstract @needs_instance def diff(self, source=None): """ @@ -321,6 +357,7 @@ def diff(self, source=None): return diff + @no_abstract def create_index(self, ignore=True): body = {} if hasattr(settings, 'ELASTICSEARCH_SETTINGS'): @@ -329,15 +366,18 @@ def create_index(self, ignore=True): es_client.indices.create(self.index, body=body, ignore=ignore and 400) + es_client.indices.put_mapping(index=self.index, doc_type=self.doc_type, body=self.make_mapping()) + @no_abstract def reindex_all(self, queryset=None): q = queryset or self.model.objects.all() for instance in q: instance.es.do_index() + @no_abstract def flush(self): es_client.indices.delete_mapping(index=self.index, doc_type=self.doc_type, diff --git a/django_elasticsearch/models.py b/django_elasticsearch/models.py index df34dae..8768c15 100644 --- a/django_elasticsearch/models.py +++ b/django_elasticsearch/models.py @@ -16,6 +16,7 @@ class Meta: abstract = True class Elasticsearch: + abstract = False index = getattr(settings, 'ELASTICSEARCH_DEFAULT_INDEX', 'django') doc_type = None # defaults to 'model-{model.name}' mapping = None @@ -57,20 +58,20 @@ def add_es_manager(sender, **kwargs): def es_save_callback(sender, instance, **kwargs): # TODO: batch ?! @task ?! - if not issubclass(sender, EsIndexable): + if not issubclass(sender, EsIndexable) or sender.Elasticsearch.abstract: return instance.es.do_index() def es_delete_callback(sender, instance, **kwargs): - if not issubclass(sender, EsIndexable): + if not issubclass(sender, EsIndexable) or sender.Elasticsearch.abstract: return instance.es.delete() def es_syncdb_callback(sender, app, created_models, **kwargs): for model in created_models: - if issubclass(model, EsIndexable): + if issubclass(model, EsIndexable) and not sender.Elasticsearch.abstract: model.es.create_index() if getattr(settings, 'ELASTICSEARCH_AUTO_INDEX', False): diff --git a/django_elasticsearch/serializers.py b/django_elasticsearch/serializers.py index 474e64e..b494f7c 100644 --- a/django_elasticsearch/serializers.py +++ b/django_elasticsearch/serializers.py @@ -1,11 +1,15 @@ import json import datetime +from django.db.models import Model from django.db.models import FieldDoesNotExist from django.db.models.fields.related import ManyToManyField class EsSerializer(object): + def __init__(self, *args, **kwargs): + pass + def serialize(self, instance): raise NotImplementedError() @@ -24,6 +28,17 @@ def deserialize(self, source): ids = [e[pk_field.name] for e in source] return self.model.objects.filter(**{pk_field.name + '__in': ids}) +def post_save_attr(f): + # Since related fields can't be set on instanciation + # this decorator saves them for later + def wrapper(*args, **kwargs): + val = f(*args, **kwargs) + serializer = args[0] + field_name = args[2] + serializer._post_save_attrs[field_name] = val + return None + return wrapper + class EsJsonToModelMixin(object): """ @@ -31,65 +46,121 @@ class EsJsonToModelMixin(object): from the json elasticsearch source (and disables db operations on the model). """ + def __init__(self, *args, **kwargs): + self._post_save_attrs = {} + super(EsJsonToModelMixin, self).__init__(*args, **kwargs) def instanciate(self, attrs): instance = self.model(**attrs) instance._is_es_deserialized = True + + # set m2m, fks and such + # for k, v in self._post_save_attrs.iteritems(): + # if v: + # try: + # setattr(instance, k, v) + # except TypeError, ValueError: + # # bypass ManyRelatedManager complaining + # # TODO + # # super(Model, instance).__setattr__(k, v) + # pass + return instance - def nested_deserialize(self, field, source): + def nested_deserialize(self, source, rel): # check for Elasticsearch.serializer on the related model - if source: - if hasattr(field.rel.to, 'Elasticsearch'): - serializer = field.rel.to.es.get_serializer() + model = rel.related_model + if source and rel: + if hasattr(model, 'Elasticsearch'): + serializer = model.es.get_serializer() obj = serializer.deserialize(source) return obj elif 'id' in source and 'value' in source: - # id/value fallback - return field.rel.to.objects.get(pk=source.get('id')) + # fallback + return source - def deserialize_field(self, source, field_name): - method_name = 'deserialize_{0}'.format(field_name) - if hasattr(self, method_name): - return getattr(self, method_name)(source, field_name) + def deserialize_type_datetimefield(self, source, field_name): + val = source.get(field_name) + if val: + return datetime.datetime.strptime(val, '%Y-%m-%dT%H:%M:%S.%f') + + def deserialize_type_datefield(self, source, field_name): + val = source.get(field_name) + if val: + return datetime.datetime.strptime(val, '%Y-%m-%d') - field = self.model._meta.get_field(field_name) - field_type_method_name = 'deserialize_type_{0}'.format( - field.__class__.__name__.lower()) - if hasattr(self, field_type_method_name): - return getattr(self, field_type_method_name)(source, field_name) + def deserialize_type_timefield(self, source, field_name): + val = source.get(field_name) + if val: + return datetime.datetime.strptime(val, '%H:%M:%S') + def deserialize_type_rel(self, source, field_name): + rel, model, direct, m2m = self.model._meta.get_field_by_name(field_name) + val = source.get(field_name) + if val: + return [self.nested_deserialize(r, rel) for r in val] + + @post_save_attr + def deserialize_type_manytoonerel(self, source, field_name): + # reverse fk + return self.deserialize_type_rel(source, field_name) + + @post_save_attr + def deserialize_type_manytomanyrel(self, source, field_name): + # reverse m2m + return self.deserialize_type_rel(source, field_name) + + @post_save_attr + def deserialize_type_foreignkey(self, source, field_name): + rel, model, direct, m2m = self.model._meta.get_field_by_name(field_name) val = source.get(field_name) + if val: + return self.nested_deserialize(val, rel) - # datetime - typ = field.get_internal_type() - if val and typ in ('DateField', 'DateTimeField'): - return datetime.datetime.strptime(val, '%Y-%m-%dT%H:%M:%S.%f') + @post_save_attr + def deserialize_type_onetoonefield(self, source, field_name): + return self.deserialize_type_foreignkey(source, field_name) - if field.rel: - # M2M - if isinstance(field, ManyToManyField): - raise AttributeError + @post_save_attr + def deserialize_type_manytomanyfield(self, source, field_name): + return self.deserialize_type_rel(source, field_name) - # FK, OtO - return self.nested_deserialize(field, source.get(field_name)) + # django <1.8 hack + @post_save_attr + def deserialize_type_relatedobject(self, source, field_name): + return self.deserialize_type_rel(source, field_name) - return source.get(field_name) + def deserialize_field(self, source, field_name): + method_name = 'deserialize_{0}'.format(field_name) + if hasattr(self, method_name): + return getattr(self, method_name)(source, field_name) + + try: + field, model, direct, m2m = self.model._meta.get_field_by_name(field_name) + except FieldDoesNotExist: + # Abstract field + field = None + + if field: + field_type_method_name = 'deserialize_type_{0}'.format( + field.__class__.__name__.lower()) + if hasattr(self, field_type_method_name): + return getattr(self, field_type_method_name)(source, field_name) + + return source.get(field_name) def deserialize(self, source): """ Returns a model instance """ attrs = {} - for k, v in source.iteritems(): - try: - attrs[k] = self.deserialize_field(source, k) - except (AttributeError, FieldDoesNotExist): - # m2m, abstract - pass + + for field_name in source.iterkeys(): + val = self.deserialize_field(source, field_name) + if val: + attrs[field_name] = val return self.instanciate(attrs) - # TODO: we can assign m2ms now class EsModelToJsonMixin(object): @@ -98,6 +169,36 @@ def __init__(self, model, max_depth=2, cur_depth=1): # used in case of related field on 'self' to avoid infinite loop self.cur_depth = cur_depth self.max_depth = max_depth + super(EsModelToJsonMixin, self).__init__(model, max_depth=max_depth, cur_depth=cur_depth) + + def serialize_type_rel(self, instance, field_name): + if self.cur_depth >= self.max_depth: + return + + return [self.nested_serialize(r) + for r in getattr(instance, field_name).all()] + + def serialize_type_manytoonerel(self, instance, field_name): + return self.serialize_type_rel(instance, field_name) + + def serialize_type_manytomanyrel(self, instance, field_name): + return self.serialize_type_rel(instance, field_name) + + def serialize_type_foreignkey(self, instance, field_name): + if self.cur_depth >= self.max_depth: + return + + return self.nested_serialize(getattr(instance, field_name)) + + def serialize_type_onetoonefield(self, instance, field_name): + return self.serialize_type_foreignkey(instance, field_name) + + def serialize_type_manytomanyfield(self, instance, field_name): + return self.serialize_type_rel(instance, field_name) + + # django <1.8 hack + def serialize_type_relatedobject(self, instance, field_name): + return self.serialize_type_rel(instance, field_name) def serialize_field(self, instance, field_name): method_name = 'serialize_{0}'.format(field_name) @@ -105,38 +206,28 @@ def serialize_field(self, instance, field_name): return getattr(self, method_name)(instance, field_name) try: - field = self.model._meta.get_field(field_name) + field, model, direct, m2m = self.model._meta.get_field_by_name(field_name) except FieldDoesNotExist: # Abstract field - pass - else: + field = None + + if field: field_type_method_name = 'serialize_type_{0}'.format( field.__class__.__name__.lower()) if hasattr(self, field_type_method_name): return getattr(self, field_type_method_name)(instance, field_name) - if field.rel: - # M2M - if isinstance(field, ManyToManyField): - return [self.nested_serialize(r) - for r in getattr(instance, field.name).all()] - - rel = getattr(instance, field.name) - # FK, OtO - if rel: # should be a model instance - if self.cur_depth >= self.max_depth: - return - - return self.nested_serialize(rel) - try: return getattr(instance, field_name) except AttributeError: raise AttributeError("The serializer doesn't know how to serialize {0}, " - "please provide it a {1} method." + "please provide it a '{1}' method." "".format(field_name, method_name)) def nested_serialize(self, rel): + if rel is None: + return + # check for Elasticsearch.serializer on the related model if hasattr(rel, 'Elasticsearch'): serializer = rel.es.get_serializer(max_depth=self.max_depth, @@ -147,6 +238,7 @@ def nested_serialize(self, rel): # Fallback on a dict with id + __unicode__ value of the related model instance. return dict(id=rel.pk, value=unicode(rel)) + def format(self, instance): # from a model instance to a dict fields = self.model.es.get_fields() diff --git a/django_elasticsearch/tests/test_indexable.py b/django_elasticsearch/tests/test_indexable.py index c751f9e..571118a 100644 --- a/django_elasticsearch/tests/test_indexable.py +++ b/django_elasticsearch/tests/test_indexable.py @@ -7,42 +7,44 @@ from django_elasticsearch.managers import es_client from django_elasticsearch.tests.utils import withattrs -from test_app.models import TestModel +from test_app.models import Test2Model class EsIndexableTestCase(TestCase): def setUp(self): # auto index is disabled for tests so we do it manually - TestModel.es.flush() - self.instance = TestModel.objects.create(username=u"1", - first_name=u"woot", - last_name=u"foo") + Test2Model.es.flush() + self.instance = Test2Model.objects.create(email=u"1", + char=u"woot", + text=u"foo") self.instance.es.do_index() - TestModel.es.do_update() + Test2Model.es.do_update() def tearDown(self): super(EsIndexableTestCase, self).tearDown() - es_client.indices.delete(index=TestModel.es.get_index()) + es_client.indices.delete(index=Test2Model.es.get_index()) def test_needs_instance(self): with self.assertRaises(AttributeError): - TestModel.es.do_index() + Test2Model.es.do_index() def test_check_cluster(self): - self.assertEqual(TestModel.es.check_cluster(), True) + self.assertEqual(Test2Model.es.check_cluster(), True) def test_get_api(self): self.assertEqual(self.instance.es.get(), - TestModel.es.get(pk=self.instance.pk), - TestModel.es.get(id=self.instance.pk)) + Test2Model.es.get(pk=self.instance.pk), + Test2Model.es.get(id=self.instance.pk)) with self.assertRaises(AttributeError): - TestModel.es.get() + Test2Model.es.get() def test_do_index(self): self.instance.es.do_index() - r = TestModel.es.deserialize(self.instance.es.get()) - self.assertTrue(isinstance(r, TestModel)) + r = es_client.get(index=self.instance.es.get_index(), + doc_type=self.instance.es.get_doc_type(), + id=self.instance.id) + self.assertEqual(r['_source']['id'], self.instance.id) def test_delete(self): self.instance.es.delete() @@ -50,44 +52,44 @@ def test_delete(self): self.instance.es.get() def test_mlt(self): - qs = self.instance.es.mlt(mlt_fields=['first_name',], min_term_freq=1, min_doc_freq=1) + qs = self.instance.es.mlt(mlt_fields=['char',], min_term_freq=1, min_doc_freq=1) self.assertEqual(qs.count(), 0) - a = TestModel.objects.create(username=u"2", first_name=u"woot", last_name=u"foo fooo") + a = Test2Model.objects.create(email=u"2", char=u"woot", text=u"foo fooo") a.es.do_index() a.es.do_update() - results = self.instance.es.mlt(mlt_fields=['first_name',], min_term_freq=1, min_doc_freq=1).deserialize() + results = self.instance.es.mlt(mlt_fields=['char',], min_term_freq=1, min_doc_freq=1).deserialize() self.assertEqual(results.count(), 1) self.assertEqual(results[0], a) def test_search(self): - hits = TestModel.es.search('wee') + hits = Test2Model.es.search('wee') self.assertEqual(hits.count(), 0) - hits = TestModel.es.search('woot') + hits = Test2Model.es.search('woot') self.assertEqual(hits.count(), 1) def test_search_with_facets(self): - s = TestModel.es.search('whatever').facet(['first_name',]) + s = Test2Model.es.search('whatever').facet(['char',]) self.assertEqual(s.count(), 0) expected = [{u'doc_count': 1, u'key': u'woot'}] self.assertEqual(s.facets['doc_count'], 1) - self.assertEqual(s.facets['first_name']['buckets'], expected) + self.assertEqual(s.facets['char']['buckets'], expected) def test_fuzziness(self): - hits = TestModel.es.search('woo') # instead of woot + hits = Test2Model.es.search('woo') # instead of woot self.assertEqual(hits.count(), 1) - hits = TestModel.es.search('woo', fuzziness=0) + hits = Test2Model.es.search('woo', fuzziness=0) self.assertEqual(hits.count(), 0) - hits = TestModel.es.search('waat', fuzziness=2) + hits = Test2Model.es.search('waat', fuzziness=2) self.assertEqual(hits.count(), 1) - @withattrs(TestModel.Elasticsearch, 'fields', ['username']) - @withattrs(TestModel.Elasticsearch, 'mappings', {"username": {"boost": 20}}) - @withattrs(TestModel.Elasticsearch, 'completion_fields', ['username']) + @withattrs(Test2Model.Elasticsearch, 'fields', ['email']) + @withattrs(Test2Model.Elasticsearch, 'mappings', {"email": {"boost": 20}}) + @withattrs(Test2Model.Elasticsearch, 'completion_fields', ['email']) @override_settings(ELASTICSEARCH_SETTINGS={ "analysis": { "default": "test_analyzer", @@ -102,73 +104,81 @@ def test_fuzziness(self): def test_custom_mapping(self): # should take the defaults into accounts expected = { - TestModel.Elasticsearch.doc_type: { + Test2Model.Elasticsearch.doc_type: { 'properties': { - 'username': { + 'email': { 'analyzer': 'test_analyzer', 'boost': 20, 'type': 'string' }, - 'username_complete': { + 'email_complete': { 'type': 'completion' } } } } # reset cache on _fields - self.assertEqual(expected, TestModel.es.make_mapping()) + self.assertEqual(expected, Test2Model.es.make_mapping()) - @withattrs(TestModel.Elasticsearch, 'completion_fields', ['first_name']) + @withattrs(Test2Model.Elasticsearch, 'completion_fields', ['char']) def test_auto_completion(self): # Note: we need to call setUp again to create the mapping taking # the new field(s) into account :( - TestModel.es.flush() - TestModel.es.do_update() - data = TestModel.es.complete('first_name', 'woo') + Test2Model.es.flush() + Test2Model.es.do_update() + data = Test2Model.es.complete('char', 'woo') self.assertTrue('woot' in data) - @withattrs(TestModel.Elasticsearch, 'fields', ['username', 'date_joined']) - def test_get_mapping(self): - TestModel.es._mapping = None - TestModel.es.flush() - TestModel.es.do_update() - - expected = {u'date_joined': {u'format': u'dateOptionalTime', u'type': u'date'}, - u'username': {u'index': u'not_analyzed', u'type': u'string'}} + def _test_mapping(self, expected): + Test2Model.es._mapping = None + Test2Model.es.flush() + Test2Model.es.do_update() # Reset the eventual cache on the Model mapping - mapping = TestModel.es.get_mapping() - TestModel.es._mapping = None - self.assertEqual(expected, mapping) + mapping = Test2Model.es.get_mapping() + Test2Model.es._mapping = None + self.assertEqual(expected, mapping) + + @withattrs(Test2Model.Elasticsearch, 'fields', ['email', 'datef']) + def test_get_mapping(self): + expected = {u'datef': {u'format': u'dateOptionalTime', u'type': u'date'}, + u'email': {u'index': u'not_analyzed', u'type': u'string'}} + self._test_mapping(expected) + + @withattrs(Test2Model.Elasticsearch, 'fields', ['dummies', 'dummiesm2m']) + def test_reverse_relationship_mapping(self): + expected = {u'dummies': {u'type': u'object'}, + u'dummiesm2m': {u'type': u'object'}} + self._test_mapping(expected) def test_get_settings(self): # Note i don't really know what's in there so i just check # it doesn't crash and deserialize well. - settings = TestModel.es.get_settings() + settings = Test2Model.es.get_settings() self.assertEqual(dict, type(settings)) def test_custom_index(self): - es_client.indices.exists(TestModel.Elasticsearch.index) + es_client.indices.exists(Test2Model.Elasticsearch.index) def test_custom_doc_type(self): es_client.indices.exists_type('django-test', 'test-doc-type') def test_reevaluate(self): # test that the request is resent if something changed filters, ordering, ndx - TestModel.es.flush() - TestModel.es.do_update() + Test2Model.es.flush() + Test2Model.es.do_update() - q = TestModel.es.search('woot') + q = Test2Model.es.search('woot') self.assertTrue(self.instance in q.deserialize()) # evaluate - q = q.filter(last_name='grut') + q = q.filter(text='grut') self.assertFalse(self.instance in q.deserialize()) # evaluate def test_diff(self): self.assertEqual(self.instance.es.diff(), {}) - self.instance.first_name = 'pouet' + self.instance.char = 'pouet' expected = { - u'first_name': { + u'char': { 'es': u'woot', 'db': u'pouet' } @@ -178,5 +188,5 @@ def test_diff(self): self.assertEqual(self.instance.es.diff(source=self.instance.es.get()), {}) # force diff to reload from db - deserialized = TestModel.es.all().deserialize()[0] + deserialized = Test2Model.es.all().deserialize()[0] self.assertEqual(deserialized.es.diff(), {}) diff --git a/django_elasticsearch/tests/test_serializer.py b/django_elasticsearch/tests/test_serializer.py index 5dd1def..eb5bc5d 100644 --- a/django_elasticsearch/tests/test_serializer.py +++ b/django_elasticsearch/tests/test_serializer.py @@ -1,3 +1,5 @@ +import datetime + from django.test import TestCase from django_elasticsearch.utils import dict_depth @@ -18,15 +20,19 @@ def serialize_char(self, instance, field_name): class EsJsonSerializerTestCase(TestCase): def setUp(self): Test2Model.es.flush() - self.target = Dummy.objects.create() + self.target = Dummy.objects.create(foo='test') self.instance = Test2Model.objects.create(fk=self.target, oto=self.target) # to test for infinite nested recursion self.instance.fkself = self.instance self.instance.save() - self.instance.es.do_index() - Test2Model.es.do_update() + self.instance.mtm.add(self.target) + + # reverse relations + self.target.reversefk = self.instance + self.target.save() + self.target.reversem2m.add(self.instance) def tearDown(self): super(EsJsonSerializerTestCase, self).tearDown() @@ -43,42 +49,120 @@ def test_dynamic_serializer_import(self): self.assertTrue(isinstance(obj, basestring)) def test_deserialize(self): - instance = Test2Model.es.deserialize({'char': 'test'}) - self.assertEqual(instance.char, 'test') + source = { + 'char': 'char test', + 'text': 'text\ntest', + 'email': 'test@test.com', + 'filef': 'f/test.png', + 'filepf': 'f/test/test.png', + 'ipaddr': '192.168.0.1', + 'genipaddr': '192.168.0.2', + 'slug': 'test', + 'url': 'http://www.perdu.com/', + + 'intf': 42, + 'bigint': 922337203685477585, + 'intlist': [1, 2, 3], + 'floatf': 15.5, + 'dec': 5/3, + 'posint': 13, + 'smint': -5, + 'possmint': 6, + + 'boolf': True, + 'nullboolf': None, + + 'datef': '2017-05-02', + 'datetf': {'iso': '2017-05-02T15:22:05.5432'}, + 'timef': '05:57:44', + } + + instance = Test2Model.es.deserialize(source) + self.assertTrue(isinstance(instance, Test2Model)) + + self.assertEqual(instance.char, 'char test') + self.assertEqual(instance.text, 'text\ntest') + self.assertEqual(instance.email, 'test@test.com') + self.assertEqual(instance.filef, 'f/test.png') + self.assertEqual(instance.filepf, 'f/test/test.png') + self.assertEqual(instance.ipaddr, '192.168.0.1') + self.assertEqual(instance.genipaddr, '192.168.0.2') + self.assertEqual(instance.slug, 'test') + self.assertEqual(instance.url, 'http://www.perdu.com/') + + self.assertEqual(instance.intf, 45) + self.assertEqual(instance.bigint, 922337203685477585) + self.assertEqual(instance.intlist, [1, 2, 3]) + self.assertEqual(instance.floatf, 15.5) + self.assertEqual(instance.dec, 5/3) + self.assertEqual(instance.posint, 13) + self.assertEqual(instance.smint, -5) + self.assertEqual(instance.possmint, 6) + + self.assertEqual(instance.boolf, True) + self.assertEqual(instance.nullboolf, None) + + self.assertEqual(instance.datef, datetime.datetime(2017, 5, 2, 0, 0)) + self.assertEqual(instance.datetf, datetime.datetime(2017, 5, 2, 15, 22, 5, 543200)) + self.assertEqual(instance.timef, datetime.datetime(1900, 1, 1, 5, 57, 44)) + self.assertRaises(ValueError, instance.save) + def test_deserialize_related(self): + source = { + 'fk': {'id': self.target.id, 'foo': 'test'}, + 'oto': {'id': self.target.id, 'foo': 'test'}, + 'fkself': {'id': self.instance.id, 'char': 'test'}, + 'mtm': [{'id': self.target.id, 'foo': 'test'}] + } + + with self.assertNumQueries(0): + instance = Test2Model.es.deserialize(source) + self.assertTrue(isinstance(instance, Test2Model)) + + self.assertTrue(isinstance(instance.fk, Dummy)) + self.assertEqual(instance.fk.id, self.target.id) + self.assertTrue(isinstance(instance.oto, Dummy)) + self.assertEqual(instance.oto.id, self.target.id) + self.assertTrue(hasattr(instance.mtm, '__iter__')) # need something more explicit + self.assertEqual(instance.mtm[0].id, self.target.id) + @withattrs(Test2Model.Elasticsearch, 'serializer_class', CustomSerializer) def test_custom_serializer(self): json = self.instance.es.serialize() self.assertIn('"char": "FOO"', json) + @withattrs(Test2Model.Elasticsearch, 'fields', ['id', 'fk']) def test_nested_fk(self): - # if the target model got a Elasticsearch.serializer, we use it - u = Test2Model.es.all()[0] - self.assertTrue('fk' in u) - self.assertTrue(type(u['fk']) is dict) + serializer = Test2Model.es.get_serializer() + obj = serializer.format(self.instance) + expected = {'id': 1, 'fk': {'id':1, 'foo': 'test'}} + self.assertEqual(obj, expected) + @withattrs(Test2Model.Elasticsearch, 'fields', ['id', 'oto']) def test_nested_oto(self): - # if the target model got a Elasticsearch.serializer, we use it - u = Test2Model.es.all()[0] - self.assertTrue('oto' in u) - self.assertTrue(type(u['oto']) is dict) + serializer = Test2Model.es.get_serializer() + obj = serializer.format(self.instance) + expected = {'id': 1, 'oto': {'id':1, 'foo': 'test'}} + self.assertEqual(obj, expected) - @withattrs(Test2Model.Elasticsearch, 'fields', ['fkself',]) + @withattrs(Test2Model.Elasticsearch, 'fields', ['id', 'fkself']) def test_self_fk_depth_test(self): Test2Model.es.serializer = None # reset cache serializer = Test2Model.es.get_serializer(max_depth=3) obj = serializer.format(self.instance) self.assertEqual(dict_depth(obj), 3) + @withattrs(Test2Model.Elasticsearch, 'fields', ['id', 'mtm']) def test_nested_m2m(self): - u = Test2Model.es.all()[0] - self.assertTrue('mtm' in u) - self.assertTrue(type(u['mtm']) is list) + serializer = Test2Model.es.get_serializer() + obj = serializer.format(self.instance) + expected = {'id': 1, 'mtm': [{'id':1, 'foo': 'test'},]} + self.assertEqual(obj, expected) @withattrs(Test2Model.Elasticsearch, 'fields', ['abstract_prop', 'abstract_method']) def test_abstract_field(self): - serializer = Test2Model.es.get_serializer() + serializer = Test2Model.es.get_serializer() obj = serializer.format(self.instance) expected = {'abstract_method': 'woot', 'abstract_prop': 'weez'} self.assertEqual(obj, expected) @@ -91,10 +175,10 @@ def test_unknown_field(self): def test_specific_field_method(self): serializer = Test2Model.es.get_serializer() obj = serializer.format(self.instance) - self.assertEqual(obj["bigint"], 42) + self.assertEqual(obj["intf"], 42) instance = Test2Model.es.deserialize(obj) - self.assertEqual(instance.bigint, 45) + self.assertEqual(instance.intf, 45) def test_type_specific_field_method(self): serializer = Test2Model.es.get_serializer() @@ -108,3 +192,31 @@ def test_type_specific_field_method(self): def test_simple_serializer(self): results = Test2Model.es.deserialize([{'id': self.instance.pk},]) self.assertTrue(self.instance in results) + + @withattrs(Test2Model.Elasticsearch, 'fields', ['id', 'dummies']) + def test_reverse_fk(self): + serializer = Test2Model.es.get_serializer() + obj = serializer.format(self.instance) + expected = {'id': 1, 'dummies': [{'id':1, 'foo': 'test'},]} + self.assertEqual(obj, expected) + + @withattrs(Test2Model.Elasticsearch, 'fields', ['id', 'dummiesm2m']) + def test_serialize_reverse_m2m(self): + serializer = Test2Model.es.get_serializer() + obj = serializer.format(self.instance) + expected = {'id': 1, 'dummiesm2m': [{'id':1, 'foo': 'test'},]} + self.assertEqual(obj, expected) + + def test_deserialize_reverse_relationships(self): + # make sure no sql query is done + instance = Test2Model.es.deserialize({'dummies': [{'id':1, 'foo': 'test'},], + 'dummiesm2m': [{'id':1, 'foo': 'test'},]}) + self.assertTrue(isinstance(instance, Test2Model)) + + self.assertEqual(len(instance.dummies), 1) + self.assertTrue(isinstance(instance.dummies[0], Dummy)) + self.assertEqual(instance.dummies[0].foo, 'test') + + self.assertEqual(len(instance.dummiesm2m), 1) + self.assertTrue(isinstance(instance.dummiesm2m[0], Dummy)) + self.assertEqual(instance.dummiesm2m[0].foo, 'test') diff --git a/test_project/test_app/models.py b/test_project/test_app/models.py index 933fbdb..26e786d 100644 --- a/test_project/test_app/models.py +++ b/test_project/test_app/models.py @@ -36,8 +36,19 @@ class Meta: ordering = ('id',) -class Dummy(models.Model): +class Dummy(EsIndexable): foo = models.CharField(max_length=256, null=True) + reversefk = models.ForeignKey('Test2Model', + related_name='dummies', + null=True) + + reversem2m = models.ManyToManyField('Test2Model', + related_name='dummiesm2m', + null=True) + + class Elasticsearch(EsIndexable.Elasticsearch): + abstract = True + fields = ['id', 'foo'] class Test2Serializer(EsJsonSerializer): @@ -50,17 +61,17 @@ def serialize_type_datetimefield(self, instance, field_name): 'time': d and d.time().isoformat()[:5] } - def deserialize_type_datetimefield(self, instance, field_name): - return datetime.strptime(instance.get(field_name)['iso'], + def deserialize_type_datetimefield(self, source, field_name): + return datetime.strptime(source.get(field_name)['iso'], '%Y-%m-%dT%H:%M:%S.%f') def serialize_abstract_method(self, instance, field_name): return 'woot' - def serialize_bigint(self, instance, field_name): + def serialize_intf(self, instance, field_name): return 42 - def deserialize_bigint(self, source, field_name): + def deserialize_intf(self, source, field_name): return 45 @@ -112,7 +123,8 @@ class Elasticsearch(EsIndexable.Elasticsearch): # Note: we need to specify this field since the value returned # by the serializer does not correspond to it's default mapping # see: Test2Serializer.serialize_type_datetimefield - mappings = {'datetf': {'type': 'object'}} + mappings = {'email': {"index": "not_analyzed"}, + 'datetf': {'type': 'object'}} @property def abstract_prop(self):