diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 8f00daa59f1a5..12804d08a4bc6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -269,12 +269,8 @@ class FPGrowthModel private[ml] ( val predictUDF = udf((items: Seq[_]) => { if (items != null) { val itemset = items.toSet - brRules.value.flatMap(rule => - if (items != null && rule._1.forall(item => itemset.contains(item))) { - rule._2.filter(item => !itemset.contains(item)) - } else { - Seq.empty - }).distinct + brRules.value.filter(_._1.forall(itemset.contains)) + .flatMap(_._2.filter(!itemset.contains(_))).distinct } else { Seq.empty }}, dt)