diff --git a/R/pkg/tests/fulltests/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R index 69dda52f0c279..0d13eba4c03bd 100644 --- a/R/pkg/tests/fulltests/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -44,7 +44,8 @@ test_that("spark.fpGrowth", { expected_association_rules <- data.frame( antecedent = I(list(list("2"), list("3"))), consequent = I(list(list("1"), list("1"))), - confidence = c(1, 1) + confidence = c(1, 1), + support = c(0.75, 0.5) ) expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 85c483c387ad8..e1b1317a2f257 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -20,6 +20,8 @@ package org.apache.spark.ml.fpm import scala.reflect.ClassTag import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} @@ -34,6 +36,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.VersionUtils /** * Common params for FPGrowth and FPGrowthModel @@ -187,7 +190,7 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + copyValues(new FPGrowthModel(uid, frequentItems, data.count())).setParent(this) } @Since("2.2.0") @@ -217,7 +220,8 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { @Experimental class FPGrowthModel private[ml] ( @Since("2.2.0") override val uid: String, - @Since("2.2.0") @transient val freqItemsets: DataFrame) + @Since("2.2.0") @transient val freqItemsets: DataFrame, + @Since("2.5.0") val numTrainingRecords: Long) extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { /** @group setParam */ @@ -241,17 +245,17 @@ class FPGrowthModel private[ml] ( @transient private var _cachedRules: DataFrame = _ /** - * Get association rules fitted using the minConfidence. Returns a dataframe - * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and - * "consequent" are Array[T] and "confidence" is Double. + * Get association rules fitted by AssociationRules with the minConfidence. Returns a dataframe + * with four fields, "antecedent", "consequent", "confidence" and "support", where "antecedent" + * and "consequent" are Array[T], "confidence" and "support" are Double. */ @Since("2.2.0") @transient def associationRules: DataFrame = { if ($(minConfidence) == _cachedMinConf) { _cachedRules } else { - _cachedRules = AssociationRules - .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + _cachedRules = AssociationRules.getAssociationRulesFromFP( + freqItemsets, "items", "freq", numTrainingRecords, $(minConfidence)) _cachedMinConf = $(minConfidence) _cachedRules } @@ -301,7 +305,7 @@ class FPGrowthModel private[ml] ( @Since("2.2.0") override def copy(extra: ParamMap): FPGrowthModel = { - val copied = new FPGrowthModel(uid, freqItemsets) + val copied = new FPGrowthModel(uid, freqItemsets, numTrainingRecords) copyValues(copied, extra).setParent(this.parent) } @@ -323,7 +327,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + val extraMetadata = "numTrainingRecords" -> instance.numTrainingRecords + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val dataPath = new Path(path, "data").toString instance.freqItemsets.write.parquet(dataPath) } @@ -335,10 +340,20 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { private val className = classOf[FPGrowthModel].getName override def load(path: String): FPGrowthModel = { + implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) + val numTrainingRecords = if (major.toInt < 2 || (major.toInt == 2 && minor.toInt < 4)) { + // 2.3 and before + 1L + } else { + // 2.4+ + (metadata.metadata \ "numTrainingRecords").extract[Long] + } + val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) - val model = new FPGrowthModel(metadata.uid, frequentItems) + val model = new FPGrowthModel(metadata.uid, frequentItems, numTrainingRecords) metadata.getAndSetParams(model) model } @@ -352,29 +367,32 @@ private[fpm] object AssociationRules { * @param dataset DataFrame("items"[Array], "freq"[Long]) containing frequent itemsets obtained * from algorithms like [[FPGrowth]]. * @param itemsCol column name for frequent itemsets - * @param freqCol column name for appearance count of the frequent itemsets - * @param minConfidence minimum confidence for generating the association rules - * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) - * containing the association rules. + * @param freqCol column name for frequent itemsets count + * @param numTrainingRecords count of training Dataset, default -1. + * @param minConfidence minimum confidence for the result association rules + * @return a DataFrame("antecedent", "consequent", "confidence", "support") containing the + * association rules. */ def getAssociationRulesFromFP[T: ClassTag]( - dataset: Dataset[_], - itemsCol: String, - freqCol: String, - minConfidence: Double): DataFrame = { + dataset: Dataset[_], + itemsCol: String, + freqCol: String, + numTrainingRecords: Long, + minConfidence: Double): DataFrame = { val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd .map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1))) val rows = new MLlibAssociationRules() .setMinConfidence(minConfidence) .run(freqItemSetRdd) - .map(r => Row(r.antecedent, r.consequent, r.confidence)) + .map(r => Row(r.antecedent, r.consequent, r.confidence, r.freqUnion / numTrainingRecords)) val dt = dataset.schema(itemsCol).dataType val schema = StructType(Seq( StructField("antecedent", dt, nullable = false), StructField("consequent", dt, nullable = false), - StructField("confidence", DoubleType, nullable = false))) + StructField("confidence", DoubleType, nullable = false), + StructField("support", DoubleType, nullable = false))) val rules = dataset.sparkSession.createDataFrame(rows, schema) rules } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index acb83ac31affd..337ce50e68cc0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -106,7 +106,7 @@ object AssociationRules { class Rule[Item] private[fpm] ( @Since("1.5.0") val antecedent: Array[Item], @Since("1.5.0") val consequent: Array[Item], - freqUnion: Double, + private[spark] val freqUnion: Double, freqAntecedent: Double) extends Serializable { /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 87f8b9034dde8..61ef755649df2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -38,14 +38,15 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val data = dataset.withColumn("items", col("items").cast(ArrayType(dt))) val model = new FPGrowth().setMinSupport(0.5).fit(data) val generatedRules = model.setMinConfidence(0.5).associationRules + generatedRules.show() val expectedRules = spark.createDataFrame(Seq( - (Array("2"), Array("1"), 1.0), - (Array("1"), Array("2"), 0.75) - )).toDF("antecedent", "consequent", "confidence") + (Array("2"), Array("1"), 1.0, 0.75), + (Array("1"), Array("2"), 0.75, 0.75) + )).toDF("antecedent", "consequent", "confidence", "support") .withColumn("antecedent", col("antecedent").cast(ArrayType(dt))) .withColumn("consequent", col("consequent").cast(ArrayType(dt))) - assert(expectedRules.sort("antecedent").rdd.collect().sameElements( - generatedRules.sort("antecedent").rdd.collect())) + assert(expectedRules.collect().toSet.equals( + generatedRules.collect().toSet)) val transformed = model.transform(data) val expectedTransformed = spark.createDataFrame(Seq( @@ -75,6 +76,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(checkDF.count() == 3 && checkDF.filter(col("freq") === col("expectedFreq")).count() == 3) } + test("FPGrowth associationRules") { + val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset) + val expectedRules = spark.createDataFrame(Seq( + (Array("2"), Array("1"), 1.0, 0.75), + (Array("3"), Array("1"), 1.0, 0.25), + (Array("1"), Array("3"), 0.25, 0.25), + (Array("1"), Array("2"), 0.75, 0.75) + )).toDF("antecedent", "consequent", "confidence", "support") + assert(expectedRules.collect().toSet.equals(model.associationRules.collect().toSet)) + } + test("FPGrowth getFreqItems with Null") { val df = spark.createDataFrame(Seq( (1, Array("1", "2", "3", "5")), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index cdc99a48e5b64..f7d47091d21d3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -101,7 +101,10 @@ object MimaExcludes { ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), // [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter"), + + // [SPARK-19939][ML] Add support for association rules in ML + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this") ) // Exclude rules for 2.3.x diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index f9394421e0cc4..456de9cf9f099 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -187,29 +187,29 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, |[z] | |[x, z, y, r, q, t, p] | +------------------------+ - >>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7) + >>> fp = FPGrowth(minSupport=0.4, minConfidence=0.7) >>> fpm = fp.fit(data) >>> fpm.freqItemsets.show(5) - +---------+----+ - | items|freq| - +---------+----+ - | [s]| 3| - | [s, x]| 3| - |[s, x, z]| 2| - | [s, z]| 2| - | [r]| 3| - +---------+----+ + +------+----+ + | items|freq| + +------+----+ + | [s]| 3| + |[s, x]| 3| + | [r]| 3| + | [y]| 3| + |[y, x]| 3| + +------+----+ only showing top 5 rows >>> fpm.associationRules.show(5) - +----------+----------+----------+ - |antecedent|consequent|confidence| - +----------+----------+----------+ - | [t, s]| [y]| 1.0| - | [t, s]| [x]| 1.0| - | [t, s]| [z]| 1.0| - | [p]| [r]| 1.0| - | [p]| [z]| 1.0| - +----------+----------+----------+ + +----------+----------+----------+-------+ + |antecedent|consequent|confidence|support| + +----------+----------+----------+-------+ + | [t]| [y]| 1.0| 0.5| + | [t]| [x]| 1.0| 0.5| + | [t]| [z]| 1.0| 0.5| + | [y, t, x]| [z]| 1.0| 0.5| + | [x]| [s]| 0.75| 0.5| + +----------+----------+----------+-------+ only showing top 5 rows >>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) >>> sorted(fpm.transform(new_data).first().prediction) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 5c87d1de4139b..bb0586ff74420 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2158,8 +2158,8 @@ def test_association_rules(self): fpm = fp.fit(self.data) expected_association_rules = self.spark.createDataFrame( - [([3], [1], 1.0), ([2], [1], 1.0)], - ["antecedent", "consequent", "confidence"] + [([3], [1], 1.0, 0.5), ([2], [1], 1.0, 0.75)], + ["antecedent", "consequent", "confidence", "support"] ) actual_association_rules = fpm.associationRules