Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions sunburnt/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from .schema import SolrError, SolrBooleanField, SolrUnicodeField, WildcardFieldInstance


class LuceneQuery(object):
default_term_re = re.compile(r'^\w+$')
def __init__(self, schema, option_flag=None, original=None):
def __init__(self, schema, option_flag=None, original=None, multiple_tags_allowed=False):
self.schema = schema
self.normalized = False
if original is None:
self.option_flag = option_flag
self.multiple_tags_allowed = multiple_tags_allowed
self.terms = collections.defaultdict(set)
self.phrases = collections.defaultdict(set)
self.ranges = set()
Expand All @@ -21,6 +21,7 @@ def __init__(self, schema, option_flag=None, original=None):
self.boosts = []
else:
self.option_flag = original.option_flag
self.multiple_tags_allowed = original.multiple_tags_allowed
self.terms = copy.copy(original.terms)
self.phrases = copy.copy(original.phrases)
self.ranges = copy.copy(original.ranges)
Expand All @@ -36,7 +37,7 @@ def clone(self):

def options(self):
opts = {}
s = unicode(self)
s = self.__unicode_special__()
if s:
opts[self.option_flag] = s
return opts
Expand Down Expand Up @@ -78,7 +79,7 @@ def serialize_term_queries(self, terms):
s += [u'%s:%s' % (name, value.to_query()) for value in value_set]
else:
s += [value.to_query() for value in value_set]
return u' AND '.join(sorted(s))
return sorted(s)

