Skip to content

Commit a991b33

Browse files
committed
Optimizes reducelanes in diversityCalculation of PQVectors, for Euclidean function
1 parent 6639992 commit a991b33

File tree

5 files changed

+185
-8
lines changed

5 files changed

+185
-8
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
337337
var node2Chunk = getChunk(node2);
338338
var node2Offset = getOffsetInChunk(node2);
339339
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
340-
float sum = 0;
341-
for (int m = 0; m < subspaceCount; m++) {
342-
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
343-
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
344-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
345-
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
346-
}
340+
float sum = VectorUtil.pqDiversityEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
347341
// scale to [0, 1]
348342
return 1 / (1 + sum);
349343
};

jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,4 +554,18 @@ public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValu
554554
return squaredSum;
555555
}
556556

557+
@Override
558+
public float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
559+
float sum = 0;
560+
for (int m = 0; m < subspaceCount; m++) {
561+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
562+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
563+
int centroidLength = subvectorSizesAndOffsets[m][0];
564+
565+
sum += squareDistance(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
566+
}
567+
return sum;
568+
569+
}
570+
557571
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,5 +250,8 @@ public static float nvqLoss(VectorFloat<?> vector, float growthRate, float midpo
250250
public static float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
251251
return impl.nvqUniformLoss(vector, minValue, maxValue, nBits);
252252
}
253+
public static float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
254+
return impl.pqDiversityEuclidean(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
255+
}
253256

254257
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,5 +320,5 @@ default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffs
320320
* @param nBits the number of bits per dimension
321321
*/
322322
float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits);
323-
323+
float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount);
324324
}

jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,5 +1369,171 @@ public void calculatePartialSums(VectorFloat<?> codebook, int codebookIndex, int
13691369
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
13701370
return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude);
13711371
}
1372+
1373+
float pqDiversityEuclidean_64(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1374+
float res = 0;
1375+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64);
1376+
for (int m = 0; m < subspaceCount; m++) {
1377+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1378+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1379+
int centroidLength = subvectorSizesAndOffsets[m][0];
1380+
final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength);
1381+
int length1 = centroidIndex1 * centroidLength;
1382+
int length2 = centroidIndex2 * centroidLength;
1383+
if (centroidLength == FloatVector.SPECIES_64.length()) {
1384+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1);
1385+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2);
1386+
var diff = a.sub(b);
1387+
sum = diff.mul(diff).add(sum);
1388+
}
1389+
else {
1390+
int i = 0;
1391+
for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) {
1392+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i);
1393+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i);
1394+
var diff = a.sub(b);
1395+
sum = diff.mul(diff).add(sum);
1396+
}
1397+
// Process the tail
1398+
1399+
for (; i < centroidLength ; ++i) {
1400+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1401+
res += diff * diff;
1402+
}
1403+
}
1404+
}
1405+
res += sum.reduceLanes(VectorOperators.ADD);
1406+
return res;
1407+
}
1408+
1409+
float pqDiversityEuclidean_128(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1410+
float res = 0;
1411+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128);
1412+
for (int m = 0; m < subspaceCount; m++) {
1413+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1414+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1415+
int centroidLength = subvectorSizesAndOffsets[m][0];
1416+
final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength);
1417+
int length1 = centroidIndex1 * centroidLength;
1418+
int length2 = centroidIndex2 * centroidLength;
1419+
if (centroidLength == FloatVector.SPECIES_128.length()) {
1420+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1);
1421+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2);
1422+
var diff = a.sub(b);
1423+
sum = diff.mul(diff).add(sum);
1424+
}
1425+
else {
1426+
int i = 0;
1427+
for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) {
1428+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i);
1429+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i);
1430+
var diff = a.sub(b);
1431+
sum = diff.mul(diff).add(sum);
1432+
}
1433+
// Process the tail
1434+
1435+
for (; i < centroidLength ; ++i) {
1436+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1437+
res += diff * diff;
1438+
}
1439+
}
1440+
}
1441+
res += sum.reduceLanes(VectorOperators.ADD);
1442+
return res;
1443+
}
1444+
1445+
float pqDiversityEuclidean_256(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1446+
float res = 0;
1447+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256);
1448+
for (int m = 0; m < subspaceCount; m++) {
1449+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1450+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1451+
int centroidLength = subvectorSizesAndOffsets[m][0];
1452+
final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength);
1453+
int length1 = centroidIndex1 * centroidLength;
1454+
int length2 = centroidIndex2 * centroidLength;
1455+
1456+
if (centroidLength == FloatVector.SPECIES_256.length()) {
1457+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1);
1458+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2);
1459+
var diff = a.sub(b);
1460+
sum = diff.mul(diff).add(sum);
1461+
}
1462+
else {
1463+
int i = 0;
1464+
for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) {
1465+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i);
1466+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i);
1467+
var diff = a.sub(b);
1468+
sum = diff.mul(diff).add(sum);
1469+
}
1470+
// Process the tail
1471+
for (; i < centroidLength ; ++i) {
1472+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1473+
res += diff * diff;
1474+
}
1475+
}
1476+
}
1477+
res += sum.reduceLanes(VectorOperators.ADD);
1478+
return res;
1479+
}
1480+
1481+
float pqDiversityEuclidean_512(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1482+
float res = 0;
1483+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
1484+
for (int m = 0; m < subspaceCount; m++) {
1485+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1486+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1487+
int centroidLength = subvectorSizesAndOffsets[m][0];
1488+
final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength);
1489+
int length1 = centroidIndex1 * centroidLength;
1490+
int length2 = centroidIndex2 * centroidLength;
1491+
if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) {
1492+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1);
1493+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2);
1494+
var diff = a.sub(b);
1495+
sum = diff.mul(diff).add(sum);
1496+
}
1497+
else {
1498+
int i = 0;
1499+
for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
1500+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i);
1501+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i);
1502+
var diff = a.sub(b);
1503+
sum = diff.mul(diff).add(sum);
1504+
}
1505+
// Process the tail
1506+
for (; i < centroidLength ; ++i) {
1507+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1508+
res += diff * diff;
1509+
}
1510+
}
1511+
}
1512+
res += sum.reduceLanes(VectorOperators.ADD);
1513+
return res;
1514+
}
1515+
1516+
@Override
1517+
public float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1518+
//Since centroid length can vary, picking the first entry in the array which is the largest one
1519+
if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length() )
1520+
{
1521+
return pqDiversityEuclidean_512( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1522+
}
1523+
else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length() )
1524+
{
1525+
return pqDiversityEuclidean_256( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1526+
}
1527+
//adding following two for completeness, will it get here?
1528+
else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length())
1529+
{
1530+
return pqDiversityEuclidean_128( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1531+
}
1532+
else
1533+
{
1534+
return pqDiversityEuclidean_64( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1535+
}
1536+
}
1537+
13721538
}
13731539

0 commit comments

Comments
 (0)