|  | 
| 24 | 24 | from .aggs import A, AggBase | 
| 25 | 25 | from .connections import get_connection | 
| 26 | 26 | from .exceptions import IllegalOperation | 
| 27 |  | -from .query import Bool, Q | 
|  | 27 | +from .query import Bool, Q, Query | 
| 28 | 28 | from .response import Hit, Response | 
| 29 | 29 | from .utils import AttrDict, DslBase, recursive_to_dict | 
| 30 | 30 | 
 | 
| @@ -319,6 +319,7 @@ def __init__(self, **kwargs): | 
| 319 | 319 |         self.aggs = AggsProxy(self) | 
| 320 | 320 |         self._sort = [] | 
| 321 | 321 |         self._collapse = {} | 
|  | 322 | +        self._knn = [] | 
| 322 | 323 |         self._source = None | 
| 323 | 324 |         self._highlight = {} | 
| 324 | 325 |         self._highlight_opts = {} | 
| @@ -407,6 +408,7 @@ def _clone(self): | 
| 407 | 408 | 
 | 
| 408 | 409 |         s._response_class = self._response_class | 
| 409 | 410 |         s._collapse = self._collapse.copy() | 
|  | 411 | +        s._knn = [knn.copy() for knn in self._knn] | 
| 410 | 412 |         s._sort = self._sort[:] | 
| 411 | 413 |         s._source = copy.copy(self._source) if self._source is not None else None | 
| 412 | 414 |         s._highlight = self._highlight.copy() | 
| @@ -445,6 +447,10 @@ def update_from_dict(self, d): | 
| 445 | 447 |             self.aggs._params = { | 
| 446 | 448 |                 "aggs": {name: A(value) for (name, value) in aggs.items()} | 
| 447 | 449 |             } | 
|  | 450 | +        if "knn" in d: | 
|  | 451 | +            self._knn = d.pop("knn") | 
|  | 452 | +            if isinstance(self._knn, dict): | 
|  | 453 | +                self._knn = [self._knn] | 
| 448 | 454 |         if "collapse" in d: | 
| 449 | 455 |             self._collapse = d.pop("collapse") | 
| 450 | 456 |         if "sort" in d: | 
| @@ -494,6 +500,60 @@ def script_fields(self, **kwargs): | 
| 494 | 500 |         s._script_fields.update(kwargs) | 
| 495 | 501 |         return s | 
| 496 | 502 | 
 | 
|  | 503 | +    def knn( | 
|  | 504 | +        self, | 
|  | 505 | +        field, | 
|  | 506 | +        k, | 
|  | 507 | +        num_candidates, | 
|  | 508 | +        query_vector=None, | 
|  | 509 | +        query_vector_builder=None, | 
|  | 510 | +        filter=None, | 
|  | 511 | +        similarity=None, | 
|  | 512 | +    ): | 
|  | 513 | +        """ | 
|  | 514 | +        Add a k-nearest neighbor (kNN) search. | 
|  | 515 | +
 | 
|  | 516 | +        :arg field: the name of the vector field to search against | 
|  | 517 | +        :arg k: number of nearest neighbors to return as top hits | 
|  | 518 | +        :arg num_candidates: number of nearest neighbor candidates to consider per shard | 
|  | 519 | +        :arg query_vector: the vector to search for | 
|  | 520 | +        :arg query_vector_builder: A dictionary indicating how to build a query vector | 
|  | 521 | +        :arg filter: query to filter the documents that can match | 
|  | 522 | +        :arg similarity: the minimum similarity required for a document to be considered a match, as a float value | 
|  | 523 | +
 | 
|  | 524 | +        Example:: | 
|  | 525 | +
 | 
|  | 526 | +            s = Search() | 
|  | 527 | +            s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector, | 
|  | 528 | +                      filter=Q('term', category='blog'))) | 
|  | 529 | +        """ | 
|  | 530 | +        s = self._clone() | 
|  | 531 | +        s._knn.append( | 
|  | 532 | +            { | 
|  | 533 | +                "field": field, | 
|  | 534 | +                "k": k, | 
|  | 535 | +                "num_candidates": num_candidates, | 
|  | 536 | +            } | 
|  | 537 | +        ) | 
|  | 538 | +        if query_vector is None and query_vector_builder is None: | 
|  | 539 | +            raise ValueError("one of query_vector and query_vector_builder is required") | 
|  | 540 | +        if query_vector is not None and query_vector_builder is not None: | 
|  | 541 | +            raise ValueError( | 
|  | 542 | +                "only one of query_vector and query_vector_builder must be given" | 
|  | 543 | +            ) | 
|  | 544 | +        if query_vector is not None: | 
|  | 545 | +            s._knn[-1]["query_vector"] = query_vector | 
|  | 546 | +        if query_vector_builder is not None: | 
|  | 547 | +            s._knn[-1]["query_vector_builder"] = query_vector_builder | 
|  | 548 | +        if filter is not None: | 
|  | 549 | +            if isinstance(filter, Query): | 
|  | 550 | +                s._knn[-1]["filter"] = filter.to_dict() | 
|  | 551 | +            else: | 
|  | 552 | +                s._knn[-1]["filter"] = filter | 
|  | 553 | +        if similarity is not None: | 
|  | 554 | +            s._knn[-1]["similarity"] = similarity | 
|  | 555 | +        return s | 
|  | 556 | + | 
| 497 | 557 |     def source(self, fields=None, **kwargs): | 
| 498 | 558 |         """ | 
| 499 | 559 |         Selectively control how the _source field is returned. | 
| @@ -677,6 +737,12 @@ def to_dict(self, count=False, **kwargs): | 
| 677 | 737 |         if self.query: | 
| 678 | 738 |             d["query"] = self.query.to_dict() | 
| 679 | 739 | 
 | 
|  | 740 | +        if self._knn: | 
|  | 741 | +            if len(self._knn) == 1: | 
|  | 742 | +                d["knn"] = self._knn[0] | 
|  | 743 | +            else: | 
|  | 744 | +                d["knn"] = self._knn | 
|  | 745 | + | 
| 680 | 746 |         # count request doesn't care for sorting and other things | 
| 681 | 747 |         if not count: | 
| 682 | 748 |             if self.post_filter: | 
|  | 
0 commit comments