Skip to content

Commit 1e8925a

Browse files
authored
Add rerank (#159)
## Problem Add rerank. ## Solution 1. Added the following two methods for rerank: a. Method that accepts required parameters only and the optional parameters are set to default. More details on the api and its default parameters can be found [here](https://docs.pinecone.io/reference/api/2024-10/inference/rerank). ```java rerank(String model, String query, List<Map<String, String>> documents) ``` b. Method that accepts all parameters: ```java rerank(String model, String query, List<Map<String, String>> documents, List<String> rankFields, int topN, boolean returnDocuments, Map<String, Object> parameters) ``` 2. Added integration tests for both rerank methods. 3. Added an example in the README file. ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan Added integration test.
1 parent 0b3defd commit 1e8925a

File tree

3 files changed

+217
-3
lines changed

3 files changed

+217
-3
lines changed

README.md

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ UpdateResponse updateResponse = index.update("v1", values, "example-namespace");
487487

488488
# Collections
489489

490-
Collections fall under data plane operations.
490+
Collections fall under control plane operations.
491491

492492
## Create collection
493493

@@ -545,8 +545,8 @@ Pinecone pinecone = new Pinecone.Builder("PINECONE_API_KEY").build();
545545
pinecone.deleteCollection("example-collection");
546546
```
547547

548-
## Inference
549-
548+
# Inference
549+
## Embed
550550
The Pinecone SDK now supports creating embeddings via the [Inference API](https://docs.pinecone.io/guides/inference/understanding-inference).
551551

552552
```java
@@ -583,6 +583,66 @@ EmbeddingsList embeddings = inference.embed(embeddingModel, parameters, inputs);
583583
List<Embedding> embeddedData = embeddings.getData();
584584
```
585585

586+
## Rerank
587+
The following example shows how to rerank items according to their relevance to a query.
588+
589+
```java
590+
import io.pinecone.clients.Inference;
591+
import io.pinecone.clients.Pinecone;
592+
import org.openapitools.inference.client.model.RerankResult;
593+
594+
import java.util.*;
595+
596+
...
597+
598+
// The model to use for reranking
599+
String model = "bge-reranker-v2-m3";
600+
601+
// The query to rerank documents against
602+
String query = "The tech company Apple is known for its innovative products like the iPhone.";
603+
604+
// Add the documents to rerank
605+
List<Map<String, String>> documents = new ArrayList<>();
606+
Map<String, String> doc1 = new HashMap<>();
607+
doc1.put("id", "vec1");
608+
doc1.put("my_field", "Apple is a popular fruit known for its sweetness and crisp texture.");
609+
documents.add(doc1);
610+
611+
Map<String, String> doc2 = new HashMap<>();
612+
doc2.put("id", "vec2");
613+
doc2.put("my_field", "Many people enjoy eating apples as a healthy snack.");
614+
documents.add(doc2);
615+
616+
Map<String, String> doc3 = new HashMap<>();
617+
doc3.put("id", "vec3");
618+
doc3.put("my_field", "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.");
619+
documents.add(doc3);
620+
621+
Map<String, String> doc4 = new HashMap<>();
622+
doc4.put("id", "vec4");
623+
doc4.put("my_field", "An apple a day keeps the doctor away, as the saying goes.");
624+
documents.add(doc4);
625+
626+
// The fields to rank the documents by. If not provided, the default is "text"
627+
List<String> rankFields = Arrays.asList("my_field");
628+
629+
// The number of results to return sorted by relevance. Defaults to the number of inputs
630+
int topN = 2;
631+
632+
// Whether to return the documents in the response
633+
boolean returnDocuments = true;
634+
635+
// Additional model-specific parameters for the reranker
636+
Map<String, Object> parameters = new HashMap<>();
637+
parameters.put("truncate", "END");
638+
639+
// Send ranking request
640+
RerankResult result = inference.rerank(model, query, documents, rankFields, topN, returnDocuments, parameters);
641+
642+
// Get ranked data
643+
System.out.println(result.getData());
644+
```
645+
586646
## Examples
587647

588648
- The data and control plane operation examples can be found in `io/pinecone/integration` folder.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package io.pinecone.integration.inference;
2+
3+
import io.pinecone.clients.Inference;
4+
import io.pinecone.clients.Pinecone;
5+
import org.junit.jupiter.api.Assertions;
6+
import org.junit.jupiter.api.Test;
7+
import org.openapitools.inference.client.ApiException;
8+
import org.openapitools.inference.client.model.RerankResult;
9+
10+
import java.util.*;
11+
12+
import static org.junit.jupiter.api.Assertions.assertNotNull;
13+
14+
public class RerankTest {
15+
private static final Pinecone pinecone = new Pinecone
16+
.Builder(System.getenv("PINECONE_API_KEY"))
17+
.withSourceTag("pinecone_test")
18+
.build();
19+
private static final Inference inference = pinecone.getInferenceClient();
20+
21+
@Test
22+
public void testRerank() throws ApiException {
23+
String model = "bge-reranker-v2-m3";
24+
String query = "The tech company Apple is known for its innovative products like the iPhone.";
25+
List<Map<String, String>> documents = new ArrayList<>();
26+
27+
Map<String, String> doc1 = new HashMap<>();
28+
doc1.put("id", "vec1");
29+
doc1.put("my_field", "Apple is a popular fruit known for its sweetness and crisp texture.");
30+
documents.add(doc1);
31+
32+
Map<String, String> doc2 = new HashMap<>();
33+
doc2.put("id", "vec2");
34+
doc2.put("my_field", "Many people enjoy eating apples as a healthy snack.");
35+
documents.add(doc2);
36+
37+
Map<String, String> doc3 = new HashMap<>();
38+
doc3.put("id", "vec3");
39+
doc3.put("my_field", "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.");
40+
documents.add(doc3);
41+
42+
Map<String, String> doc4 = new HashMap<>();
43+
doc4.put("id", "vec4");
44+
doc4.put("my_field", "An apple a day keeps the doctor away, as the saying goes.");
45+
documents.add(doc4);
46+
47+
List<String> rankFields = Arrays.asList("my_field");
48+
int topN = 2;
49+
boolean returnDocuments = true;
50+
Map<String, Object> parameters = new HashMap<>();
51+
parameters.put("truncate", "END");
52+
53+
RerankResult result = inference.rerank(model, query, documents, rankFields, topN, returnDocuments, parameters);
54+
55+
assertNotNull(result);
56+
Assertions.assertEquals(result.getData().size(), topN);
57+
Assertions.assertEquals(result.getData().get(0).getIndex(), 2);
58+
Assertions.assertEquals(result.getData().get(0).getDocument().get("my_field"), doc3.get("my_field"));
59+
Assertions.assertEquals(result.getData().size(), 2);
60+
}
61+
62+
@Test
63+
public void testRerankWithRequiredParameters() throws ApiException {
64+
String model = "bge-reranker-v2-m3";
65+
String query = "The tech company Apple is known for its innovative products like the iPhone.";
66+
List<Map<String, String>> documents = new ArrayList<>();
67+
68+
Map<String, String> doc1 = new HashMap<>();
69+
doc1.put("id", "vec1");
70+
doc1.put("text", "Apple is a popular fruit known for its sweetness and crisp texture.");
71+
documents.add(doc1);
72+
73+
Map<String, String> doc2 = new HashMap<>();
74+
doc2.put("id", "vec2");
75+
doc2.put("text", "Many people enjoy eating apples as a healthy snack.");
76+
documents.add(doc2);
77+
78+
Map<String, String> doc3 = new HashMap<>();
79+
doc3.put("id", "vec3");
80+
doc3.put("text", "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.");
81+
documents.add(doc3);
82+
83+
Map<String, String> doc4 = new HashMap<>();
84+
doc4.put("id", "vec4");
85+
doc4.put("text", "An apple a day keeps the doctor away, as the saying goes.");
86+
documents.add(doc4);
87+
88+
RerankResult result = inference.rerank(model, query, documents);
89+
90+
assertNotNull(result);
91+
Assertions.assertEquals(result.getData().size(), documents.size());
92+
Assertions.assertEquals(result.getData().get(0).getIndex(), 2);
93+
Assertions.assertEquals(result.getData().get(0).getDocument().get("text"), doc3.get("text"));
94+
}
95+
}

src/main/java/io/pinecone/clients/Inference.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,65 @@ public EmbeddingsList embed(String model, Map<String, Object> parameters, List<S
6060
return inferenceApi.embed(embedRequest);
6161
}
6262

63+
/**
64+
* Reranks a list of documents based on the relevance to a query using the specified model. Since rest of the
65+
* parameters are optional, they are set to their default values.
66+
*
67+
* @param model The model to be used for reranking the documents.
68+
* @param query The query string to rank the documents against.
69+
* @param documents A list of maps representing the documents to be ranked.
70+
* Each map should contain document attributes, such as "text".
71+
*
72+
* @return RerankResult containing the ranked documents and their scores.
73+
* @throws ApiException If the API call fails, an ApiException is thrown.
74+
*/
75+
public RerankResult rerank(String model,
76+
String query,
77+
List<Map<String, String>> documents) throws ApiException {
78+
return rerank(model,
79+
query,
80+
documents,
81+
Arrays.asList("text"),
82+
documents.size(),
83+
true,
84+
new HashMap<>());
85+
}
86+
87+
/**
88+
* Reranks a list of documents based on the relevance to a query using the specified model with additional options.
89+
*
90+
* @param model The model to be used for reranking the documents.
91+
* @param query The query string to rank the documents against.
92+
* @param documents A list of maps representing the documents to be ranked.
93+
* Each map should contain document attributes, such as "text".
94+
* @param rankFields A list of fields in the documents to be used for ranking, typically "text".
95+
* @param topN The number of top-ranked documents to return.
96+
* @param returnDocuments Whether to return the documents along with the ranking scores.
97+
* @param parameters A map containing additional model-specific parameters for reranking.
98+
* @return RerankResult containing the ranked documents and their scores.
99+
* @throws ApiException If the API call fails, an ApiException is thrown.
100+
*/
101+
public RerankResult rerank(String model,
102+
String query,
103+
List<Map<String, String>> documents,
104+
List<String> rankFields,
105+
int topN,
106+
boolean returnDocuments,
107+
Map<String, Object> parameters) throws ApiException {
108+
RerankRequest rerankRequest = new RerankRequest();
109+
110+
rerankRequest
111+
.model(model)
112+
.query(query)
113+
.documents(documents)
114+
.rankFields(rankFields)
115+
.topN(topN)
116+
.returnDocuments(returnDocuments)
117+
.putAdditionalProperty("parameters", parameters);
118+
119+
return inferenceApi.rerank(rerankRequest);
120+
}
121+
63122
/**
64123
* Converts a list of input strings to EmbedRequestInputsInner objects.
65124
*

0 commit comments

Comments
 (0)