From 3c2b027147878d0fdc6cc7a442b0244fc5481ccd Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 20 Aug 2025 07:19:55 +0000 Subject: [PATCH 1/7] Add GGUFRankingFinisher and corresponding tests for ranking capabilities --- .../nlp/finisher/GGUFRankingFinisher.scala | 367 ++++++++++++++++++ .../seq2seq/AutoGGUFRerankerTest.scala | 30 +- .../finisher/GGUFRankingFinisherTest.scala | 349 +++++++++++++++++ 3 files changed, 740 insertions(+), 6 deletions(-) create mode 100644 src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala diff --git a/src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala b/src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala new file mode 100644 index 00000000000000..a7343f3ed9620f --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala @@ -0,0 +1,367 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.johnsnowlabs.nlp.finisher + +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import com.johnsnowlabs.nlp.util.FinisherUtil +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, DoubleParam, IntParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.expressions.{UserDefinedFunction, Window} +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.types._ + +import scala.collection.mutable + +/** Finisher for AutoGGUFReranker outputs that provides ranking capabilities including top-k + * selection, sorting by relevance score, and score normalization. + * + * This finisher processes the output of AutoGGUFReranker, which contains documents with + * relevance scores in their metadata. It provides several options for post-processing: + * + * - Top-k selection: Select only the top k documents by relevance score + * - Score thresholding: Filter documents by minimum relevance score + * - Min-max scaling: Normalize relevance scores to 0-1 range + * - Sorting: Sort documents by relevance score in descending order + * - Ranking: Add rank information to document metadata + * + * The finisher preserves the document annotation structure while adding ranking information to + * the metadata and optionally filtering/sorting the documents. + * + * ==Example== + * + * {{{ + * import com.johnsnowlabs.nlp.base._ + * import com.johnsnowlabs.nlp.annotators._ + * import com.johnsnowlabs.nlp.finisher._ + * import org.apache.spark.ml.Pipeline + * import spark.implicits._ + * + * val document = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("document") + * + * val reranker = AutoGGUFReranker + * .pretrained("bge-reranker-v2-m3-Q4_K_M") + * .setInputCols("document") + * .setOutputCol("reranked_documents") + * .setQuery("A man is eating pasta.") + * + * val finisher = new GGUFRankingFinisher() + * .setInputCols("reranked_documents") + * .setOutputCol("ranked_documents") + * .setTopK(3) + * .setMinRelevanceScore(0.1) + * .setMinMaxScaling(true) + * + * val pipeline = new Pipeline().setStages(Array(document, reranker, finisher)) + * + * val data = Seq( + * "A man is eating food.", + * "A man is eating a piece of bread.", + * "The girl is carrying a baby.", + * "A man is riding a horse." + * ).toDF("text") + * + * val result = pipeline.fit(data).transform(data) + * result.select("ranked_documents").show(truncate = false) + * // Documents will be sorted by relevance with rank information in metadata + * }}} + * + * @param uid + * required uid for storing finisher to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this finisher can take. Users can set and get the parameter + * values through setters and getters, respectively. + */ +case class GGUFRankingFinisher(override val uid: String) + extends Transformer + with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("GGUF_RANKING_FINISHER")) + + val RELEVANCE_SCORE_COL_NAME = "relevance_score" + val QUERY_COL_NAME = "query" + val RANK_COL_NAME = "rank" + + /** Name of input annotation cols containing reranked documents + * + * @group param + */ + val inputCols: StringArrayParam = + new StringArrayParam( + this, + "inputCols", + "Name of input annotation cols containing reranked documents") + + /** Name of input annotation cols containing reranked documents + * + * @group setParam + */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** Name of input annotation cols containing reranked documents + * + * @group setParam + */ + def setInputCols(value: String*): this.type = setInputCols(value.toArray) + + /** Name of input annotation cols containing reranked documents + * + * @group getParam + */ + def getInputCols: Array[String] = $(inputCols) + + /** Name of output annotation column containing ranked documents + * + * @group param + */ + val outputCol: StringArrayParam = + new StringArrayParam( + this, + "outputCol", + "Name of output annotation column containing ranked documents") + + /** Name of output annotation column containing ranked documents + * + * @group setParam + */ + def setOutputCol(value: String): this.type = set(outputCol, Array(value)) + + /** Name of output annotation column containing ranked documents + * + * @group getParam + */ + def getOutputCol: String = $(outputCol).headOption.getOrElse("ranked_documents") + + /** Maximum number of top documents to return based on relevance score + * + * @group param + */ + val topK: IntParam = + new IntParam( + this, + "topK", + "Maximum number of top documents to return based on relevance score") + + /** Set maximum number of top documents to return + * + * @group setParam + */ + def setTopK(value: Int): this.type = set(topK, value) + + /** Get maximum number of top documents to return + * + * @group getParam + */ + def getTopK: Int = $(topK) + + /** Minimum relevance score threshold for filtering documents + * + * @group param + */ + val minRelevanceScore: DoubleParam = + new DoubleParam( + this, + "minRelevanceScore", + "Minimum relevance score threshold for filtering documents") + + /** Set minimum relevance score threshold + * + * @group setParam + */ + def setMinRelevanceScore(value: Double): this.type = set(minRelevanceScore, value) + + /** Get minimum relevance score threshold + * + * @group getParam + */ + def getMinRelevanceScore: Double = $(minRelevanceScore) + + /** Whether to apply min-max scaling to normalize relevance scores to 0-1 range + * + * @group param + */ + val minMaxScaling: BooleanParam = + new BooleanParam( + this, + "minMaxScaling", + "Whether to apply min-max scaling to normalize relevance scores to 0-1 range") + + /** Set whether to apply min-max scaling + * + * @group setParam + */ + def setMinMaxScaling(value: Boolean): this.type = set(minMaxScaling, value) + + /** Get whether to apply min-max scaling + * + * @group getParam + */ + def getMinMaxScaling: Boolean = $(minMaxScaling) + + setDefault( + topK -> -1, // -1 means no limit + minRelevanceScore -> Double.MinValue, // No threshold by default + minMaxScaling -> false, + outputCol -> Array("ranked_documents")) + + override def transform(dataset: Dataset[_]): DataFrame = { + val inputCol = getInputCols.head + val outputColumnName = getOutputCol + + import dataset.sparkSession.implicits._ + + // First, flatten all annotations across all rows to get global statistics + val allAnnotations = dataset + .withColumn("row_id", monotonically_increasing_id()) + .select($"row_id", explode(col(inputCol)).as("annotation")) + + // Extract scores from all annotations + val scoresDF = allAnnotations + .withColumn( + "score", + when( + col("annotation.metadata").getItem(RELEVANCE_SCORE_COL_NAME).isNotNull && + col("annotation.metadata").getItem(RELEVANCE_SCORE_COL_NAME) =!= "", + col("annotation.metadata").getItem(RELEVANCE_SCORE_COL_NAME).cast("double")) + .otherwise(0.0)) + + // Get global min/max for scaling if enabled + val (globalMin, globalMax) = if (getMinMaxScaling) { + val stats = scoresDF.agg(min($"score"), max($"score")).collect().head + (stats.getDouble(0), stats.getDouble(1)) + } else { + (0.0, 0.0) + } + + // Calculate scaled scores and assign global ranks + val scaledDF = if (getMinMaxScaling && globalMax != globalMin) { + scoresDF.withColumn("scaled_score", ($"score" - globalMin) / (globalMax - globalMin)) + } else if (getMinMaxScaling && globalMax == globalMin) { + scoresDF.withColumn("scaled_score", lit(1.0)) + } else { + scoresDF.withColumn("scaled_score", $"score") + } + + // Filter by threshold and assign global ranks + val rankedDF = scaledDF + .filter($"scaled_score" >= getMinRelevanceScore) + .withColumn("global_rank", row_number().over(Window.orderBy($"scaled_score".desc))) + + // Apply top-k limit if specified + val limitedDF = if (getTopK > 0) { + rankedDF.filter($"global_rank" <= getTopK) + } else { + rankedDF + } + + // Create UDF to update annotation metadata + val updateAnnotationMetadataUDF: UserDefinedFunction = udf { + (annotationRow: Row, score: Double, rank: Int) => + // Convert Row to Annotation + val annotation = Annotation( + annotatorType = annotationRow.getString(0), + begin = annotationRow.getInt(1), + end = annotationRow.getInt(2), + result = annotationRow.getString(3), + metadata = annotationRow.getAs[Map[String, String]](4)) + + val updatedMetadata = annotation.metadata ++ Map( + RELEVANCE_SCORE_COL_NAME -> score.toString, + RANK_COL_NAME -> rank.toString) + + Annotation( + annotatorType = annotation.annotatorType, + begin = annotation.begin, + end = annotation.end, + result = annotation.result, + metadata = updatedMetadata) + } + + // Update annotations with new metadata + val updatedAnnotationsDF = limitedDF + .withColumn( + "updated_annotation", + updateAnnotationMetadataUDF($"annotation", $"scaled_score", $"global_rank")) + .select($"row_id", $"updated_annotation") + + // Group annotations back by row_id + val groupedDF = updatedAnnotationsDF + .groupBy($"row_id") + .agg(collect_list($"updated_annotation").as("processed_annotations")) + + // Join back with original dataset and filter out rows with no annotations + val originalWithRowId = dataset.withColumn("row_id", monotonically_increasing_id()) + + val result = originalWithRowId + .join( + groupedDF, + Seq("row_id"), + "inner" + ) // Use inner join to exclude rows with no annotations + .withColumn(outputColumnName, $"processed_annotations") + .drop("row_id", "processed_annotations") + + result + } + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + + override def transformSchema(schema: StructType): StructType = { + val documentAnnotators = Seq(AnnotatorType.DOCUMENT) + + getInputCols.foreach { annotationColumn => + FinisherUtil.checkIfInputColsExist(getInputCols, schema) + FinisherUtil.checkIfAnnotationColumnIsSparkNLPAnnotation(schema, annotationColumn) + + /** Check if the annotationColumn has Document type. It must be annotators that produce + * Document annotations with relevance score metadata (like AutoGGUFReranker) + */ + require( + documentAnnotators.contains(schema(annotationColumn).metadata.getString("annotatorType")), + s"column [$annotationColumn] must be of type Document with relevance score metadata") + } + + // Add output column to schema + val outputColumnName = getOutputCol + schema.add( + StructField( + outputColumnName, + ArrayType(Annotation.dataType), + nullable = false, + metadata = new MetadataBuilder() + .putString("annotatorType", AnnotatorType.DOCUMENT) + .build())) + } +} + +object GGUFRankingFinisher extends DefaultParamsReadable[GGUFRankingFinisher] diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala index d7bfd20f7f9a19..9056cee4ca846c 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala @@ -1,6 +1,7 @@ package com.johnsnowlabs.nlp.annotators.seq2seq import com.johnsnowlabs.nlp.Annotation +import com.johnsnowlabs.nlp.finisher.GGUFRankingFinisher import com.johnsnowlabs.nlp.base.DocumentAssembler import com.johnsnowlabs.nlp.util.io.ResourceHelper import com.johnsnowlabs.tags.{SlowTest, FastTest} @@ -23,7 +24,7 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { lazy val model: AutoGGUFReranker = AutoGGUFReranker .loadSavedModel(modelPath, ResourceHelper.spark) .setInputCols("document") - .setOutputCol("completions") + .setOutputCol("reranked_documents") .setBatchSize(4) .setQuery(query) @@ -34,11 +35,19 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { "The girl is carrying a baby.", "A man is riding a horse.", "A young girl is playing violin.").toDF("text").repartition(1) - lazy val pipeline: Pipeline = new Pipeline().setStages(Array(documentAssembler, model)) + + lazy val finisher: GGUFRankingFinisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(-1) + .setMinRelevanceScore(0.1) + .setMinMaxScaling(true) + lazy val pipeline: Pipeline = + new Pipeline().setStages(Array(documentAssembler, model)) def assertAnnotationsNonEmpty(resultDf: DataFrame): Unit = { Annotation - .collect(resultDf, "completions") + .collect(resultDf, "reranked_documents") .foreach(annotations => { println(annotations.head) println(annotations.head.metadata) @@ -77,7 +86,7 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { newPipeline .fit(data) .transform(data) - .select("completions") + .select("reranked_documents") .show(truncate = false) } @@ -96,7 +105,7 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { val model = AutoGGUFReranker .pretrained() .setInputCols("document") - .setOutputCol("completions") + .setOutputCol("reranked_documents") .setBatchSize(2) val pipeline = @@ -109,7 +118,7 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { val model: AutoGGUFReranker = AutoGGUFReranker .loadSavedModel(modelPath, ResourceHelper.spark) .setInputCols("document") - .setOutputCol("completions") + .setOutputCol("reranked_documents") .setBatchSize(4) val pipeline = new Pipeline().setStages(Array(documentAssembler, model)) assertThrows[org.apache.spark.SparkException] { @@ -117,4 +126,13 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { result.show() } } + + it should "be able to finisher the reranked documents" taggedAs FastTest in { + model.setQuery(query) + val pipeline = new Pipeline().setStages(Array(documentAssembler, model, finisher)) + val result = pipeline.fit(data).transform(data) + +// assertAnnotationsNonEmpty(result) + result.select("ranked_documents").show(truncate = false) + } } diff --git a/src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala b/src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala new file mode 100644 index 00000000000000..3797bc8f29eb56 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala @@ -0,0 +1,349 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.finisher + +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, ContentProvider} +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.FastTest +import com.johnsnowlabs.util.Benchmark +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.scalatest.flatspec.AnyFlatSpec + +class GGUFRankingFinisherTest extends AnyFlatSpec { + + import ResourceHelper.spark.implicits._ + // Mock data to simulate AutoGGUFReranker output + def createMockRerankerOutput(): DataFrame = { + val spark = ResourceHelper.spark + + val documents = Seq( + ("A man is eating food.", 0.85, "A man is eating pasta."), + ("A man is eating a piece of bread.", 0.72, "A man is eating pasta."), + ("The girl is carrying a baby.", 0.15, "A man is eating pasta."), + ("A man is riding a horse.", 0.28, "A man is eating pasta."), + ("A young girl is playing violin.", 0.05, "A man is eating pasta.")) + + val mockAnnotations = documents.map { case (text, score, query) => + Row( + AnnotatorType.DOCUMENT, + 0, + text.length - 1, + text, + Map("relevance_score" -> score.toString, "query" -> query), + Array.empty[Float]) + } + + val rows = Seq(Row(mockAnnotations)) + val annotationSchema = StructType( + Array( + StructField("annotatorType", StringType, nullable = false), + StructField("begin", IntegerType, nullable = false), + StructField("end", IntegerType, nullable = false), + StructField("result", StringType, nullable = false), + StructField("metadata", MapType(StringType, StringType), nullable = false), + StructField("embeddings", ArrayType(FloatType), nullable = false))) + val schema = StructType( + Array(StructField("reranked_documents", ArrayType(annotationSchema), nullable = false))) + + spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + } + + "GGUFRankingFinisher with default settings" should "process documents and add rank metadata" taggedAs FastTest in { + val mockData = createMockRerankerOutput() + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + + val result = finisher.transform(mockData) + + assert(result.columns.contains("ranked_documents")) + + // Get the ranked documents + val rankedDocs = + result.select("ranked_documents").rdd.map(_.getAs[Seq[Row]](0)).collect().head + + assert(rankedDocs.length == 5) + + // Check that results are sorted by relevance score in descending order + val scores = + rankedDocs.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) + assert(scores.zip(scores.tail).forall { case (a, b) => a >= b }) + + // Check that rank metadata is added + val ranks = rankedDocs.map(_.getAs[Map[String, String]]("metadata")("rank").toInt) + assert(ranks == Seq(1, 2, 3, 4, 5)) + } + + "GGUFRankingFinisher with topK" should "return only top k results" taggedAs FastTest in { + val mockData = createMockRerankerOutput() + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(3) + + val result = finisher.transform(mockData) + + // Should have only 1 row since all annotations are in a single row + assert(result.count() == 1) + + val rankedDocs = + result.select("ranked_documents").rdd.map(_.getAs[Seq[Row]](0)).collect().head + assert(rankedDocs.length == 3) + + // Check that we get the top 3 scores + val scores = + rankedDocs.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) + assert(scores.length == 3) + assert(scores.contains(0.85)) + assert(scores.contains(0.72)) + assert(scores.contains(0.28)) + + // Check ranks are correct + val ranks = rankedDocs.map(_.getAs[Map[String, String]]("metadata")("rank").toInt) + assert(ranks == Seq(1, 2, 3)) + } + + "GGUFRankingFinisher with threshold" should "filter by minimum relevance score" taggedAs FastTest in { + val mockData = createMockRerankerOutput() + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setMinRelevanceScore(0.3) + + val result = finisher.transform(mockData) + + val rankedDocs = + result.select("ranked_documents").rdd.map(_.getAs[Seq[Row]](0)).collect().head + assert(rankedDocs.length == 2) // Only scores >= 0.3 (0.85 and 0.72) + + val scores = + rankedDocs.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) + assert(scores.forall(_ >= 0.3)) + } + + "GGUFRankingFinisher with min-max scaling" should "normalize scores to 0-1 range" taggedAs FastTest in { + val mockData = createMockRerankerOutput() + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setMinMaxScaling(true) + + val result = finisher.transform(mockData) + + val rankedDocs = + result.select("ranked_documents").rdd.map(_.getAs[Seq[Row]](0)).collect().head + val scores = + rankedDocs.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) + + // Check that scores are between 0 and 1 + assert(scores.forall(score => score >= 0.0 && score <= 1.0)) + + // Check that we have both min (0.0) and max (1.0) values + assert(scores.contains(1.0)) // Max original score should be 1.0 + assert(scores.contains(0.0)) // Min original score should be 0.0 + } + + "GGUFRankingFinisher with combined filters" should "apply topK, threshold, and scaling together" taggedAs FastTest in { + val mockData = createMockRerankerOutput() + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(2) + .setMinRelevanceScore(0.1) // After scaling, this should filter some results + .setMinMaxScaling(true) + + val result = finisher.transform(mockData) + + val rankedDocs = + result.select("ranked_documents").rdd.map(_.getAs[Seq[Row]](0)).collect().head + + // Should have at most 2 results due to topK + assert(rankedDocs.length <= 2) + + val scores = + rankedDocs.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) + + // All scores should be >= 0.1 and <= 1.0 + assert(scores.forall(score => score >= 0.1 && score <= 1.0)) + + // Results should be sorted descending + assert(scores.zip(scores.tail).forall { case (a, b) => a >= b }) + + // Check that ranks are correct + val ranks = rankedDocs.map(_.getAs[Map[String, String]]("metadata")("rank").toInt) + assert(ranks == (1 to rankedDocs.length).toSeq) + } + + "GGUFRankingFinisher" should "handle empty input" taggedAs FastTest in { + val spark = ResourceHelper.spark + + val emptyRows = Seq(Row(Array.empty[Row])) + val annotationSchema = StructType( + Array( + StructField("annotatorType", StringType, nullable = false), + StructField("begin", IntegerType, nullable = false), + StructField("end", IntegerType, nullable = false), + StructField("result", StringType, nullable = false), + StructField("metadata", MapType(StringType, StringType), nullable = false), + StructField("embeddings", ArrayType(FloatType), nullable = false))) + val schema = StructType( + Array(StructField("reranked_documents", ArrayType(annotationSchema), nullable = false))) + + val emptyData = spark.createDataFrame(spark.sparkContext.parallelize(emptyRows), schema) + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + + val result = finisher.transform(emptyData) + + // Since we now filter out empty rows, the result should have no rows + assert(result.count() == 0) + } + + "GGUFRankingFinisher" should "preserve query information in metadata" taggedAs FastTest in { + val mockData = createMockRerankerOutput() + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + + val result = finisher.transform(mockData) + + val rankedDocs = + result.select("ranked_documents").rdd.map(_.getAs[Seq[Row]](0)).collect().head + + // Check that query information is preserved in metadata + rankedDocs.foreach { doc => + val metadata = doc.getAs[Map[String, String]]("metadata") + assert(metadata.contains("query")) + assert(metadata("query") == "A man is eating pasta.") + } + } + + "GGUFRankingFinisher" should "handle documents with missing relevance scores" taggedAs FastTest in { + val spark = ResourceHelper.spark + + val documents = Seq( + ("A man is eating food.", Some("0.85"), "A man is eating pasta."), + ("A man is eating a piece of bread.", None, "A man is eating pasta."), // Missing score + ("The girl is carrying a baby.", Some("0.15"), "A man is eating pasta.")) + + val testAnnotations: Seq[Row] = documents.map { case (text, scoreOpt, query) => + val metadata = Map("query" -> query) ++ scoreOpt.map("relevance_score" -> _).toMap + Row(AnnotatorType.DOCUMENT, 0, text.length - 1, text, metadata, Array.empty[Float]) + } + + val rows = Seq(Row(testAnnotations)) + val annotationSchema = StructType( + Array( + StructField("annotatorType", StringType, nullable = false), + StructField("begin", IntegerType, nullable = false), + StructField("end", IntegerType, nullable = false), + StructField("result", StringType, nullable = false), + StructField("metadata", MapType(StringType, StringType), nullable = false), + StructField("embeddings", ArrayType(FloatType), nullable = false))) + val schema = StructType( + Array(StructField("reranked_documents", ArrayType(annotationSchema), nullable = false))) + + val testData = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + + val result = finisher.transform(testData) + + val rankedDocs = + result.select("ranked_documents").rdd.map(_.getAs[Seq[Row]](0)).collect().head + assert(rankedDocs.length == 3) + + // Document with missing score should get 0.0 and be ranked last + val scores = + rankedDocs.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) + assert(scores.last == 0.0) // Missing score becomes 0.0 + } + + "GGUFRankingFinisher with topK across multiple rows" should "filter out empty rows and return only top k globally" taggedAs FastTest in { + val spark = ResourceHelper.spark + + val documents = Seq( + ("A man is eating food.", 0.85, "A man is eating pasta."), + ("A man is eating a piece of bread.", 0.72, "A man is eating pasta."), + ("The girl is carrying a baby.", 0.15, "A man is eating pasta."), + ("A man is riding a horse.", 0.28, "A man is eating pasta."), + ("A young girl is playing violin.", 0.05, "A man is eating pasta.")) + + // Create individual rows, each with one annotation (simulating real usage) + val testRows: Seq[Row] = documents.map { case (text, score, query) => + val annotation = Row( + AnnotatorType.DOCUMENT, + 0, + text.length - 1, + text, + Map("relevance_score" -> score.toString, "query" -> query), + Array.empty[Float]) + Row(Array(annotation)) + } + + val annotationSchema = StructType( + Array( + StructField("annotatorType", StringType, nullable = false), + StructField("begin", IntegerType, nullable = false), + StructField("end", IntegerType, nullable = false), + StructField("result", StringType, nullable = false), + StructField("metadata", MapType(StringType, StringType), nullable = false), + StructField("embeddings", ArrayType(FloatType), nullable = false))) + val schema = StructType( + Array(StructField("reranked_documents", ArrayType(annotationSchema), nullable = false))) + + val testData = spark.createDataFrame(spark.sparkContext.parallelize(testRows), schema) + + val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(3) + + val result = finisher.transform(testData) + + // Should have only 3 rows (top-3 globally) + assert(result.count() == 3) + + val allAnnotations = result.select("ranked_documents").collect().flatMap(_.getSeq[Row](0)) + val scores = + allAnnotations.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) + val ranks = allAnnotations.map(_.getAs[Map[String, String]]("metadata")("rank").toInt) + + // Should have exactly 3 documents with scores 0.85, 0.72, 0.28 + assert(scores.length == 3) + assert(scores.contains(0.85)) + assert(scores.contains(0.72)) + assert(scores.contains(0.28)) + + // Ranks should be 1, 2, 3 + assert(ranks.sorted sameElements Array(1, 2, 3)) + } +} From 3719ef0f70f6a822ffee35a0a2aa1dd8326debfe Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 20 Aug 2025 07:46:07 +0000 Subject: [PATCH 2/7] Add GGUFRankingFinisher implementation and tests for ranking functionality --- docs/en/finisher_gguf_ranking.md | 152 ++++++++++ python/sparknlp/base/__init__.py | 1 + python/sparknlp/base/gguf_ranking_finisher.py | 234 +++++++++++++++ .../seq2seq/auto_gguf_reranker_test.py | 163 +++++++++- .../test/base/gguf_ranking_finisher_test.py | 282 ++++++++++++++++++ 5 files changed, 830 insertions(+), 2 deletions(-) create mode 100644 docs/en/finisher_gguf_ranking.md create mode 100644 python/sparknlp/base/gguf_ranking_finisher.py create mode 100644 python/test/base/gguf_ranking_finisher_test.py diff --git a/docs/en/finisher_gguf_ranking.md b/docs/en/finisher_gguf_ranking.md new file mode 100644 index 00000000000000..e92a83fc365ae0 --- /dev/null +++ b/docs/en/finisher_gguf_ranking.md @@ -0,0 +1,152 @@ +# GGUFRankingFinisher + +The `GGUFRankingFinisher` is a Spark NLP finisher designed to post-process the output of `AutoGGUFReranker`. It provides advanced ranking capabilities including top-k selection, score-based filtering, and normalization. + +## Features + +- **Top-K Selection**: Select only the top k documents by relevance score +- **Score Thresholding**: Filter documents by minimum relevance score +- **Min-Max Scaling**: Normalize relevance scores to 0-1 range +- **Sorting**: Automatically sorts documents by relevance score in descending order +- **Ranking**: Adds rank metadata to each document + +## Parameters + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `inputCols` | `Array[String]` | Name of input annotation columns containing reranked documents | - | +| `outputCol` | `String` | Name of output annotation column containing ranked documents | `"ranked_documents"` | +| `topK` | `Int` | Maximum number of top documents to return (-1 for no limit) | `-1` | +| `minRelevanceScore` | `Double` | Minimum relevance score threshold | `Double.MinValue` | +| `minMaxScaling` | `Boolean` | Whether to apply min-max scaling to normalize scores | `false` | + +## Usage + +### Basic Usage + +```scala +import com.johnsnowlabs.nlp.finisher.GGUFRankingFinisher + +val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") +``` + +### Top-K Selection + +```scala +val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(5) // Get top 5 most relevant documents +``` + +### Score Thresholding + +```scala +val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setMinRelevanceScore(0.3) // Only documents with score >= 0.3 +``` + +### Min-Max Scaling + +```scala +val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setMinMaxScaling(true) // Normalize scores to 0-1 range +``` + +### Combined Usage + +```scala +val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(3) + .setMinRelevanceScore(0.2) + .setMinMaxScaling(true) +``` + +## Complete Pipeline Example + +```scala +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.annotators.seq2seq.AutoGGUFReranker +import com.johnsnowlabs.nlp.finisher.GGUFRankingFinisher +import org.apache.spark.ml.Pipeline + +// Document assembler +val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + +// Reranker +val reranker = AutoGGUFReranker + .pretrained("bge-reranker-v2-m3-Q4_K_M") + .setInputCols("document") + .setOutputCol("reranked_documents") + .setQuery("A man is eating pasta.") + +// Finisher +val finisher = new GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(3) + .setMinMaxScaling(true) + +// Pipeline +val pipeline = new Pipeline() + .setStages(Array(documentAssembler, reranker, finisher)) +``` + +## Python Usage + +```python +from sparknlp.finisher import GGUFRankingFinisher +from sparknlp.annotator import AutoGGUFReranker +from sparknlp.base import DocumentAssembler +from pyspark.ml import Pipeline + +# Create finisher +finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") \ + .setTopK(3) \ + .setMinMaxScaling(True) + +# Create pipeline +pipeline = Pipeline(stages=[document_assembler, reranker, finisher]) +``` + +## Output Schema + +The finisher produces a DataFrame with the output annotation column containing ranked documents. Each document annotation contains: + +- **result**: The document text +- **metadata**: Including `relevance_score`, `rank`, and original `query` information +- **begin/end**: Character positions in the original text +- **annotatorType**: Set to `DOCUMENT` + +## Processing Order + +The finisher applies operations in the following order: + +1. **Extract** documents and metadata from annotations across all rows +2. **Scale** relevance scores (if min-max scaling is enabled) +3. **Filter** by minimum relevance score threshold +4. **Sort** by relevance score (descending) +5. **Limit** to top-k results globally (if specified) +6. **Add rank** metadata to each document +7. **Return** filtered rows with ranked annotations + +## Notes + +- The finisher expects input from `AutoGGUFReranker` or compatible annotators that produce documents with `relevance_score` metadata +- Min-max scaling is applied before threshold filtering, so thresholds should be set according to the scaled range (0.0-1.0) +- Results are always sorted by relevance score in descending order +- Top-k filtering is applied globally across all input rows, not per row +- The finisher adds `rank` metadata to each document indicating its position in the ranking +- Rows with empty annotation arrays are filtered out from the result diff --git a/python/sparknlp/base/__init__.py b/python/sparknlp/base/__init__.py index f4fbeadc55e91d..95facb1ea1a68d 100644 --- a/python/sparknlp/base/__init__.py +++ b/python/sparknlp/base/__init__.py @@ -17,6 +17,7 @@ from sparknlp.base.multi_document_assembler import * from sparknlp.base.embeddings_finisher import * from sparknlp.base.finisher import * +from sparknlp.base.gguf_ranking_finisher import * from sparknlp.base.graph_finisher import * from sparknlp.base.has_recursive_fit import * from sparknlp.base.has_recursive_transform import * diff --git a/python/sparknlp/base/gguf_ranking_finisher.py b/python/sparknlp/base/gguf_ranking_finisher.py new file mode 100644 index 00000000000000..7c7ca423015adf --- /dev/null +++ b/python/sparknlp/base/gguf_ranking_finisher.py @@ -0,0 +1,234 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for the GGUFRankingFinisher.""" + +from pyspark import keyword_only +from pyspark.ml.param import TypeConverters, Params, Param +from sparknlp.internal import AnnotatorTransformer + + +class GGUFRankingFinisher(AnnotatorTransformer): + """Finisher for AutoGGUFReranker outputs that provides ranking capabilities + including top-k selection, sorting by relevance score, and score normalization. + + This finisher processes the output of AutoGGUFReranker, which contains documents with + relevance scores in their metadata. It provides several options for post-processing: + + - Top-k selection: Select only the top k documents by relevance score + - Score thresholding: Filter documents by minimum relevance score + - Min-max scaling: Normalize relevance scores to 0-1 range + - Sorting: Sort documents by relevance score in descending order + - Ranking: Add rank information to document metadata + + The finisher preserves the document annotation structure while adding ranking information + to the metadata and optionally filtering/sorting the documents. + + For extended examples of usage, see the `Examples + `__. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``DOCUMENT`` + ====================== ====================== + + Parameters + ---------- + inputCols + Name of input annotation columns containing reranked documents + outputCol + Name of output annotation column containing ranked documents, by default "ranked_documents" + topK + Maximum number of top documents to return based on relevance score (-1 for no limit), by default -1 + minRelevanceScore + Minimum relevance score threshold for filtering documents, by default Double.MinValue + minMaxScaling + Whether to apply min-max scaling to normalize relevance scores to 0-1 range, by default False + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> reranker = AutoGGUFReranker.pretrained("bge-reranker-v2-m3-Q4_K_M") \\ + ... .setInputCols("document") \\ + ... .setOutputCol("reranked_documents") \\ + ... .setQuery("A man is eating pasta.") + >>> finisher = GGUFRankingFinisher() \\ + ... .setInputCols("reranked_documents") \\ + ... .setOutputCol("ranked_documents") \\ + ... .setTopK(3) \\ + ... .setMinMaxScaling(True) + >>> pipeline = Pipeline().setStages([documentAssembler, reranker, finisher]) + >>> data = spark.createDataFrame([ + ... ("A man is eating food.",), + ... ("A man is eating a piece of bread.",), + ... ("The girl is carrying a baby.",), + ... ("A man is riding a horse.",) + ... ], ["text"]) + >>> result = pipeline.fit(data).transform(data) + >>> result.select("ranked_documents").show(truncate=False) + # Documents will be sorted by relevance with rank information in metadata + """ + + name = "GGUFRankingFinisher" + + inputCols = Param(Params._dummy(), + "inputCols", + "Name of input annotation columns containing reranked documents", + typeConverter=TypeConverters.toListString) + + outputCol = Param(Params._dummy(), + "outputCol", + "Name of output annotation column containing ranked documents", + typeConverter=TypeConverters.toListString) + + topK = Param(Params._dummy(), + "topK", + "Maximum number of top documents to return based on relevance score (-1 for no limit)", + typeConverter=TypeConverters.toInt) + + minRelevanceScore = Param(Params._dummy(), + "minRelevanceScore", + "Minimum relevance score threshold for filtering documents", + typeConverter=TypeConverters.toFloat) + + minMaxScaling = Param(Params._dummy(), + "minMaxScaling", + "Whether to apply min-max scaling to normalize relevance scores to 0-1 range", + typeConverter=TypeConverters.toBoolean) + + @keyword_only + def __init__(self): + super(GGUFRankingFinisher, self).__init__( + classname="com.johnsnowlabs.nlp.finisher.GGUFRankingFinisher") + self._setDefault( + topK=-1, + minRelevanceScore=float('-inf'), # Equivalent to Double.MinValue + minMaxScaling=False, + outputCol=["ranked_documents"] + ) + + @keyword_only + def setParams(self): + kwargs = self._input_kwargs + return self._set(**kwargs) + + def setInputCols(self, *value): + """Sets input annotation column names. + + Parameters + ---------- + value : List[str] + Input annotation column names containing reranked documents + """ + if len(value) == 1 and isinstance(value[0], list): + return self._set(inputCols=value[0]) + else: + return self._set(inputCols=list(value)) + + def getInputCols(self): + """Gets input annotation column names. + + Returns + ------- + List[str] + Input annotation column names + """ + return self.getOrDefault(self.inputCols) + + def setOutputCol(self, value): + """Sets output annotation column name. + + Parameters + ---------- + value : str + Output annotation column name + """ + return self._set(outputCol=[value]) + + def getOutputCol(self): + """Gets output annotation column name. + + Returns + ------- + str + Output annotation column name + """ + output_cols = self.getOrDefault(self.outputCol) + return output_cols[0] if output_cols else "ranked_documents" + + def setTopK(self, value): + """Sets maximum number of top documents to return. + + Parameters + ---------- + value : int + Maximum number of top documents to return (-1 for no limit) + """ + return self._set(topK=value) + + def getTopK(self): + """Gets maximum number of top documents to return. + + Returns + ------- + int + Maximum number of top documents to return + """ + return self.getOrDefault(self.topK) + + def setMinRelevanceScore(self, value): + """Sets minimum relevance score threshold. + + Parameters + ---------- + value : float + Minimum relevance score threshold + """ + return self._set(minRelevanceScore=value) + + def getMinRelevanceScore(self): + """Gets minimum relevance score threshold. + + Returns + ------- + float + Minimum relevance score threshold + """ + return self.getOrDefault(self.minRelevanceScore) + + def setMinMaxScaling(self, value): + """Sets whether to apply min-max scaling. + + Parameters + ---------- + value : bool + Whether to apply min-max scaling to normalize scores + """ + return self._set(minMaxScaling=value) + + def getMinMaxScaling(self): + """Gets whether to apply min-max scaling. + + Returns + ------- + bool + Whether min-max scaling is enabled + """ + return self.getOrDefault(self.minMaxScaling) diff --git a/python/test/annotator/seq2seq/auto_gguf_reranker_test.py b/python/test/annotator/seq2seq/auto_gguf_reranker_test.py index 2ec3999b1a083e..5f7543a4522aed 100644 --- a/python/test/annotator/seq2seq/auto_gguf_reranker_test.py +++ b/python/test/annotator/seq2seq/auto_gguf_reranker_test.py @@ -254,5 +254,164 @@ def runTest(self): print(f"Expected behavior when query not set: {str(e)}") -if __name__ == "__main__": - unittest.main() +@pytest.mark.slow +class AutoGGUFRerankerWithFinisherTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.query = "A man is eating pasta." + self.data = ( + self.spark.createDataFrame( + [ + ["A man is eating food."], + ["A man is eating a piece of bread."], + ["The girl is carrying a baby."], + ["A man is riding a horse."], + ["A young girl is playing violin."], + ] + ) + .toDF("text") + .repartition(1) + ) + + def runTest(self): + document_assembler = ( + DocumentAssembler().setInputCol("text").setOutputCol("document") + ) + + model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" + + # Skip test if model file doesn't exist + import os + if not os.path.exists(model_path): + self.skipTest(f"Model file not found: {model_path}") + + reranker = ( + AutoGGUFReranker.loadSavedModel(model_path, self.spark) + .setInputCols("document") + .setOutputCol("reranked_documents") + .setBatchSize(4) + .setQuery(self.query) + ) + + # Add the GGUFRankingFinisher to test the full pipeline + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(3) + .setMinMaxScaling(True) + ) + + pipeline = Pipeline().setStages([document_assembler, reranker, finisher]) + results = pipeline.fit(self.data).transform(self.data) + + # Check that results are returned + collected_results = results.collect() + self.assertGreater(len(collected_results), 0) + + # Should have at most 3 results due to topK + self.assertLessEqual(len(collected_results), 3) + + # Check that each result has ranked_documents column + for row in collected_results: + self.assertIsNotNone(row["ranked_documents"]) + # Check that annotations have metadata with relevance_score and rank + annotations = row["ranked_documents"] + for annotation in annotations: + self.assertIn("relevance_score", annotation.metadata) + self.assertIn("rank", annotation.metadata) + self.assertIn("query", annotation.metadata) + self.assertEqual(annotation.metadata["query"], self.query) + + # Check that relevance score is normalized (due to minMaxScaling) + score = float(annotation.metadata["relevance_score"]) + self.assertTrue(0.0 <= score <= 1.0) + + # Check that rank is a valid integer + rank = int(annotation.metadata["rank"]) + self.assertIsInstance(rank, int) + self.assertGreaterEqual(rank, 1) + + # Verify that results are sorted by rank + for row in collected_results: + annotations = row["ranked_documents"] + ranks = [int(annotation.metadata["rank"]) for annotation in annotations] + self.assertEqual(ranks, sorted(ranks)) + + print("Pipeline with AutoGGUFReranker and GGUFRankingFinisher completed successfully") + results.select("ranked_documents").show(truncate=False) + + +@pytest.mark.slow +class AutoGGUFRerankerFinisherCombinedFiltersTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.query = "A man is eating pasta." + self.data = ( + self.spark.createDataFrame( + [ + ["A man is eating food."], + ["A man is eating a piece of bread."], + ["The girl is carrying a baby."], + ["A man is riding a horse."], + ["A young girl is playing violin."], + ["A woman is cooking dinner."], + ] + ) + .toDF("text") + .repartition(1) + ) + + def runTest(self): + document_assembler = ( + DocumentAssembler().setInputCol("text").setOutputCol("document") + ) + + model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" + + # Skip test if model file doesn't exist + import os + if not os.path.exists(model_path): + self.skipTest(f"Model file not found: {model_path}") + + reranker = ( + AutoGGUFReranker.loadSavedModel(model_path, self.spark) + .setInputCols("document") + .setOutputCol("reranked_documents") + .setBatchSize(4) + .setQuery(self.query) + ) + + # Test with combined filters: topK, threshold, and scaling + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(2) + .setMinRelevanceScore(0.1) + .setMinMaxScaling(True) + ) + + pipeline = Pipeline().setStages([document_assembler, reranker, finisher]) + results = pipeline.fit(self.data).transform(self.data) + + collected_results = results.collect() + + # Should have at most 2 results due to topK + self.assertLessEqual(len(collected_results), 2) + + # Check that all results meet the criteria + for row in collected_results: + annotations = row["ranked_documents"] + for annotation in annotations: + # Check normalized scores are >= 0.1 threshold + score = float(annotation.metadata["relevance_score"]) + self.assertTrue(0.1 <= score <= 1.0) + + # Check rank metadata exists + self.assertIn("rank", annotation.metadata) + rank = int(annotation.metadata["rank"]) + self.assertGreaterEqual(rank, 1) + + print("Combined filters test completed successfully") + results.select("ranked_documents").show(truncate=False) diff --git a/python/test/base/gguf_ranking_finisher_test.py b/python/test/base/gguf_ranking_finisher_test.py new file mode 100644 index 00000000000000..e8a907be188b25 --- /dev/null +++ b/python/test/base/gguf_ranking_finisher_test.py @@ -0,0 +1,282 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from pyspark.sql.types import StructType, StructField +from pyspark.sql import Row + +from sparknlp.base import GGUFRankingFinisher +from sparknlp.annotation import Annotation +from test.util import SparkContextForTest + + +@pytest.mark.fast +class GGUFRankingFinisherTestSpec(unittest.TestCase): + + def setUp(self): + self.spark = SparkContextForTest.spark + + def create_mock_reranker_output(self): + """Create mock data to simulate AutoGGUFReranker output.""" + + documents = [ + ("A man is eating food.", 0.85, "A man is eating pasta."), + ("A man is eating a piece of bread.", 0.72, "A man is eating pasta."), + ("The girl is carrying a baby.", 0.15, "A man is eating pasta."), + ("A man is riding a horse.", 0.28, "A man is eating pasta."), + ("A young girl is playing violin.", 0.05, "A man is eating pasta.") + ] + + annotations = [] + for text, score, query in documents: + annotation = Annotation( + annotatorType="document", + begin=0, + end=len(text) - 1, + result=text, + metadata={"relevance_score": str(score), "query": query}, + embeddings=[] + ) + annotations.append(annotation) + + # Create DataFrame with annotation array + rows = [Row(reranked_documents=annotations)] + schema = StructType([ + StructField("reranked_documents", Annotation.arrayType(), nullable=False) + ]) + + return self.spark.createDataFrame(rows, schema) + + def test_default_settings(self): + """Test GGUFRankingFinisher with default settings.""" + mock_data = self.create_mock_reranker_output() + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") + + result = finisher.transform(mock_data) + + self.assertIn("ranked_documents", result.columns) + + # Get the ranked documents + ranked_docs = result.collect()[0]["ranked_documents"] + + self.assertEqual(len(ranked_docs), 5) + + # Check that results are sorted by relevance score in descending order + scores = [float(doc.metadata["relevance_score"]) for doc in ranked_docs] + self.assertEqual(scores, sorted(scores, reverse=True)) + + # Check that rank metadata is added + ranks = [int(doc.metadata["rank"]) for doc in ranked_docs] + self.assertEqual(ranks, [1, 2, 3, 4, 5]) + + def test_top_k(self): + """Test GGUFRankingFinisher with topK setting.""" + mock_data = self.create_mock_reranker_output() + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") \ + .setTopK(3) + + result = finisher.transform(mock_data) + + ranked_docs = result.collect()[0]["ranked_documents"] + self.assertEqual(len(ranked_docs), 3) + + # Check that we get the top 3 scores + scores = [float(doc.metadata["relevance_score"]) for doc in ranked_docs] + self.assertEqual(len(scores), 3) + self.assertIn(0.85, scores) + self.assertIn(0.72, scores) + self.assertIn(0.28, scores) + + # Check ranks are correct + ranks = [int(doc.metadata["rank"]) for doc in ranked_docs] + self.assertEqual(ranks, [1, 2, 3]) + + def test_threshold_filtering(self): + """Test GGUFRankingFinisher with minimum relevance score threshold.""" + mock_data = self.create_mock_reranker_output() + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") \ + .setMinRelevanceScore(0.3) + + result = finisher.transform(mock_data) + + ranked_docs = result.collect()[0]["ranked_documents"] + self.assertEqual(len(ranked_docs), 2) # Only scores >= 0.3 (0.85 and 0.72) + + scores = [float(doc.metadata["relevance_score"]) for doc in ranked_docs] + self.assertTrue(all(score >= 0.3 for score in scores)) + + def test_min_max_scaling(self): + """Test GGUFRankingFinisher with min-max scaling.""" + mock_data = self.create_mock_reranker_output() + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") \ + .setMinMaxScaling(True) + + result = finisher.transform(mock_data) + + ranked_docs = result.collect()[0]["ranked_documents"] + scores = [float(doc.metadata["relevance_score"]) for doc in ranked_docs] + + # Check that scores are between 0 and 1 + self.assertTrue(all(0.0 <= score <= 1.0 for score in scores)) + + # Check that we have both min (0.0) and max (1.0) values + self.assertIn(1.0, scores) # Max original score should be 1.0 + self.assertIn(0.0, scores) # Min original score should be 0.0 + + def test_combined_filters(self): + """Test GGUFRankingFinisher with combined topK, threshold, and scaling.""" + mock_data = self.create_mock_reranker_output() + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") \ + .setTopK(2) \ + .setMinRelevanceScore(0.1) \ + .setMinMaxScaling(True) + + result = finisher.transform(mock_data) + + ranked_docs = result.collect()[0]["ranked_documents"] + + # Should have at most 2 results due to topK + self.assertLessEqual(len(ranked_docs), 2) + + scores = [float(doc.metadata["relevance_score"]) for doc in ranked_docs] + + # All scores should be >= 0.1 and <= 1.0 + self.assertTrue(all(0.1 <= score <= 1.0 for score in scores)) + + # Results should be sorted descending + self.assertEqual(scores, sorted(scores, reverse=True)) + + # Check that ranks are correct + ranks = [int(doc.metadata["rank"]) for doc in ranked_docs] + self.assertEqual(ranks, list(range(1, len(ranked_docs) + 1))) + + def test_empty_input(self): + """Test GGUFRankingFinisher with empty input.""" + + # Create empty annotations + rows = [Row(reranked_documents=[])] + schema = StructType([ + StructField("reranked_documents", Annotation.arrayType(), nullable=False) + ]) + + empty_data = self.spark.createDataFrame(rows, schema) + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") + + result = finisher.transform(empty_data) + + # Since empty rows are filtered out, the result should have no rows + result_count = result.count() + self.assertEqual(result_count, 0) + + def test_query_preservation(self): + """Test that query information is preserved in metadata.""" + mock_data = self.create_mock_reranker_output() + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") + + result = finisher.transform(mock_data) + + ranked_docs = result.collect()[0]["ranked_documents"] + + # Check that query information is preserved in metadata + for doc in ranked_docs: + self.assertIn("query", doc.metadata) + self.assertEqual(doc.metadata["query"], "A man is eating pasta.") + + def test_missing_relevance_scores(self): + """Test handling of documents with missing relevance scores.""" + + documents = [ + ("A man is eating food.", {"relevance_score": "0.85", "query": "A man is eating pasta."}), + ("A man is eating a piece of bread.", {"query": "A man is eating pasta."}), # Missing score + ("The girl is carrying a baby.", {"relevance_score": "0.15", "query": "A man is eating pasta."}) + ] + + annotations = [] + for text, metadata in documents: + annotation = Annotation( + annotatorType="document", + begin=0, + end=len(text) - 1, + result=text, + metadata=metadata, + embeddings=[] + ) + annotations.append(annotation) + + rows = [Row(reranked_documents=annotations)] + schema = StructType([ + StructField("reranked_documents", Annotation.arrayType(), nullable=False) + ]) + + test_data = self.spark.createDataFrame(rows, schema) + + finisher = GGUFRankingFinisher() \ + .setInputCols("reranked_documents") \ + .setOutputCol("ranked_documents") + + result = finisher.transform(test_data) + + ranked_docs = result.collect()[0]["ranked_documents"] + self.assertEqual(len(ranked_docs), 3) + + # Document with missing score should get 0.0 and be ranked last + scores = [float(doc.metadata["relevance_score"]) for doc in ranked_docs] + self.assertEqual(scores[-1], 0.0) # Missing score becomes 0.0 + + def test_parameter_getters_setters(self): + """Test parameter getters and setters.""" + finisher = GGUFRankingFinisher() + + # Test topK + finisher.setTopK(5) + self.assertEqual(finisher.getTopK(), 5) + + # Test minRelevanceScore + finisher.setMinRelevanceScore(0.5) + self.assertEqual(finisher.getMinRelevanceScore(), 0.5) + + # Test minMaxScaling + finisher.setMinMaxScaling(True) + self.assertTrue(finisher.getMinMaxScaling()) + + # Test outputCol + finisher.setOutputCol("custom_output") + self.assertEqual(finisher.getOutputCol(), "custom_output") + + +if __name__ == '__main__': + unittest.main() From 00715e44213ae74dd44ddb3c167b80b8c3184afa Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 20 Aug 2025 07:48:10 +0000 Subject: [PATCH 3/7] Update test case to tag "finisher the reranked documents" as SlowTest --- .../nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala index 9056cee4ca846c..2a25ad03b65fde 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala @@ -127,7 +127,7 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { } } - it should "be able to finisher the reranked documents" taggedAs FastTest in { + it should "be able to finisher the reranked documents" taggedAs SlowTest in { model.setQuery(query) val pipeline = new Pipeline().setStages(Array(documentAssembler, model, finisher)) val result = pipeline.fit(data).transform(data) From 9c1f1b1c6d6d233dad5b7470ebf91e52f0492e43 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Mon, 25 Aug 2025 08:17:04 +0000 Subject: [PATCH 4/7] Add documentation for GGUFRankingFinisher with features, usage examples, and output schema --- .../GGUFRankingFinisher.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/en/{finisher_gguf_ranking.md => annotator_entries/GGUFRankingFinisher.md} (100%) diff --git a/docs/en/finisher_gguf_ranking.md b/docs/en/annotator_entries/GGUFRankingFinisher.md similarity index 100% rename from docs/en/finisher_gguf_ranking.md rename to docs/en/annotator_entries/GGUFRankingFinisher.md From 80cc8a0f4b2ff6fee25cb0aa6188ebff61de2b50 Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Mon, 1 Sep 2025 10:44:30 +0200 Subject: [PATCH 5/7] Resolve partition warning for windowing --- .../nlp/finisher/GGUFRankingFinisher.scala | 28 +++++++++++-------- .../finisher/GGUFRankingFinisherTest.scala | 28 ++++++++----------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala b/src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala index a7343f3ed9620f..f12ba37125cb40 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisher.scala @@ -15,18 +15,15 @@ */ package com.johnsnowlabs.nlp.finisher -import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} import com.johnsnowlabs.nlp.util.FinisherUtil +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{BooleanParam, DoubleParam, IntParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param._ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.{UserDefinedFunction, Window} -import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ - -import scala.collection.mutable +import org.apache.spark.sql.{DataFrame, Dataset, Row} /** Finisher for AutoGGUFReranker outputs that provides ranking capabilities including top-k * selection, sorting by relevance score, and score normalization. @@ -57,7 +54,7 @@ import scala.collection.mutable * .setOutputCol("document") * * val reranker = AutoGGUFReranker - * .pretrained("bge-reranker-v2-m3-Q4_K_M") + * .pretrained("bge_reranker_v2_m3-Q4_K_M") * .setInputCols("document") * .setOutputCol("reranked_documents") * .setQuery("A man is eating pasta.") @@ -273,9 +270,18 @@ case class GGUFRankingFinisher(override val uid: String) } // Filter by threshold and assign global ranks - val rankedDF = scaledDF - .filter($"scaled_score" >= getMinRelevanceScore) - .withColumn("global_rank", row_number().over(Window.orderBy($"scaled_score".desc))) + val filteredDF = scaledDF.filter($"scaled_score" >= getMinRelevanceScore) + + // Order by score and add row numbers using zipWithIndex to avoid Window partitioning + val orderedDF = filteredDF.orderBy($"scaled_score".desc) + val rankedRDD = orderedDF.rdd.zipWithIndex().map { case (row, index) => + Row.fromSeq(row.toSeq :+ (index + 1L)) + } + + // Create new schema with rank column + val schemaWithRank = + orderedDF.schema.add(StructField("global_rank", LongType, nullable = false)) + val rankedDF = orderedDF.sparkSession.createDataFrame(rankedRDD, schemaWithRank) // Apply top-k limit if specified val limitedDF = if (getTopK > 0) { diff --git a/src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala b/src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala index 3797bc8f29eb56..787e85e0bb860e 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/finisher/GGUFRankingFinisherTest.scala @@ -16,30 +16,24 @@ package com.johnsnowlabs.nlp.finisher -import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, ContentProvider} -import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.AnnotatorType import com.johnsnowlabs.nlp.util.io.ResourceHelper import com.johnsnowlabs.tags.FastTest -import com.johnsnowlabs.util.Benchmark -import org.apache.spark.ml.Pipeline -import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Row} import org.scalatest.flatspec.AnyFlatSpec class GGUFRankingFinisherTest extends AnyFlatSpec { - - import ResourceHelper.spark.implicits._ // Mock data to simulate AutoGGUFReranker output def createMockRerankerOutput(): DataFrame = { val spark = ResourceHelper.spark val documents = Seq( - ("A man is eating food.", 0.85, "A man is eating pasta."), - ("A man is eating a piece of bread.", 0.72, "A man is eating pasta."), - ("The girl is carrying a baby.", 0.15, "A man is eating pasta."), - ("A man is riding a horse.", 0.28, "A man is eating pasta."), - ("A young girl is playing violin.", 0.05, "A man is eating pasta.")) + ("A man is eating food.", 7.02, "A man is eating pasta."), + ("A man is eating a piece of bread.", 2.1, "A man is eating pasta."), + ("The girl is carrying a baby.", -10.78, "A man is eating pasta."), + ("A man is riding a horse.", -8.43, "A man is eating pasta."), + ("A young girl is playing violin.", -10.77, "A man is eating pasta.")) val mockAnnotations = documents.map { case (text, score, query) => Row( @@ -114,9 +108,9 @@ class GGUFRankingFinisherTest extends AnyFlatSpec { val scores = rankedDocs.map(_.getAs[Map[String, String]]("metadata")("relevance_score").toDouble) assert(scores.length == 3) - assert(scores.contains(0.85)) - assert(scores.contains(0.72)) - assert(scores.contains(0.28)) + assert(scores.contains(7.02)) + assert(scores.contains(2.1)) + assert(scores.contains(-8.43)) // Check ranks are correct val ranks = rankedDocs.map(_.getAs[Map[String, String]]("metadata")("rank").toInt) @@ -194,7 +188,7 @@ class GGUFRankingFinisherTest extends AnyFlatSpec { // Check that ranks are correct val ranks = rankedDocs.map(_.getAs[Map[String, String]]("metadata")("rank").toInt) - assert(ranks == (1 to rankedDocs.length).toSeq) + assert(ranks == (1 to rankedDocs.length)) } "GGUFRankingFinisher" should "handle empty input" taggedAs FastTest in { From 33b1f80e2d376fdc0ab9f797a5b86328695308a9 Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Mon, 1 Sep 2025 10:53:10 +0200 Subject: [PATCH 6/7] Add GGUFRankingFinisher notebook --- ...RankingFinisher_for_AutoGGUFReranker.ipynb | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb diff --git a/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb b/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb new file mode 100644 index 00000000000000..0ca948cc0ca2c1 --- /dev/null +++ b/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb @@ -0,0 +1,252 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3cac0728", + "metadata": {}, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb)\n", + "\n", + "# GGUFRankingFinisher for AutoGGUFReranker\n", + "\n", + "This notebook will show you how to use the `GGUFRankingFinisher` to post-process the relevance scores produced by the AutoGGUFReranker.\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- `AutoGGUFReranker` was introduced in `Spark NLP 6.1.2`, enabling efficient and quantized reranking of documents with LLMs. Please make sure you have upgraded to the latest Spark NLP release.\n", + "- `GGUFRankingFinisher` was introduced in `Spark NLP 6.1.3`, to post-process the document rankings\n", + "\n", + "`GGUFRankingFinisher` for `AutoGGUFReranker` outputs that provides ranking capabilities\n", + "including top-k selection, sorting by relevance score, and score normalization.\n", + "\n", + "This finisher processes the output of AutoGGUFReranker, which contains documents with\n", + "relevance scores in their metadata. It provides several options for post-processing:\n", + "\n", + "- Top-k selection: Select only the top k documents by relevance score\n", + "- Score thresholding: Filter documents by minimum relevance score\n", + "- Min-max scaling: Normalize relevance scores to 0-1 range\n", + "- Sorting: Sort documents by relevance score in descending order\n", + "- Ranking: Add rank information to document metadata\n", + "\n", + "The finisher preserves the document annotation structure while adding ranking information\n", + "to the metadata and optionally filtering/sorting the documents.\n", + "\n", + "## Spark NLP Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c8568c4", + "metadata": {}, + "outputs": [], + "source": [ + "# Only execute this if you are on Google Colab\n", + "! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a631ac47", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6.1.3\n" + ] + } + ], + "source": [ + "import sparknlp\n", + "\n", + "# let's start Spark with Spark NLP with GPU enabled. If you don't have GPUs available remove this parameter.\n", + "spark = sparknlp.start(gpu=True)\n", + "print(sparknlp.version())" + ] + }, + { + "cell_type": "markdown", + "id": "46c81114", + "metadata": {}, + "source": [ + "## Producing Document Rankings\n", + "\n", + "Let's start by producing some document ranking. We first define a suitable pipeline and run it on some data. The relevance scores will then be in the metadata." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d781eb9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracted 'libjllama.so' to '/tmp/libjllama.so'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no \n", + "ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no\n", + "ggml_cuda_init: found 1 CUDA devices:\n", + " Device 0: NVIDIA GeForce RTX 3070, compute capability 8.6, VMM: yes\n", + "[Stage 1:> (0 + 4) / 4]\r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------------------------------+-------------------------------------------+\n", + "|result |reranked_document.metadata[relevance_score]|\n", + "+---------------------------------+-------------------------------------------+\n", + "|A man is eating food. |7.023443 |\n", + "|A man is eating a piece of bread.|2.1200795 |\n", + "|The girl is carrying a baby. |-10.790537 |\n", + "|A man is riding a horse. |-8.433026 |\n", + "|A young girl is playing violin. |-10.778883 |\n", + "+---------------------------------+-------------------------------------------+\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + } + ], + "source": [ + "import sparknlp\n", + "from sparknlp.base import *\n", + "from sparknlp.annotator import *\n", + "from pyspark.ml import Pipeline\n", + "\n", + "document_assembler = DocumentAssembler().setInputCol(\"text\").setOutputCol(\"document\")\n", + "\n", + "auto_gguf_model = (\n", + " AutoGGUFReranker.loadSavedModel(\n", + " \"/home/ducha/Workspace/scala/spark-nlp-release/tmp_autogguf_reranker/bge-reranker-v2-m3-q4_k_m.gguf\",\n", + " spark,\n", + " )\n", + " .setInputCols(\"document\")\n", + " .setOutputCol(\"reranked_documents\")\n", + " .setQuery(\"A man is eating pasta.\")\n", + " .setDisableLog(True)\n", + ")\n", + "\n", + "pipeline = Pipeline().setStages([document_assembler, auto_gguf_model])\n", + "\n", + "data = spark.createDataFrame(\n", + " [\n", + " [\"A man is eating food.\"],\n", + " [\"A man is eating a piece of bread.\"],\n", + " [\"The girl is carrying a baby.\"],\n", + " [\"A man is riding a horse.\"],\n", + " [\"A young girl is playing violin.\"],\n", + " ]\n", + ").toDF(\"text\")\n", + "\n", + "result = pipeline.fit(data).transform(data)\n", + "\n", + "\n", + "# Verify results contain relevance scores\n", + "result.selectExpr(\"explode(reranked_documents) as reranked_document\").selectExpr(\n", + " \"reranked_document.result\", \"reranked_document.metadata['relevance_score']\"\n", + ").show(truncate=False)" + ] + }, + { + "cell_type": "markdown", + "id": "12c3bdd0", + "metadata": {}, + "source": [ + "# Post-Processing Ranking with `GGUFRankingFinisher`\n", + "\n", + "Let's now use the `GGUFRankingFinisher` to post-process and sort our results. For this the annotator will\n", + "\n", + "1. automatically sort\n", + "2. only choose the top 3 results\n", + "3. scale the relevance scores to be between $[0, 1]$, available as `scaled_score` in the metadata\n", + "4. set a minimum relevance score after rescaling\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e40e9051", + "metadata": {}, + "outputs": [], + "source": [ + "from sparknlp.base import *\n", + "\n", + "finisher = (\n", + " GGUFRankingFinisher()\n", + " .setInputCols(\"reranked_documents\")\n", + " .setOutputCol(\"finished_reranked_documents\")\n", + " .setTopK(3)\n", + " .setMinRelevanceScore(0.3)\n", + " .setMinMaxScaling(True)\n", + ")\n", + "finisher_result = finisher.transform(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a4836a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "|finished_reranked_documents |\n", + "+------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "|{document, 0, 20, A man is eating food., {sentence -> 0, query -> A man is eating pasta., relevance_score -> 1.0000162790005285, rank -> 1}, []} |\n", + "|{document, 0, 32, A man is eating a piece of bread., {sentence -> 0, query -> A man is eating pasta., relevance_score -> 0.7246697769085113, rank -> 2}, []}|\n", + "+------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "\n" + ] + } + ], + "source": [ + "finisher_result.selectExpr(\n", + " \"explode(finished_reranked_documents) as finished_reranked_documents\"\n", + ").show(truncate=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python (3.10.12)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d20e35933d361948eb8c4700b7db1eb30bed190f Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Mon, 1 Sep 2025 12:28:52 +0200 Subject: [PATCH 7/7] Change pretrained AutoGGUFReranking model --- docs/en/annotator_entries/AutoGGUFReranker.md | 4 +- .../annotator_entries/GGUFRankingFinisher.md | 2 +- ...RankingFinisher_for_AutoGGUFReranker.ipynb | 5 +- .../annotator/seq2seq/auto_gguf_reranker.py | 8 +- python/sparknlp/base/gguf_ranking_finisher.py | 2 +- .../seq2seq/auto_gguf_reranker_test.py | 78 +++++++------ .../test/base/gguf_ranking_finisher_test.py | 109 +++++++++++------- .../annotators/seq2seq/AutoGGUFReranker.scala | 6 +- .../seq2seq/AutoGGUFRerankerTest.scala | 4 +- 9 files changed, 122 insertions(+), 96 deletions(-) diff --git a/docs/en/annotator_entries/AutoGGUFReranker.md b/docs/en/annotator_entries/AutoGGUFReranker.md index a6c1ca6cfd5b7d..cc510b0d41b8ae 100644 --- a/docs/en/annotator_entries/AutoGGUFReranker.md +++ b/docs/en/annotator_entries/AutoGGUFReranker.md @@ -33,7 +33,7 @@ val reranker = AutoGGUFReranker.pretrained() .setQuery("A man is eating pasta.") ``` -The default model is `"bge-reranker-v2-m3-Q4_K_M"`, if no name is provided. +The default model is `"bge_reranker_v2_m3_Q4_K_M"`, if no name is provided. For available pretrained models please see the [Models Hub](https://sparknlp.org/models). @@ -105,7 +105,7 @@ val document = new DocumentAssembler() .setOutputCol("document") val reranker = AutoGGUFReranker - .pretrained("bge-reranker-v2-m3-Q4_K_M") + .pretrained() .setInputCols("document") .setOutputCol("reranked_documents") .setBatchSize(4) diff --git a/docs/en/annotator_entries/GGUFRankingFinisher.md b/docs/en/annotator_entries/GGUFRankingFinisher.md index e92a83fc365ae0..da7d057e85730e 100644 --- a/docs/en/annotator_entries/GGUFRankingFinisher.md +++ b/docs/en/annotator_entries/GGUFRankingFinisher.md @@ -85,7 +85,7 @@ val documentAssembler = new DocumentAssembler() // Reranker val reranker = AutoGGUFReranker - .pretrained("bge-reranker-v2-m3-Q4_K_M") + .pretrained() .setInputCols("document") .setOutputCol("reranked_documents") .setQuery("A man is eating pasta.") diff --git a/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb b/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb index 0ca948cc0ca2c1..95a07f0ee1a361 100644 --- a/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb +++ b/examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb @@ -136,10 +136,7 @@ "document_assembler = DocumentAssembler().setInputCol(\"text\").setOutputCol(\"document\")\n", "\n", "auto_gguf_model = (\n", - " AutoGGUFReranker.loadSavedModel(\n", - " \"/home/ducha/Workspace/scala/spark-nlp-release/tmp_autogguf_reranker/bge-reranker-v2-m3-q4_k_m.gguf\",\n", - " spark,\n", - " )\n", + " AutoGGUFReranker.pretrained()\n", " .setInputCols(\"document\")\n", " .setOutputCol(\"reranked_documents\")\n", " .setQuery(\"A man is eating pasta.\")\n", diff --git a/python/sparknlp/annotator/seq2seq/auto_gguf_reranker.py b/python/sparknlp/annotator/seq2seq/auto_gguf_reranker.py index ca922667bb9ea0..5d4c082dd26f89 100755 --- a/python/sparknlp/annotator/seq2seq/auto_gguf_reranker.py +++ b/python/sparknlp/annotator/seq2seq/auto_gguf_reranker.py @@ -47,7 +47,7 @@ class AutoGGUFReranker(AnnotatorModel, HasBatchedAnnotate, HasLlamaCppProperties ... .setOutputCol("reranked_documents") \\ ... .setQuery("A man is eating pasta.") - The default model is ``"bge-reranker-v2-m3-Q4_K_M"``, if no name is provided. + The default model is ``"bge_reranker_v2_m3_Q4_K_M"``, if no name is provided. For extended examples of usage, see the `AutoGGUFRerankerTest `__ @@ -222,7 +222,7 @@ class AutoGGUFReranker(AnnotatorModel, HasBatchedAnnotate, HasLlamaCppProperties >>> document = DocumentAssembler() \\ ... .setInputCol("text") \\ ... .setOutputCol("document") - >>> reranker = AutoGGUFReranker.pretrained("bge-reranker-v2-m3-Q4_K_M") \\ + >>> reranker = AutoGGUFReranker.pretrained() \\ ... .setInputCols(["document"]) \\ ... .setOutputCol("reranked_documents") \\ ... .setBatchSize(4) \\ @@ -307,13 +307,13 @@ def loadSavedModel(folder, spark_session): return AutoGGUFReranker(java_model=jModel) @staticmethod - def pretrained(name="bge-reranker-v2-m3-Q4_K_M", lang="en", remote_loc=None): + def pretrained(name="bge_reranker_v2_m3_Q4_K_M", lang="en", remote_loc=None): """Downloads and loads a pretrained model. Parameters ---------- name : str, optional - Name of the pretrained model, by default "bge-reranker-v2-m3-Q4_K_M" + Name of the pretrained model, by default "bge_reranker_v2_m3_Q4_K_M" lang : str, optional Language of the pretrained model, by default "en" remote_loc : str, optional diff --git a/python/sparknlp/base/gguf_ranking_finisher.py b/python/sparknlp/base/gguf_ranking_finisher.py index 7c7ca423015adf..e98c3304ddfaca 100644 --- a/python/sparknlp/base/gguf_ranking_finisher.py +++ b/python/sparknlp/base/gguf_ranking_finisher.py @@ -65,7 +65,7 @@ class GGUFRankingFinisher(AnnotatorTransformer): >>> documentAssembler = DocumentAssembler() \\ ... .setInputCol("text") \\ ... .setOutputCol("document") - >>> reranker = AutoGGUFReranker.pretrained("bge-reranker-v2-m3-Q4_K_M") \\ + >>> reranker = AutoGGUFReranker.pretrained() \\ ... .setInputCols("document") \\ ... .setOutputCol("reranked_documents") \\ ... .setQuery("A man is eating pasta.") diff --git a/python/test/annotator/seq2seq/auto_gguf_reranker_test.py b/python/test/annotator/seq2seq/auto_gguf_reranker_test.py index 5f7543a4522aed..996221dd115138 100644 --- a/python/test/annotator/seq2seq/auto_gguf_reranker_test.py +++ b/python/test/annotator/seq2seq/auto_gguf_reranker_test.py @@ -47,9 +47,10 @@ def runTest(self): # Use a local model path for testing - in real scenarios, use pretrained() model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" - + # Skip test if model file doesn't exist import os + if not os.path.exists(model_path): self.skipTest(f"Model file not found: {model_path}") @@ -104,33 +105,29 @@ def runTest(self): DocumentAssembler().setInputCol("text").setOutputCol("document") ) - # Test with pretrained model (may not be available in test environment) - try: - reranker = ( - AutoGGUFReranker.pretrained("bge-reranker-v2-m3-Q4_K_M") - .setInputCols("document") - .setOutputCol("reranked_documents") - .setBatchSize(2) - .setQuery(self.query) - ) + reranker = ( + AutoGGUFReranker.pretrained() + .setInputCols("document") + .setOutputCol("reranked_documents") + .setBatchSize(2) + .setQuery(self.query) + ) - pipeline = Pipeline().setStages([document_assembler, reranker]) - results = pipeline.fit(self.data).transform(self.data) + pipeline = Pipeline().setStages([document_assembler, reranker]) + results = pipeline.fit(self.data).transform(self.data) - # Verify results contain relevance scores - collected_results = results.collect() - for row in collected_results: - annotations = row["reranked_documents"] - for annotation in annotations: - self.assertIn("relevance_score", annotation.metadata) - # Relevance score should be a valid number - score = float(annotation.metadata["relevance_score"]) - self.assertIsInstance(score, float) + # Verify results contain relevance scores + collected_results = results.collect() + for row in collected_results: + annotations = row["reranked_documents"] + for annotation in annotations: + self.assertIn("relevance_score", annotation.metadata) + # Relevance score should be a valid number + score = float(annotation.metadata["relevance_score"]) + self.assertIsInstance(score, float) + + results.show() - results.show() - except Exception as e: - # Skip if pretrained model is not available - self.skipTest(f"Pretrained model not available: {str(e)}") @pytest.mark.slow class AutoGGUFRerankerMetadataTestSpec(unittest.TestCase): @@ -139,9 +136,10 @@ def setUp(self): def runTest(self): model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" - + # Skip test if model file doesn't exist import os + if not os.path.exists(model_path): self.skipTest(f"Model file not found: {model_path}") @@ -150,10 +148,11 @@ def runTest(self): metadata = reranker.getMetadata() self.assertIsNotNone(metadata) self.assertGreater(len(metadata), 0) - + print("Model metadata:") print(eval(metadata)) + # # @pytest.mark.slow # class AutoGGUFRerankerSerializationTestSpec(unittest.TestCase): @@ -215,7 +214,7 @@ def runTest(self): # results.select("reranked_documents").show(truncate=False) -@pytest.mark.slow +@pytest.mark.slow class AutoGGUFRerankerErrorHandlingTestSpec(unittest.TestCase): def setUp(self): self.spark = SparkContextForTest.spark @@ -229,9 +228,10 @@ def runTest(self): data = self.spark.createDataFrame([["Test document"]]).toDF("text") model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" - + # Skip test if model file doesn't exist import os + if not os.path.exists(model_path): self.skipTest(f"Model file not found: {model_path}") @@ -244,7 +244,7 @@ def runTest(self): ) pipeline = Pipeline().setStages([document_assembler, reranker]) - + # This should still work with empty query (based on implementation) try: results = pipeline.fit(data).transform(data) @@ -279,9 +279,10 @@ def runTest(self): ) model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" - + # Skip test if model file doesn't exist import os + if not os.path.exists(model_path): self.skipTest(f"Model file not found: {model_path}") @@ -322,11 +323,11 @@ def runTest(self): self.assertIn("rank", annotation.metadata) self.assertIn("query", annotation.metadata) self.assertEqual(annotation.metadata["query"], self.query) - + # Check that relevance score is normalized (due to minMaxScaling) score = float(annotation.metadata["relevance_score"]) self.assertTrue(0.0 <= score <= 1.0) - + # Check that rank is a valid integer rank = int(annotation.metadata["rank"]) self.assertIsInstance(rank, int) @@ -338,7 +339,9 @@ def runTest(self): ranks = [int(annotation.metadata["rank"]) for annotation in annotations] self.assertEqual(ranks, sorted(ranks)) - print("Pipeline with AutoGGUFReranker and GGUFRankingFinisher completed successfully") + print( + "Pipeline with AutoGGUFReranker and GGUFRankingFinisher completed successfully" + ) results.select("ranked_documents").show(truncate=False) @@ -368,9 +371,10 @@ def runTest(self): ) model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" - + # Skip test if model file doesn't exist import os + if not os.path.exists(model_path): self.skipTest(f"Model file not found: {model_path}") @@ -396,7 +400,7 @@ def runTest(self): results = pipeline.fit(self.data).transform(self.data) collected_results = results.collect() - + # Should have at most 2 results due to topK self.assertLessEqual(len(collected_results), 2) @@ -407,7 +411,7 @@ def runTest(self): # Check normalized scores are >= 0.1 threshold score = float(annotation.metadata["relevance_score"]) self.assertTrue(0.1 <= score <= 1.0) - + # Check rank metadata exists self.assertIn("rank", annotation.metadata) rank = int(annotation.metadata["rank"]) diff --git a/python/test/base/gguf_ranking_finisher_test.py b/python/test/base/gguf_ranking_finisher_test.py index e8a907be188b25..abdaa90e5b1966 100644 --- a/python/test/base/gguf_ranking_finisher_test.py +++ b/python/test/base/gguf_ranking_finisher_test.py @@ -37,7 +37,7 @@ def create_mock_reranker_output(self): ("A man is eating a piece of bread.", 0.72, "A man is eating pasta."), ("The girl is carrying a baby.", 0.15, "A man is eating pasta."), ("A man is riding a horse.", 0.28, "A man is eating pasta."), - ("A young girl is playing violin.", 0.05, "A man is eating pasta.") + ("A young girl is playing violin.", 0.05, "A man is eating pasta."), ] annotations = [] @@ -48,15 +48,15 @@ def create_mock_reranker_output(self): end=len(text) - 1, result=text, metadata={"relevance_score": str(score), "query": query}, - embeddings=[] + embeddings=[], ) annotations.append(annotation) # Create DataFrame with annotation array rows = [Row(reranked_documents=annotations)] - schema = StructType([ - StructField("reranked_documents", Annotation.arrayType(), nullable=False) - ]) + schema = StructType( + [StructField("reranked_documents", Annotation.arrayType(), nullable=False)] + ) return self.spark.createDataFrame(rows, schema) @@ -64,17 +64,19 @@ def test_default_settings(self): """Test GGUFRankingFinisher with default settings.""" mock_data = self.create_mock_reranker_output() - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") .setOutputCol("ranked_documents") + ) result = finisher.transform(mock_data) self.assertIn("ranked_documents", result.columns) - + # Get the ranked documents ranked_docs = result.collect()[0]["ranked_documents"] - + self.assertEqual(len(ranked_docs), 5) # Check that results are sorted by relevance score in descending order @@ -89,10 +91,12 @@ def test_top_k(self): """Test GGUFRankingFinisher with topK setting.""" mock_data = self.create_mock_reranker_output() - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ - .setOutputCol("ranked_documents") \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") .setTopK(3) + ) result = finisher.transform(mock_data) @@ -114,10 +118,12 @@ def test_threshold_filtering(self): """Test GGUFRankingFinisher with minimum relevance score threshold.""" mock_data = self.create_mock_reranker_output() - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ - .setOutputCol("ranked_documents") \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") .setMinRelevanceScore(0.3) + ) result = finisher.transform(mock_data) @@ -131,10 +137,12 @@ def test_min_max_scaling(self): """Test GGUFRankingFinisher with min-max scaling.""" mock_data = self.create_mock_reranker_output() - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ - .setOutputCol("ranked_documents") \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") .setMinMaxScaling(True) + ) result = finisher.transform(mock_data) @@ -152,17 +160,19 @@ def test_combined_filters(self): """Test GGUFRankingFinisher with combined topK, threshold, and scaling.""" mock_data = self.create_mock_reranker_output() - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ - .setOutputCol("ranked_documents") \ - .setTopK(2) \ - .setMinRelevanceScore(0.1) \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") + .setOutputCol("ranked_documents") + .setTopK(2) + .setMinRelevanceScore(0.1) .setMinMaxScaling(True) + ) result = finisher.transform(mock_data) ranked_docs = result.collect()[0]["ranked_documents"] - + # Should have at most 2 results due to topK self.assertLessEqual(len(ranked_docs), 2) @@ -183,15 +193,17 @@ def test_empty_input(self): # Create empty annotations rows = [Row(reranked_documents=[])] - schema = StructType([ - StructField("reranked_documents", Annotation.arrayType(), nullable=False) - ]) + schema = StructType( + [StructField("reranked_documents", Annotation.arrayType(), nullable=False)] + ) empty_data = self.spark.createDataFrame(rows, schema) - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") .setOutputCol("ranked_documents") + ) result = finisher.transform(empty_data) @@ -203,14 +215,16 @@ def test_query_preservation(self): """Test that query information is preserved in metadata.""" mock_data = self.create_mock_reranker_output() - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") .setOutputCol("ranked_documents") + ) result = finisher.transform(mock_data) ranked_docs = result.collect()[0]["ranked_documents"] - + # Check that query information is preserved in metadata for doc in ranked_docs: self.assertIn("query", doc.metadata) @@ -220,9 +234,18 @@ def test_missing_relevance_scores(self): """Test handling of documents with missing relevance scores.""" documents = [ - ("A man is eating food.", {"relevance_score": "0.85", "query": "A man is eating pasta."}), - ("A man is eating a piece of bread.", {"query": "A man is eating pasta."}), # Missing score - ("The girl is carrying a baby.", {"relevance_score": "0.15", "query": "A man is eating pasta."}) + ( + "A man is eating food.", + {"relevance_score": "0.85", "query": "A man is eating pasta."}, + ), + ( + "A man is eating a piece of bread.", + {"query": "A man is eating pasta."}, + ), # Missing score + ( + "The girl is carrying a baby.", + {"relevance_score": "0.15", "query": "A man is eating pasta."}, + ), ] annotations = [] @@ -233,20 +256,22 @@ def test_missing_relevance_scores(self): end=len(text) - 1, result=text, metadata=metadata, - embeddings=[] + embeddings=[], ) annotations.append(annotation) rows = [Row(reranked_documents=annotations)] - schema = StructType([ - StructField("reranked_documents", Annotation.arrayType(), nullable=False) - ]) + schema = StructType( + [StructField("reranked_documents", Annotation.arrayType(), nullable=False)] + ) test_data = self.spark.createDataFrame(rows, schema) - finisher = GGUFRankingFinisher() \ - .setInputCols("reranked_documents") \ + finisher = ( + GGUFRankingFinisher() + .setInputCols("reranked_documents") .setOutputCol("ranked_documents") + ) result = finisher.transform(test_data) @@ -278,5 +303,5 @@ def test_parameter_getters_setters(self): self.assertEqual(finisher.getOutputCol(), "custom_output") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFReranker.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFReranker.scala index 61dd0681dbb427..f916ad82a071d2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFReranker.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFReranker.scala @@ -57,7 +57,7 @@ import java.util.{ArrayList, List} * .setOutputCol("reranked_documents") * .setQuery("A man is eating pasta.") * }}} - * The default model is `"bge-reranker-v2-m3-Q4_K_M"`, if no name is provided. + * The default model is `"bge_reranker_v2_m3_Q4_K_M"`, if no name is provided. * * For available pretrained models please see the [[https://sparknlp.org/models Models Hub]]. * @@ -90,7 +90,7 @@ import java.util.{ArrayList, List} * .setOutputCol("document") * * val reranker = AutoGGUFReranker - * .pretrained("bge-reranker-v2-m3-Q4_K_M") + * .pretrained() * .setInputCols("document") * .setOutputCol("reranked_documents") * .setBatchSize(4) @@ -252,7 +252,7 @@ class AutoGGUFReranker(override val uid: String) trait ReadablePretrainedAutoGGUFReranker extends ParamsAndFeaturesFallbackReadable[AutoGGUFReranker] with HasPretrained[AutoGGUFReranker] { - override val defaultModelName: Some[String] = Some("bge-reranker-v2-m3-Q4_K_M") + override val defaultModelName: Some[String] = Some("bge_reranker_v2_m3_Q4_K_M") override val defaultLang: String = "en" /** Java compliant-overrides */ diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala index 2a25ad03b65fde..532d9a921967c8 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala @@ -20,7 +20,7 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { .setOutputCol("document") lazy val query: String = "A man is eating pasta." - lazy val modelPath = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" + lazy val modelPath = "/tmp/bge_reranker_v2_m3_Q4_K_M.gguf" lazy val model: AutoGGUFReranker = AutoGGUFReranker .loadSavedModel(modelPath, ResourceHelper.spark) .setInputCols("document") @@ -91,7 +91,7 @@ class AutoGGUFRerankerTest extends AnyFlatSpec { } it should "contain metadata when loadSavedModel" taggedAs SlowTest in { - lazy val modelPath = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf" + lazy val modelPath = "/tmp/bge_reranker_v2_m3_Q4_K_M.gguf" val model = AutoGGUFReranker.loadSavedModel(modelPath, ResourceHelper.spark) val metadata = model.getMetadata assert(metadata.nonEmpty)