diff --git a/sunburnt/search.py b/sunburnt/search.py index d90a9bf..6429acb 100644 --- a/sunburnt/search.py +++ b/sunburnt/search.py @@ -4,14 +4,14 @@ from .schema import SolrError, SolrBooleanField, SolrUnicodeField, WildcardFieldInstance - class LuceneQuery(object): default_term_re = re.compile(r'^\w+$') - def __init__(self, schema, option_flag=None, original=None): + def __init__(self, schema, option_flag=None, original=None, multiple_tags_allowed=False): self.schema = schema self.normalized = False if original is None: self.option_flag = option_flag + self.multiple_tags_allowed = multiple_tags_allowed self.terms = collections.defaultdict(set) self.phrases = collections.defaultdict(set) self.ranges = set() @@ -21,6 +21,7 @@ def __init__(self, schema, option_flag=None, original=None): self.boosts = [] else: self.option_flag = original.option_flag + self.multiple_tags_allowed = original.multiple_tags_allowed self.terms = copy.copy(original.terms) self.phrases = copy.copy(original.phrases) self.ranges = copy.copy(original.ranges) @@ -36,7 +37,7 @@ def clone(self): def options(self): opts = {} - s = unicode(self) + s = self.__unicode_special__() if s: opts[self.option_flag] = s return opts @@ -78,7 +79,7 @@ def serialize_term_queries(self, terms): s += [u'%s:%s' % (name, value.to_query()) for value in value_set] else: s += [value.to_query() for value in value_set] - return u' AND '.join(sorted(s)) + return sorted(s) range_query_templates = { "any": u"[* TO *]", @@ -95,7 +96,7 @@ def serialize_range_queries(self): range_s = self.range_query_templates[rel] % \ tuple(value.to_query() for value in sorted(values, key=lambda x: getattr(x, "value"))) s.append(u"%s:%s" % (name, range_s)) - return u' AND '.join(s) + return s def child_needs_parens(self, child): if len(child) == 1: @@ -178,7 +179,10 @@ def normalize(self): self.normalized = True return self, mutated - def __unicode__(self, level=0, op=None): + def __unicode__(self): + return self.__unicode_special__(force_serialize=True) + + def __unicode_special__(self, level=0, op=None, force_serialize=False): if not self.normalized: self, _ = self.normalize() if self.boosts: @@ -189,20 +193,27 @@ def __unicode__(self, level=0, op=None): for kwargs, boost_score in self.boosts] newself = newself | (newself & reduce(operator.or_, boost_queries)) newself, _ = newself.normalize() - return newself.__unicode__(level=level) + return newself.__unicode_special__(level=level, force_serialize=force_serialize) else: - u = [s for s in [self.serialize_term_queries(self.terms), + alliter = [self.serialize_term_queries(self.terms), self.serialize_term_queries(self.phrases), self.serialize_range_queries()] - if s] + u = [] + for iterator in alliter: + u.extend(iterator) for q in self.subqueries: op_ = u'OR' if self._or else u'AND' if self.child_needs_parens(q): - u.append(u"(%s)"%q.__unicode__(level=level+1, op=op_)) + u.append(u"(%s)"%q.__unicode_special__(level=level+1, op=op_)) else: - u.append(u"%s"%q.__unicode__(level=level+1, op=op_)) + u.append(u"%s"%q.__unicode_special__(level=level+1, op=op_)) if self._and: - return u' AND '.join(u) + if (not force_serialize and + level == 0 and + self.multiple_tags_allowed): + return u + else: + return u' AND '.join(u) elif self._or: return u' OR '.join(u) elif self._not: @@ -265,7 +276,7 @@ def __pow__(self, value): q._and = False q._pow = value return q - + def add(self, args, kwargs): self.normalized = False _args = [] @@ -369,7 +380,7 @@ class BaseSearch(object): def _init_common_modules(self): self.query_obj = LuceneQuery(self.schema, u'q') - self.filter_obj = LuceneQuery(self.schema, u'fq') + self.filter_obj = LuceneQuery(self.schema, u'fq', multiple_tags_allowed=True) self.paginator = PaginateOptions(self.schema) self.highlighter = HighlightOptions(self.schema) self.faceter = FacetOptions(self.schema) @@ -506,7 +517,7 @@ def params(self): _count = None def count(self): - # get the total count for the current query without retrieving any results + # get the total count for the current query without retrieving any results # cache it, since it may be needed multiple times when used with django paginator if self._count is None: # are we already paginated? then we'll behave as if that's diff --git a/sunburnt/sunburnt.py b/sunburnt/sunburnt.py index 880756c..3bbe935 100644 --- a/sunburnt/sunburnt.py +++ b/sunburnt/sunburnt.py @@ -100,7 +100,7 @@ def url_for_update(self, commit=None, commitWithin=None, softCommit=None, optimi return self.update_url def select(self, params): - qs = urllib.urlencode(params) + qs = urllib.urlencode(params, doseq=True) url = "%s?%s" % (self.select_url, qs) if len(url) > self.max_length_get_url: warnings.warn("Long query URL encountered - POSTing instead of " diff --git a/sunburnt/test_search.py b/sunburnt/test_search.py index e804ac4..f635313 100644 --- a/sunburnt/test_search.py +++ b/sunburnt/test_search.py @@ -102,9 +102,9 @@ class MockInterface(object): (["hello"], {}, [("fq", u"hello"), ("q", "*:*")]), (["hello"], {"int_field":3}, - [("fq", u"hello AND int_field:3"), ("q", "*:*")]), + [("fq", u"hello"), ("fq", u"int_field:3"), ("q", "*:*")]), (["hello", "world"], {}, - [("fq", u"hello AND world"), ("q", "*:*")]), + [("fq", u"hello"), ("fq", u"world"), ("q", "*:*")]), # NB this next is not really what we want, # probably this should warn (["hello world"], {}, @@ -115,9 +115,9 @@ class MockInterface(object): (["hello"], {}, [("fq", u"hello"), ("q", "*:*")]), (["hello"], {"int_field":3}, - [("fq", u"int_field:3 AND hello"), ("q", "*:*")]), + [("fq", u"hello"), ("fq", u"int_field:3"), ("q", "*:*")]), (["hello", "world"], {}, - [("fq", u"hello AND world"), ("q", "*:*")]), + [("fq", u"hello"), ("fq", u"world"), ("q", "*:*")]), (["hello world"], {}, [("fq", u"hello\\ world"), ("q", "*:*")]), ), @@ -126,9 +126,9 @@ class MockInterface(object): (["hello"], {}, [("fq", u"hello"), ("q", "*:*")]), (["hello"], {"int_field":3}, - [("fq", u"hello AND int_field:3"), ("q", "*:*")]), + [("fq", u"hello"), ("fq", "int_field:3"), ("q", "*:*")]), (["hello", "world"], {}, - [("fq", u"hello AND world"), ("q", "*:*")]), + [("fq", u"hello"), ("fq", u"world"), ("q", "*:*")]), (["hello world"], {}, [("fq", u"hello\\ world"), ("q", "*:*")]), ), @@ -418,7 +418,7 @@ def test_bad_option_data(): (lambda q: q.query("hello world").filter(q.Q(text_field="tow") | q.Q(boolean_field=False, int_field__gt=3)), [('fq', u'text_field:tow OR (boolean_field:false AND int_field:{3 TO *})'), ('q', u'hello\\ world')]), (lambda q: q.query("hello world").filter(q.Q(text_field="tow") & q.Q(boolean_field=False, int_field__gt=3)), - [('fq', u'boolean_field:false AND text_field:tow AND int_field:{3 TO *}'), ('q', u'hello\\ world')]), + [('fq', u'boolean_field:false'), ('fq', u'int_field:{3 TO *}'), ('fq', u'text_field:tow'), ('q', u'hello\\ world')]), # Test various combinations of NOTs at the top level. # Sometimes we need to do the *:* trick, sometimes not. (lambda q: q.query(~q.Q("hello world")), @@ -492,7 +492,7 @@ def test_bad_option_data(): def check_complex_boolean_query(solr_search, query, output): p = query(solr_search).params() try: - assert p == output + assert p == output, "Unequal: %r, %r" % (p, output) except AssertionError: if debug: print p