Skip to content

Commit dd1c332

Browse files
committed
Optimizes reducelanes in diversityCalculation of PQVectors, for Euclidean function
1 parent 1c29821 commit dd1c332

File tree

5 files changed

+187
-8
lines changed

5 files changed

+187
-8
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
298298
var node2Chunk = getChunk(node2);
299299
var node2Offset = getOffsetInChunk(node2);
300300
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
301-
float sum = 0;
302-
for (int m = 0; m < subspaceCount; m++) {
303-
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
304-
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
305-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
306-
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
307-
}
301+
float sum = VectorUtil.pqDiversityEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
308302
// scale to [0, 1]
309303
return 1 / (1 + sum);
310304
};

jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,18 @@ public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValu
584584
return squaredSum;
585585
}
586586

587+
@Override
588+
public float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
589+
float sum = 0;
590+
for (int m = 0; m < subspaceCount; m++) {
591+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
592+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
593+
int centroidLength = subvectorSizesAndOffsets[m][0];
594+
595+
sum += squareDistance(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
596+
}
597+
return sum;
598+
599+
}
600+
587601
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ public static float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encod
227227
return impl.pqDecodedCosineSimilarity(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude);
228228
}
229229

230+
public static float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
231+
return impl.pqDiversityEuclidean(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
232+
}
233+
230234
public static float nvqDotProduct8bit(VectorFloat<?> vector, ByteSequence<?> bytes, float growthRate, float midpoint, float minValue, float maxValue) {
231235
return impl.nvqDotProduct8bit(vector, bytes, growthRate, midpoint, minValue, maxValue);
232236
}
@@ -254,4 +258,5 @@ public static float nvqLoss(VectorFloat<?> vector, float growthRate, float midpo
254258
public static float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
255259
return impl.nvqUniformLoss(vector, minValue, maxValue, nBits);
256260
}
261+
257262
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,5 +337,5 @@ default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffs
337337
* @param nBits the number of bits per dimension
338338
*/
339339
float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits);
340-
340+
float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount);
341341
}

jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,5 +1576,171 @@ public void calculatePartialSums(VectorFloat<?> codebook, int codebookIndex, int
15761576
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
15771577
return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude);
15781578
}
1579+
1580+
float pqDiversityEuclidean_64(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1581+
float res = 0;
1582+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64);
1583+
for (int m = 0; m < subspaceCount; m++) {
1584+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1585+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1586+
int centroidLength = subvectorSizesAndOffsets[m][0];
1587+
final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength);
1588+
int length1 = centroidIndex1 * centroidLength;
1589+
int length2 = centroidIndex2 * centroidLength;
1590+
if (centroidLength == FloatVector.SPECIES_64.length()) {
1591+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1);
1592+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2);
1593+
var diff = a.sub(b);
1594+
sum = diff.mul(diff).add(sum);
1595+
}
1596+
else {
1597+
int i = 0;
1598+
for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) {
1599+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i);
1600+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i);
1601+
var diff = a.sub(b);
1602+
sum = diff.mul(diff).add(sum);
1603+
}
1604+
// Process the tail
1605+
1606+
for (; i < centroidLength ; ++i) {
1607+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1608+
res += diff * diff;
1609+
}
1610+
}
1611+
}
1612+
res += sum.reduceLanes(VectorOperators.ADD);
1613+
return res;
1614+
}
1615+
1616+
float pqDiversityEuclidean_128(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1617+
float res = 0;
1618+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128);
1619+
for (int m = 0; m < subspaceCount; m++) {
1620+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1621+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1622+
int centroidLength = subvectorSizesAndOffsets[m][0];
1623+
final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength);
1624+
int length1 = centroidIndex1 * centroidLength;
1625+
int length2 = centroidIndex2 * centroidLength;
1626+
if (centroidLength == FloatVector.SPECIES_128.length()) {
1627+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1);
1628+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2);
1629+
var diff = a.sub(b);
1630+
sum = diff.mul(diff).add(sum);
1631+
}
1632+
else {
1633+
int i = 0;
1634+
for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) {
1635+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i);
1636+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i);
1637+
var diff = a.sub(b);
1638+
sum = diff.mul(diff).add(sum);
1639+
}
1640+
// Process the tail
1641+
1642+
for (; i < centroidLength ; ++i) {
1643+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1644+
res += diff * diff;
1645+
}
1646+
}
1647+
}
1648+
res += sum.reduceLanes(VectorOperators.ADD);
1649+
return res;
1650+
}
1651+
1652+
float pqDiversityEuclidean_256(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1653+
float res = 0;
1654+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256);
1655+
for (int m = 0; m < subspaceCount; m++) {
1656+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1657+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1658+
int centroidLength = subvectorSizesAndOffsets[m][0];
1659+
final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength);
1660+
int length1 = centroidIndex1 * centroidLength;
1661+
int length2 = centroidIndex2 * centroidLength;
1662+
1663+
if (centroidLength == FloatVector.SPECIES_256.length()) {
1664+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1);
1665+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2);
1666+
var diff = a.sub(b);
1667+
sum = diff.mul(diff).add(sum);
1668+
}
1669+
else {
1670+
int i = 0;
1671+
for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) {
1672+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i);
1673+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i);
1674+
var diff = a.sub(b);
1675+
sum = diff.mul(diff).add(sum);
1676+
}
1677+
// Process the tail
1678+
for (; i < centroidLength ; ++i) {
1679+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1680+
res += diff * diff;
1681+
}
1682+
}
1683+
}
1684+
res += sum.reduceLanes(VectorOperators.ADD);
1685+
return res;
1686+
}
1687+
1688+
float pqDiversityEuclidean_512(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1689+
float res = 0;
1690+
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
1691+
for (int m = 0; m < subspaceCount; m++) {
1692+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
1693+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
1694+
int centroidLength = subvectorSizesAndOffsets[m][0];
1695+
final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength);
1696+
int length1 = centroidIndex1 * centroidLength;
1697+
int length2 = centroidIndex2 * centroidLength;
1698+
if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) {
1699+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1);
1700+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2);
1701+
var diff = a.sub(b);
1702+
sum = diff.mul(diff).add(sum);
1703+
}
1704+
else {
1705+
int i = 0;
1706+
for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
1707+
FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i);
1708+
FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i);
1709+
var diff = a.sub(b);
1710+
sum = diff.mul(diff).add(sum);
1711+
}
1712+
// Process the tail
1713+
for (; i < centroidLength ; ++i) {
1714+
var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i);
1715+
res += diff * diff;
1716+
}
1717+
}
1718+
}
1719+
res += sum.reduceLanes(VectorOperators.ADD);
1720+
return res;
1721+
}
1722+
1723+
@Override
1724+
public float pqDiversityEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
1725+
//Since centroid length can vary, picking the first entry in the array which is the largest one
1726+
if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length() )
1727+
{
1728+
return pqDiversityEuclidean_512( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1729+
}
1730+
else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length() )
1731+
{
1732+
return pqDiversityEuclidean_256( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1733+
}
1734+
//adding following two for completeness, will it get here?
1735+
else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length())
1736+
{
1737+
return pqDiversityEuclidean_128( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1738+
}
1739+
else
1740+
{
1741+
return pqDiversityEuclidean_64( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
1742+
}
1743+
}
1744+
15791745
}
15801746

0 commit comments

Comments
 (0)