@@ -317,6 +317,80 @@ def test_hybrid_query_with_text_filter(index):
317317 assert "research" not in result [text_field ].lower ()
318318
319319
320+ @pytest .mark .parametrize ("scorer" , ["BM25" , "BM25STD" , "TFIDF" , "TFIDF.DOCNORM" ])
321+ def test_hybrid_query_word_weights (index , scorer ):
322+ text = "a medical professional with expertise in lung cancers"
323+ text_field = "description"
324+ vector = [0.1 , 0.1 , 0.5 ]
325+ vector_field = "user_embedding"
326+ return_fields = ["description" ]
327+
328+ weights = {"medical" : 3.4 , "cancers" : 5 }
329+
330+ # test we can run a query with text weights
331+ weighted_query = HybridQuery (
332+ text = text ,
333+ text_field_name = text_field ,
334+ vector = vector ,
335+ vector_field_name = vector_field ,
336+ return_fields = return_fields ,
337+ text_scorer = scorer ,
338+ text_weights = weights ,
339+ )
340+
341+ weighted_results = index .query (weighted_query )
342+ assert len (weighted_results ) == 7
343+
344+ # test that weights do change the scores on results
345+ unweighted_query = HybridQuery (
346+ text = text ,
347+ text_field_name = text_field ,
348+ vector = vector ,
349+ vector_field_name = vector_field ,
350+ return_fields = return_fields ,
351+ text_scorer = scorer ,
352+ text_weights = {},
353+ )
354+
355+ unweighted_results = index .query (unweighted_query )
356+
357+ for weighted , unweighted in zip (weighted_results , unweighted_results ):
358+ for word in weights :
359+ if word in weighted ["description" ] or word in unweighted ["description" ]:
360+ assert float (weighted ["text_score" ]) > float (unweighted ["text_score" ])
361+
362+ # test that weights do change the document score and order of results
363+ weights = {"medical" : 5 , "cancers" : 3.4 } # switch the weights
364+ weighted_query = HybridQuery (
365+ text = text ,
366+ text_field_name = text_field ,
367+ vector = vector ,
368+ vector_field_name = vector_field ,
369+ return_fields = return_fields ,
370+ text_scorer = scorer ,
371+ text_weights = weights ,
372+ )
373+
374+ weighted_results = index .query (weighted_query )
375+ assert weighted_results != unweighted_results
376+
377+ # test assigning weights on construction is equivalent to setting them on the query object
378+ new_query = HybridQuery (
379+ text = text ,
380+ text_field_name = text_field ,
381+ vector = vector ,
382+ vector_field_name = vector_field ,
383+ return_fields = return_fields ,
384+ text_scorer = scorer ,
385+ text_weights = None ,
386+ )
387+
388+ new_query .set_text_weights (weights )
389+
390+ new_weighted_results = index .query (new_query )
391+ assert new_weighted_results == weighted_results
392+
393+
320394def test_multivector_query (index ):
321395 skip_if_redis_version_below (index .client , "7.2.0" )
322396
0 commit comments