range_query_templates = {
"any": u"[* TO *]",
Expand All @@ -95,7 +96,7 @@ def serialize_range_queries(self):
range_s = self.range_query_templates[rel] % \
tuple(value.to_query() for value in sorted(values, key=lambda x: getattr(x, "value")))
s.append(u"%s:%s" % (name, range_s))
return u' AND '.join(s)
return s

def child_needs_parens(self, child):
if len(child) == 1:
Expand Down Expand Up @@ -178,7 +179,10 @@ def normalize(self):
self.normalized = True
return self, mutated

def __unicode__(self, level=0, op=None):
def __unicode__(self):
return self.__unicode_special__(force_serialize=True)

def __unicode_special__(self, level=0, op=None, force_serialize=False):
if not self.normalized:
self, _ = self.normalize()
if self.boosts:
Expand All @@ -189,20 +193,27 @@ def __unicode__(self, level=0, op=None):
for kwargs, boost_score in self.boosts]
newself = newself | (newself & reduce(operator.or_, boost_queries))
newself, _ = newself.normalize()
return newself.__unicode__(level=level)
return newself.__unicode_special__(level=level, force_serialize=force_serialize)
else:
u = [s for s in [self.serialize_term_queries(self.terms),
alliter = [self.serialize_term_queries(self.terms),
self.serialize_term_queries(self.phrases),
self.serialize_range_queries()]
if s]
u = []
for iterator in alliter:
u.extend(iterator)
for q in self.subqueries:
op_ = u'OR' if self._or else u'AND'
if self.child_needs_parens(q):
u.append(u"(%s)"%q.__unicode__(level=level+1, op=op_))
u.append(u"(%s)"%q.__unicode_special__(level=level+1, op=op_))
else:
u.append(u"%s"%q.__unicode__(level=level+1, op=op_))
u.append(u"%s"%q.__unicode_special__(level=level+1, op=op_))
if self._and:
return u' AND '.join(u)
if (not force_serialize and
level == 0 and
self.multiple_tags_allowed):
return u
else:
return u' AND '.join(u)
elif self._or:
return u' OR '.join(u)
elif self._not:
Expand Down Expand Up @@ -265,7 +276,7 @@ def __pow__(self, value):
q._and = False
q._pow = value
return q

def add(self, args, kwargs):
self.normalized = False
_args = []
Expand Down Expand Up @@ -369,7 +380,7 @@ class BaseSearch(object):

def _init_common_modules(self):
self.query_obj = LuceneQuery(self.schema, u'q')
self.filter_obj = LuceneQuery(self.schema, u'fq')
self.filter_obj = LuceneQuery(self.schema, u'fq', multiple_tags_allowed=True)
self.paginator = PaginateOptions(self.schema)
self.highlighter = HighlightOptions(self.schema)
self.faceter = FacetOptions(self.schema)
Expand Down Expand Up @@ -506,7 +517,7 @@ def params(self):

_count = None
def count(self):
# get the total count for the current query without retrieving any results
# get the total count for the current query without retrieving any results
# cache it, since it may be needed multiple times when used with django paginator
if self._count is None:
# are we already paginated? then we'll behave as if that's
Expand Down
2 changes: 1 addition & 1 deletion sunburnt/sunburnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def url_for_update(self, commit=None, commitWithin=None, softCommit=None, optimi
return self.update_url

def select(self, params):
qs = urllib.urlencode(params)
qs = urllib.urlencode(params, doseq=True)
url = "%s?%s" % (self.select_url, qs)
if len(url) > self.max_length_get_url:
warnings.warn("Long query URL encountered - POSTing instead of "
Expand Down
16 changes: 8 additions & 8 deletions sunburnt/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class MockInterface(object):
(["hello"], {},
[("fq", u"hello"), ("q", "*:*")]),
(["hello"], {"int_field":3},
[("fq", u"hello AND int_field:3"), ("q", "*:*")]),
[("fq", u"hello"), ("fq", u"int_field:3"), ("q", "*:*")]),
(["hello", "world"], {},
[("fq", u"hello AND world"), ("q", "*:*")]),
[("fq", u"hello"), ("fq", u"world"), ("q", "*:*")]),
# NB this next is not really what we want,
# probably this should warn
(["hello world"], {},
Expand All @@ -115,9 +115,9 @@ class MockInterface(object):
(["hello"], {},
[("fq", u"hello"), ("q", "*:*")]),
(["hello"], {"int_field":3},
[("fq", u"int_field:3 AND hello"), ("q", "*:*")]),
[("fq", u"hello"), ("fq", u"int_field:3"), ("q", "*:*")]),
(["hello", "world"], {},
[("fq", u"hello AND world"), ("q", "*:*")]),
[("fq", u"hello"), ("fq", u"world"), ("q", "*:*")]),
(["hello world"], {},
[("fq", u"hello\\ world"), ("q", "*:*")]),
),
Expand All @@ -126,9 +126,9 @@ class MockInterface(object):
(["hello"], {},
[("fq", u"hello"), ("q", "*:*")]),
(["hello"], {"int_field":3},
[("fq", u"hello AND int_field:3"), ("q", "*:*")]),
[("fq", u"hello"), ("fq", "int_field:3"), ("q", "*:*")]),
(["hello", "world"], {},
[("fq", u"hello AND world"), ("q", "*:*")]),
[("fq", u"hello"), ("fq", u"world"), ("q", "*:*")]),
(["hello world"], {},
[("fq", u"hello\\ world"), ("q", "*:*")]),
),
Expand Down Expand Up @@ -418,7 +418,7 @@ def test_bad_option_data():
(lambda q: q.query("hello world").filter(q.Q(text_field="tow") | q.Q(boolean_field=False, int_field__gt=3)),
[('fq', u'text_field:tow OR (boolean_field:false AND int_field:{3 TO *})'), ('q', u'hello\\ world')]),
(lambda q: q.query("hello world").filter(q.Q(text_field="tow") & q.Q(boolean_field=False, int_field__gt=3)),
[('fq', u'boolean_field:false AND text_field:tow AND int_field:{3 TO *}'), ('q', u'hello\\ world')]),
[('fq', u'boolean_field:false'), ('fq', u'int_field:{3 TO *}'), ('fq', u'text_field:tow'), ('q', u'hello\\ world')]),
# Test various combinations of NOTs at the top level.
# Sometimes we need to do the *:* trick, sometimes not.
(lambda q: q.query(~q.Q("hello world")),
Expand Down Expand Up @@ -492,7 +492,7 @@ def test_bad_option_data():
def check_complex_boolean_query(solr_search, query, output):
p = query(solr_search).params()
try:
assert p == output
assert p == output, "Unequal: %r, %r" % (p, output)
except AssertionError:
if debug:
print p
Expand Down