diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 047a378b79aa7..817ecb285a9a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -196,8 +196,13 @@ final class OneVsRestModel private[ml] ( } // output label and label metadata as prediction + val predictionMetadata = new MetadataBuilder() + .withMetadata(labelMetadata) + .putString("name", predictionCol.name) + .build() + aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) + .withColumn($(predictionCol), labelUDF(col(accColName)), predictionMetadata) .drop(accColName) }