Skip to content

Commit c72460c

Browse files
Adds optional text word weights to HybridQuery and TextQuery (#410)
This PR adds the ability to individually increase or decrease the score contribution of specific words when performing a scored text matching query.
1 parent 3b1804a commit c72460c

File tree

6 files changed

+386
-23
lines changed

6 files changed

+386
-23
lines changed

redisvl/query/aggregate.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
return_fields: Optional[List[str]] = None,
103103
stopwords: Optional[Union[str, Set[str]]] = "english",
104104
dialect: int = 2,
105+
text_weights: Optional[Dict[str, float]] = None,
105106
):
106107
"""
107108
Instantiates a HybridQuery object.
@@ -127,6 +128,9 @@ def __init__(
127128
set, or tuple of strings is provided then those will be used as stopwords.
128129
Defaults to "english". if set to "None" then no stopwords will be removed.
129130
dialect (int, optional): The Redis dialect version. Defaults to 2.
131+
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
132+
within the query text. Defaults to None, as no modifications will be made to the
133+
text_scorer score.
130134
131135
Raises:
132136
ValueError: If the text string is empty, or if the text string becomes empty after
@@ -146,6 +150,7 @@ def __init__(
146150
self._dtype = dtype
147151
self._num_results = num_results
148152
self._set_stopwords(stopwords)
153+
self._text_weights = self._parse_text_weights(text_weights)
149154

150155
query_string = self._build_query_string()
151156
super().__init__(query_string)
@@ -193,6 +198,7 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
193198
language will be used. if a list, set, or tuple of strings is provided then those
194199
will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
195200
will be removed.
201+
196202
Raises:
197203
TypeError: If the stopwords are not a set, list, or tuple of strings.
198204
"""
@@ -222,6 +228,7 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
222228
223229
Returns:
224230
str: The tokenized and escaped query string.
231+
225232
Raises:
226233
ValueError: If the text string becomes empty after stopwords are removed.
227234
"""
@@ -233,13 +240,57 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
233240
)
234241
for token in user_query.split()
235242
]
236-
tokenized = " | ".join(
237-
[token for token in tokens if token and token not in self._stopwords]
238-
)
239243

240-
if not tokenized:
244+
token_list = [
245+
token for token in tokens if token and token not in self._stopwords
246+
]
247+
for i, token in enumerate(token_list):
248+
if token in self._text_weights:
249+
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"
250+
251+
if not token_list:
241252
raise ValueError("text string cannot be empty after removing stopwords")
242-
return tokenized
253+
return " | ".join(token_list)
254+
255+
def _parse_text_weights(
256+
self, weights: Optional[Dict[str, float]]
257+
) -> Dict[str, float]:
258+
parsed_weights: Dict[str, float] = {}
259+
if not weights:
260+
return parsed_weights
261+
for word, weight in weights.items():
262+
word = word.strip().lower()
263+
if not word or " " in word:
264+
raise ValueError(
265+
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
266+
)
267+
if (
268+
not (isinstance(weight, float) or isinstance(weight, int))
269+
or weight < 0.0
270+
):
271+
raise ValueError(
272+
f"Weights must be positive number. Got {{ {word}:{weight} }}"
273+
)
274+
parsed_weights[word] = weight
275+
return parsed_weights
276+
277+
def set_text_weights(self, weights: Dict[str, float]):
278+
"""Set or update the text weights for the query.
279+
280+
Args:
281+
text_weights: Dictionary of word:weight mappings
282+
"""
283+
self._text_weights = self._parse_text_weights(weights)
284+
self._query = self._build_query_string()
285+
286+
@property
287+
def text_weights(self) -> Dict[str, float]:
288+
"""Get the text weights.
289+
290+
Returns:
291+
Dictionary of word:weight mappings.
292+
"""
293+
return self._text_weights
243294

244295
def _build_query_string(self) -> str:
245296
"""Build the full query string for text search with optional filtering."""

redisvl/query/query.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,7 @@ def __init__(
10281028
in_order: bool = False,
10291029
params: Optional[Dict[str, Any]] = None,
10301030
stopwords: Optional[Union[str, Set[str]]] = "english",
1031+
text_weights: Optional[Dict[str, float]] = None,
10311032
):
10321033
"""A query for running a full text search, along with an optional filter expression.
10331034
@@ -1064,13 +1065,16 @@ def __init__(
10641065
a default set of stopwords for that language will be used. Users may specify
10651066
their own stop words by providing a List or Set of words. if set to None,
10661067
then no words will be removed. Defaults to 'english'.
1067-
1068+
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
1069+
within the query text. Defaults to None, as no modifications will be made to the
1070+
text_scorer score.
10681071
Raises:
10691072
ValueError: if stopwords language string cannot be loaded.
10701073
TypeError: If stopwords is not a valid iterable set of strings.
10711074
"""
10721075
self._text = text
10731076
self._field_weights = self._parse_field_weights(text_field_name)
1077+
self._text_weights = self._parse_text_weights(text_weights)
10741078
self._num_results = num_results
10751079

10761080
self._set_stopwords(stopwords)
@@ -1151,9 +1155,14 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
11511155
)
11521156
for token in user_query.split()
11531157
]
1154-
return " | ".join(
1155-
[token for token in tokens if token and token not in self._stopwords]
1156-
)
1158+
token_list = [
1159+
token for token in tokens if token and token not in self._stopwords
1160+
]
1161+
for i, token in enumerate(token_list):
1162+
if token in self._text_weights:
1163+
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"
1164+
1165+
return " | ".join(token_list)
11571166

11581167
def _parse_field_weights(
11591168
self, field_spec: Union[str, Dict[str, float]]
@@ -1220,6 +1229,46 @@ def text_field_name(self) -> Union[str, Dict[str, float]]:
12201229
return field
12211230
return self._field_weights.copy()
12221231

1232+
def _parse_text_weights(
1233+
self, weights: Optional[Dict[str, float]]
1234+
) -> Dict[str, float]:
1235+
parsed_weights: Dict[str, float] = {}
1236+
if not weights:
1237+
return parsed_weights
1238+
for word, weight in weights.items():
1239+
word = word.strip().lower()
1240+
if not word or " " in word:
1241+
raise ValueError(
1242+
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
1243+
)
1244+
if (
1245+
not (isinstance(weight, float) or isinstance(weight, int))
1246+
or weight < 0.0
1247+
):
1248+
raise ValueError(
1249+
f"Weights must be positive number. Got {{ {word}:{weight} }}"
1250+
)
1251+
parsed_weights[word] = weight
1252+
return parsed_weights
1253+
1254+
def set_text_weights(self, weights: Dict[str, float]):
1255+
"""Set or update the text weights for the query.
1256+
1257+
Args:
1258+
text_weights: Dictionary of word:weight mappings
1259+
"""
1260+
self._text_weights = self._parse_text_weights(weights)
1261+
self._built_query_string = None
1262+
1263+
@property
1264+
def text_weights(self) -> Dict[str, float]:
1265+
"""Get the text weights.
1266+
1267+
Returns:
1268+
Dictionary of word:weight mappings.
1269+
"""
1270+
return self._text_weights
1271+
12231272
def _build_query_string(self) -> str:
12241273
"""Build the full query string for text search with optional filtering."""
12251274
filter_expression = self._filter_expression

tests/integration/test_aggregation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,82 @@ def test_hybrid_query_with_text_filter(index):
317317
assert "research" not in result[text_field].lower()
318318

319319

320+
@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"])
321+
def test_hybrid_query_word_weights(index, scorer):
322+
skip_if_redis_version_below(index.client, "7.2.0")
323+
324+
text = "a medical professional with expertise in lung cancers"
325+
text_field = "description"
326+
vector = [0.1, 0.1, 0.5]
327+
vector_field = "user_embedding"
328+
return_fields = ["description"]
329+
330+
weights = {"medical": 3.4, "cancers": 5}
331+
332+
# test we can run a query with text weights
333+
weighted_query = HybridQuery(
334+
text=text,
335+
text_field_name=text_field,
336+
vector=vector,
337+
vector_field_name=vector_field,
338+
return_fields=return_fields,
339+
text_scorer=scorer,
340+
text_weights=weights,
341+
)
342+
343+
weighted_results = index.query(weighted_query)
344+
assert len(weighted_results) == 7
345+
346+
# test that weights do change the scores on results
347+
unweighted_query = HybridQuery(
348+
text=text,
349+
text_field_name=text_field,
350+
vector=vector,
351+
vector_field_name=vector_field,
352+
return_fields=return_fields,
353+
text_scorer=scorer,
354+
text_weights={},
355+
)
356+
357+
unweighted_results = index.query(unweighted_query)
358+
359+
for weighted, unweighted in zip(weighted_results, unweighted_results):
360+
for word in weights:
361+
if word in weighted["description"] or word in unweighted["description"]:
362+
assert float(weighted["text_score"]) > float(unweighted["text_score"])
363+
364+
# test that weights do change the document score and order of results
365+
weights = {"medical": 5, "cancers": 3.4} # switch the weights
366+
weighted_query = HybridQuery(
367+
text=text,
368+
text_field_name=text_field,
369+
vector=vector,
370+
vector_field_name=vector_field,
371+
return_fields=return_fields,
372+
text_scorer=scorer,
373+
text_weights=weights,
374+
)
375+
376+
weighted_results = index.query(weighted_query)
377+
assert weighted_results != unweighted_results
378+
379+
# test assigning weights on construction is equivalent to setting them on the query object
380+
new_query = HybridQuery(
381+
text=text,
382+
text_field_name=text_field,
383+
vector=vector,
384+
vector_field_name=vector_field,
385+
return_fields=return_fields,
386+
text_scorer=scorer,
387+
text_weights=None,
388+
)
389+
390+
new_query.set_text_weights(weights)
391+
392+
new_weighted_results = index.query(new_query)
393+
assert new_weighted_results == weighted_results
394+
395+
320396
def test_multivector_query(index):
321397
skip_if_redis_version_below(index.client, "7.2.0")
322398

tests/integration/test_query.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,72 @@ def test_text_query_with_text_filter(index):
888888
assert "research" not in result[text_field]
889889

890890

891+
@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"])
892+
def test_text_query_word_weights(index, scorer):
893+
skip_if_redis_version_below(index.client, "7.2.0")
894+
895+
text = "a medical professional with expertise in lung cancers"
896+
text_field = "description"
897+
return_fields = ["description"]
898+
899+
weights = {"medical": 3.4, "cancers": 5}
900+
901+
# test we can run a query with text weights
902+
weighted_query = TextQuery(
903+
text=text,
904+
text_field_name=text_field,
905+
return_fields=return_fields,
906+
text_scorer=scorer,
907+
text_weights=weights,
908+
)
909+
910+
weighted_results = index.query(weighted_query)
911+
assert len(weighted_results) == 4
912+
913+
# test that weights do change the scores on results
914+
unweighted_query = TextQuery(
915+
text=text,
916+
text_field_name=text_field,
917+
return_fields=return_fields,
918+
text_scorer=scorer,
919+
text_weights={},
920+
)
921+
922+
unweighted_results = index.query(unweighted_query)
923+
924+
for weighted, unweighted in zip(weighted_results, unweighted_results):
925+
for word in weights:
926+
if word in weighted["description"] or word in unweighted["description"]:
927+
assert weighted["score"] > unweighted["score"]
928+
929+
# test that weights do change the document score and order of results
930+
weights = {"medical": 5, "cancers": 3.4} # switch the weights
931+
weighted_query = TextQuery(
932+
text=text,
933+
text_field_name=text_field,
934+
return_fields=return_fields,
935+
text_scorer=scorer,
936+
text_weights=weights,
937+
)
938+
939+
weighted_results = index.query(weighted_query)
940+
assert weighted_results != unweighted_results
941+
942+
# test assigning weights on construction is equivalent to setting them on the query object
943+
new_query = TextQuery(
944+
text=text,
945+
text_field_name=text_field,
946+
return_fields=return_fields,
947+
text_scorer=scorer,
948+
text_weights=None,
949+
)
950+
951+
new_query.set_text_weights(weights)
952+
953+
new_weighted_results = index.query(new_query)
954+
assert new_weighted_results == weighted_results
955+
956+
891957
def test_vector_query_with_ef_runtime(index, vector_query, sample_data):
892958
"""
893959
Integration test: Verify that setting EF_RUNTIME on a VectorQuery works correctly.

0 commit comments

Comments
 (0)