From bcac9498a4c8183d01593591d64ee097fe62069a Mon Sep 17 00:00:00 2001 From: sueann Date: Wed, 22 Feb 2017 15:11:58 -0800 Subject: [PATCH 01/10] skeleton fns, doesn't compile --- .../scala/org/apache/spark/ml/recommendation/ALS.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 799e881fad74a..ccf6fa0bca2be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -327,6 +327,15 @@ class ALSModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new ALSModel.ALSModelWriter(this) + + // TODO: output is DataFrame ?? DataSet ?? what exactly is the output schema? + def recommendForAllUsers(): DataFrame = { + + } + + def recommendForAllItems(): DataFrame = { + + } } @Since("1.6.0") From d4616ecdcddc52d48cf90e76b75a5ba94625fb19 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 27 Feb 2017 16:23:16 -0800 Subject: [PATCH 02/10] simple working tests --- .../apache/spark/ml/recommendation/ALS.scala | 52 ++++++++++++++----- .../recommendation/TopByKeyAggregator.scala | 49 +++++++++++++++++ .../spark/ml/recommendation/ALSSuite.scala | 43 ++++++++++++++- 3 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index ccf6fa0bca2be..9dba9cfc4cf31 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -284,18 +284,18 @@ class ALSModel private[ml] ( @Since("2.2.0") def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + private val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => + if (userFeatures != null && itemFeatures != null) { + blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) + } else { + Float.NaN + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. - val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => - if (userFeatures != null && itemFeatures != null) { - blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) - } else { - Float.NaN - } - } val predictions = dataset .join(userFactors, checkedCast(dataset($(userCol))) === userFactors("id"), "left") @@ -328,13 +328,41 @@ class ALSModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new ALSModel.ALSModelWriter(this) - // TODO: output is DataFrame ?? DataSet ?? what exactly is the output schema? - def recommendForAllUsers(): DataFrame = { - + @Since("2.2.0") + def recommendForAllUsers(num: Int): DataFrame = { + recommendForAll(userFactors, itemFactors, $(userCol), num) } - def recommendForAllItems(): DataFrame = { + @Since("2.2.0") + def recommendForAllItems(num: Int): DataFrame = { + recommendForAll(itemFactors, userFactors, $(itemCol), num) + } + /** + * Makes recommendations for all users (or items). + * @param srcFactors src factors for which to generate recommendations + * @param dstFactors dst factors used to make recommendations + * @param srcOutputColumn name of the column for the source in the output DataFrame + * @param num number of recommendations for each record + * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are + * stored as an array of (dstId: Int, ratingL: Double) tuples. + */ + private def recommendForAll( + srcFactors: DataFrame, + dstFactors: DataFrame, + srcOutputColumn: String, + num: Int): DataFrame = { + import srcFactors.sparkSession.implicits._ + + val ratings = srcFactors.crossJoin(dstFactors) + .select( + srcFactors("id").as("srcId"), + dstFactors("id").as("dstId"), + predict(srcFactors("features"), dstFactors("features")).as($(predictionCol))) + // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) + ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) + .toDF(srcOutputColumn, "recommendations") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala new file mode 100644 index 0000000000000..b3602d557a9f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.{Encoder, Encoders} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.util.BoundedPriorityQueue + +/** + * TODO: Some comments go here about the class + * TODO: should probably move it to somewhere else + */ + +private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] + (num: Int, ord: Ordering[(K2, V)]) + extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] { + + override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord) + override def reduce( + q: BoundedPriorityQueue[(K2, V)], + a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = q += {(a._2, a._3)} + override def merge( + q1: BoundedPriorityQueue[(K2, V)], + q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = q1 ++= q2 + override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = + r.toArray.sorted(ord.reverse) + override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = + Encoders.kryo[BoundedPriorityQueue[(K2, V)]] + override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]] +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index c8228dd004374..47da317a504a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.WrappedArray import scala.collection.JavaConverters._ import scala.language.existentials @@ -367,7 +368,7 @@ class ALSSuite implicitPrefs: Boolean = false, numUserBlocks: Int = 2, numItemBlocks: Int = 3, - targetRMSE: Double = 0.05): Unit = { + targetRMSE: Double = 0.05): ALSModel = { val spark = this.spark import spark.implicits._ val als = new ALS() @@ -410,6 +411,8 @@ class ALSSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + + model } test("exact rank-1 matrix") { @@ -659,6 +662,44 @@ class ALSSuite Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => model.setColdStartStrategy(s).transform(data) } + + test("recommendForAllUsers") { + val numUsers = 20 + val numItems = 40 + val numRecs = 5 + val (training, test) = genExplicitTestData(numUsers, numItems, rank = 2, noiseStd = 0.01) + val topItems = + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) + .recommendForAllUsers(numRecs) + + assert(topItems.count() == numUsers) + assert(topItems.columns.contains("user")) + checkRecommendationOrdering(topItems, numRecs) + } + + test("recommendForAllItems") { + val numUsers = 20 + val numItems = 40 + val numRecs = 5 + val (training, test) = genExplicitTestData(numUsers, numItems, rank = 2, noiseStd = 0.01) + val topUsers = + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) + .recommendForAllItems(numRecs) + + assert(topUsers.count() == numItems) + assert(topUsers.columns.contains("item")) + checkRecommendationOrdering(topUsers, numRecs) + } + + private def checkRecommendationOrdering(topK: DataFrame, k: Int): Unit = { + assert(topK.columns.contains("recommendations")) + topK.select("recommendations").collect().foreach( + row => { + val recs = row.getAs[WrappedArray[Row]]("recommendations") + assert(recs.length == k) + assert(recs.sorted(Ordering.by((x: Row) => x(1).asInstanceOf[Float]).reverse) == recs) + } + ) } } From 2de5b451f0666bdb5f2e5e09cd814e1b68f76873 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 27 Feb 2017 16:25:29 -0800 Subject: [PATCH 03/10] comments --- .../apache/spark/ml/recommendation/TopByKeyAggregator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala index b3602d557a9f5..fceb19a95e324 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.util.BoundedPriorityQueue /** - * TODO: Some comments go here about the class - * TODO: should probably move it to somewhere else + * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds + * the top `num` K2 items based on the given Ordering. */ private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] From 6de4c38b42d466633569962a27d5dcd08baa7a3e Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 27 Feb 2017 16:59:06 -0800 Subject: [PATCH 04/10] scaladoc, formatting --- .../apache/spark/ml/recommendation/ALS.scala | 15 +++++++++++- .../recommendation/TopByKeyAggregator.scala | 23 ++++++++++++++----- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 9dba9cfc4cf31..8f82c837c51a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -286,6 +286,7 @@ class ALSModel private[ml] ( private val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { + // TODO: try dot-producting on Seqs or another non-converted type for potential optimization blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN @@ -328,11 +329,23 @@ class ALSModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new ALSModel.ALSModelWriter(this) + /** + * Returns top `num` items recommended for each user, for all users. + * @param num number of recommendations for each user + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemId: Int, rating: Double) tuples. + */ @Since("2.2.0") def recommendForAllUsers(num: Int): DataFrame = { recommendForAll(userFactors, itemFactors, $(userCol), num) } + /** + * Returns top `num` users recommended for each item, for all items. + * @param num number of recommendations for each item + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userId: Int, rating: Double) tuples. + */ @Since("2.2.0") def recommendForAllItems(num: Int): DataFrame = { recommendForAll(itemFactors, userFactors, $(itemCol), num) @@ -345,7 +358,7 @@ class ALSModel private[ml] ( * @param srcOutputColumn name of the column for the source in the output DataFrame * @param num number of recommendations for each record * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are - * stored as an array of (dstId: Int, ratingL: Double) tuples. + * stored as an array of (dstId: Int, rating: Double) tuples. */ private def recommendForAll( srcFactors: DataFrame, diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala index fceb19a95e324..b249af9705184 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -25,25 +25,36 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.util.BoundedPriorityQueue + /** * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds * the top `num` K2 items based on the given Ordering. */ - private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] (num: Int, ord: Ordering[(K2, V)]) extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] { override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord) + override def reduce( - q: BoundedPriorityQueue[(K2, V)], - a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = q += {(a._2, a._3)} + q: BoundedPriorityQueue[(K2, V)], + a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = { + q += {(a._2, a._3)} + } + override def merge( q1: BoundedPriorityQueue[(K2, V)], - q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = q1 ++= q2 - override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = + q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = { + q1 ++= q2 + } + + override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = { r.toArray.sorted(ord.reverse) - override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = + } + + override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = { Encoders.kryo[BoundedPriorityQueue[(K2, V)]] + } + override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]] } From 1b35f0a6dc906b3723be68e57b40f339cb4b3af7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 27 Feb 2017 18:48:30 -0800 Subject: [PATCH 05/10] clean-ups, comments. --- .../apache/spark/ml/recommendation/ALS.scala | 32 +++++++++++-------- .../recommendation/TopByKeyAggregator.scala | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 8f82c837c51a8..d0e4b2366441f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,8 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -286,7 +287,8 @@ class ALSModel private[ml] ( private val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { - // TODO: try dot-producting on Seqs or another non-converted type for potential optimization + // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for + // potential optimization. blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN @@ -330,25 +332,25 @@ class ALSModel private[ml] ( override def write: MLWriter = new ALSModel.ALSModelWriter(this) /** - * Returns top `num` items recommended for each user, for all users. - * @param num number of recommendations for each user + * Returns top `numItems` items recommended for each user, for all users. + * @param numItems max number of recommendations for each user * @return a DataFrame of (userCol: Int, recommendations), where recommendations are * stored as an array of (itemId: Int, rating: Double) tuples. */ @Since("2.2.0") - def recommendForAllUsers(num: Int): DataFrame = { - recommendForAll(userFactors, itemFactors, $(userCol), num) + def recommendForAllUsers(numItems: Int): DataFrame = { + recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) } /** - * Returns top `num` users recommended for each item, for all items. - * @param num number of recommendations for each item + * Returns top `numUsers` users recommended for each item, for all items. + * @param numUsers max number of recommendations for each item * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are * stored as an array of (userId: Int, rating: Double) tuples. */ @Since("2.2.0") - def recommendForAllItems(num: Int): DataFrame = { - recommendForAll(itemFactors, userFactors, $(itemCol), num) + def recommendForAllItems(numUsers: Int): DataFrame = { + recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) } /** @@ -356,7 +358,8 @@ class ALSModel private[ml] ( * @param srcFactors src factors for which to generate recommendations * @param dstFactors dst factors used to make recommendations * @param srcOutputColumn name of the column for the source in the output DataFrame - * @param num number of recommendations for each record + * @param dstOutputColumn name of the column for the destination in the output DataFrame + * @param num max number of recommendations for each record * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are * stored as an array of (dstId: Int, rating: Double) tuples. */ @@ -364,14 +367,15 @@ class ALSModel private[ml] ( srcFactors: DataFrame, dstFactors: DataFrame, srcOutputColumn: String, + dstOutputColumn: String, num: Int): DataFrame = { import srcFactors.sparkSession.implicits._ val ratings = srcFactors.crossJoin(dstFactors) .select( - srcFactors("id").as("srcId"), - dstFactors("id").as("dstId"), - predict(srcFactors("features"), dstFactors("features")).as($(predictionCol))) + srcFactors("id"), + dstFactors("id"), + predict(srcFactors("features"), dstFactors("features"))) // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala index b249af9705184..517179c0eb9ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -56,5 +56,5 @@ private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: Ty Encoders.kryo[BoundedPriorityQueue[(K2, V)]] } - override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]] + override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]() } From 08a58e45fcb46c345bd734b8b289e218f71c5ce6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 28 Feb 2017 16:21:24 -0800 Subject: [PATCH 06/10] added tests for TopByKeyAggregator, more precise tests for recommendAll fns --- .../apache/spark/ml/recommendation/ALS.scala | 18 ++-- .../spark/ml/recommendation/ALSSuite.scala | 102 +++++++++++++----- .../TopByKeyAggregatorSuite.scala | 81 ++++++++++++++ 3 files changed, 163 insertions(+), 38 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d0e4b2366441f..658486f37c5d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -285,11 +285,11 @@ class ALSModel private[ml] ( @Since("2.2.0") def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) - private val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => - if (userFeatures != null && itemFeatures != null) { + private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) => + if (featuresA != null && featuresB != null) { // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for // potential optimization. - blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) + blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1) } else { Float.NaN } @@ -335,22 +335,22 @@ class ALSModel private[ml] ( * Returns top `numItems` items recommended for each user, for all users. * @param numItems max number of recommendations for each user * @return a DataFrame of (userCol: Int, recommendations), where recommendations are - * stored as an array of (itemId: Int, rating: Double) tuples. + * stored as an array of (itemId: Int, rating: Float) tuples. */ @Since("2.2.0") def recommendForAllUsers(numItems: Int): DataFrame = { - recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) + recommendForAll(userFactors, itemFactors, $(userCol), numItems) } /** * Returns top `numUsers` users recommended for each item, for all items. * @param numUsers max number of recommendations for each item * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are - * stored as an array of (userId: Int, rating: Double) tuples. + * stored as an array of (userId: Int, rating: Float) tuples. */ @Since("2.2.0") def recommendForAllItems(numUsers: Int): DataFrame = { - recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) + recommendForAll(itemFactors, userFactors, $(itemCol), numUsers) } /** @@ -358,16 +358,14 @@ class ALSModel private[ml] ( * @param srcFactors src factors for which to generate recommendations * @param dstFactors dst factors used to make recommendations * @param srcOutputColumn name of the column for the source in the output DataFrame - * @param dstOutputColumn name of the column for the destination in the output DataFrame * @param num max number of recommendations for each record * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are - * stored as an array of (dstId: Int, rating: Double) tuples. + * stored as an array of (dstId: Int, rating: Float) tuples. */ private def recommendForAll( srcFactors: DataFrame, dstFactors: DataFrame, srcOutputColumn: String, - dstOutputColumn: String, num: Int): DataFrame = { import srcFactors.sparkSession.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 47da317a504a1..0346c6bf9cbd7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -662,44 +662,90 @@ class ALSSuite Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => model.setColdStartStrategy(s).transform(data) } + } - test("recommendForAllUsers") { - val numUsers = 20 - val numItems = 40 - val numRecs = 5 - val (training, test) = genExplicitTestData(numUsers, numItems, rank = 2, noiseStd = 0.01) - val topItems = - testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) - .recommendForAllUsers(numRecs) + private def getALSModel = { + val spark = this.spark + import spark.implicits._ - assert(topItems.count() == numUsers) + val userFactors = Seq( + (0, Array(6.0f, 4.0f)), + (1, Array(3.0f, 4.0f)), + (2, Array(3.0f, 6.0f)) + ).toDF("id", "features") + val itemFactors = Seq( + (3, Array(5.0f, 6.0f)), + (4, Array(6.0f, 2.0f)), + (5, Array(3.0f, 6.0f)), + (6, Array(4.0f, 1.0f)) + ).toDF("id", "features") + val als = new ALS().setRank(2) + new ALSModel(als.uid, als.getRank, userFactors, itemFactors) + .setUserCol("user") + .setItemCol("item") + } + + test("recommendForAllUsers with k < num_items") { + val topItems = getALSModel.recommendForAllUsers(2) + assert(topItems.count() == 3) assert(topItems.columns.contains("user")) - checkRecommendationOrdering(topItems, numRecs) + + val expected = Map( + 0 -> Array(Row(3, 54f), Row(4, 44f)), + 1 -> Array(Row(3, 39f), Row(5, 33f)), + 2 -> Array(Row(3, 51f), Row(5, 45f)) + ) + checkRecommendations(topItems, expected) } - test("recommendForAllItems") { - val numUsers = 20 - val numItems = 40 - val numRecs = 5 - val (training, test) = genExplicitTestData(numUsers, numItems, rank = 2, noiseStd = 0.01) - val topUsers = - testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) - .recommendForAllItems(numRecs) + test("recommendForAllUsers with k = num_items") { + val topItems = getALSModel.recommendForAllUsers(4) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) - assert(topUsers.count() == numItems) + val expected = Map( + 0 -> Array(Row(3, 54f), Row(4, 44f), Row(5, 42f), Row(6, 28f)), + 1 -> Array(Row(3, 39f), Row(5, 33f), Row(4, 26f), Row(6, 16f)), + 2 -> Array(Row(3, 51f), Row(5, 45f), Row(4, 30f), Row(6, 18f)) + ) + checkRecommendations(topItems, expected) + } + + test("recommendForAllItems with k < num_users") { + val topUsers = getALSModel.recommendForAllItems(2) + assert(topUsers.count() == 4) assert(topUsers.columns.contains("item")) - checkRecommendationOrdering(topUsers, numRecs) + + val expected = Map( + 3 -> Array(Row(0, 54f), Row(2, 51f)), + 4 -> Array(Row(0, 44f), Row(2, 30f)), + 5 -> Array(Row(2, 45f), Row(0, 42f)), + 6 -> Array(Row(0, 28f), Row(2, 18f)) + ) + checkRecommendations(topUsers, expected) } - private def checkRecommendationOrdering(topK: DataFrame, k: Int): Unit = { - assert(topK.columns.contains("recommendations")) - topK.select("recommendations").collect().foreach( - row => { - val recs = row.getAs[WrappedArray[Row]]("recommendations") - assert(recs.length == k) - assert(recs.sorted(Ordering.by((x: Row) => x(1).asInstanceOf[Float]).reverse) == recs) - } + test("recommendForAllItems with k = num_users") { + val topUsers = getALSModel.recommendForAllItems(3) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array(Row(0, 54f), Row(2, 51f), Row(1, 39f)), + 4 -> Array(Row(0, 44f), Row(2, 30f), Row(1, 26f)), + 5 -> Array(Row(2, 45f), Row(0, 42f), Row(1, 33f)), + 6 -> Array(Row(0, 28f), Row(2, 18f), Row(1, 16f)) ) + checkRecommendations(topUsers, expected) + } + + private def checkRecommendations(topK: DataFrame, expected: Map[Int, Array[Row]]): Unit = { + assert(topK.columns.contains("recommendations")) + topK.collect().foreach { row => + val id = row.getInt(0) + val recs = row.getAs[WrappedArray[Row]]("recommendations") + assert(recs === expected(id)) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala new file mode 100644 index 0000000000000..2f210f4b1e368 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + + +class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2)) + Seq( + (0, 3, 54f), + (0, 4, 44f), + (0, 5, 42f), + (0, 6, 28f), + (1, 3, 39f), + (1, 4, 26f), + (1, 5, 33f), + (1, 6, 16f), + (2, 3, 51f), + (2, 4, 30f), + (2, 5, 45f), + (2, 6, 18f) + ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn) + } + + test("topByKey with k < #items") { + val topK = getTopK(2) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f), (5, 33f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkTopK(topK, expected) + } + + test("topByKey with k > #items") { + val topK = getTopK(5) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + checkTopK(topK, expected) + } + + private def checkTopK( + topK: Dataset[(Int, Array[(Int, Float)])], + expected: Map[Int, Array[(Int, Float)]]): Unit = { + topK.collect().foreach { record => + val id = record._1 + val recs = record._2 + assert(recs === expected(id)) + } + } +} From b139c265cff37b2b6659cd5736a9d87b3eaa6128 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 28 Feb 2017 16:23:40 -0800 Subject: [PATCH 07/10] cleanup --- .../scala/org/apache/spark/ml/recommendation/ALSSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 0346c6bf9cbd7..5eeb11816eb04 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -368,7 +368,7 @@ class ALSSuite implicitPrefs: Boolean = false, numUserBlocks: Int = 2, numItemBlocks: Int = 3, - targetRMSE: Double = 0.05): ALSModel = { + targetRMSE: Double = 0.05): Unit = { val spark = this.spark import spark.implicits._ val als = new ALS() @@ -411,8 +411,6 @@ class ALSSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) - - model } test("exact rank-1 matrix") { From c1973e6f3c5c8ba61eb2317c79322f02aa470356 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 28 Feb 2017 16:45:48 -0800 Subject: [PATCH 08/10] name the output columns in the recommendations Array --- .../apache/spark/ml/recommendation/ALS.scala | 29 ++++++++++++++----- .../spark/ml/recommendation/ALSSuite.scala | 15 ++++++---- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 658486f37c5d8..4a5743650c547 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -335,37 +335,39 @@ class ALSModel private[ml] ( * Returns top `numItems` items recommended for each user, for all users. * @param numItems max number of recommendations for each user * @return a DataFrame of (userCol: Int, recommendations), where recommendations are - * stored as an array of (itemId: Int, rating: Float) tuples. + * stored as an array of (itemCol: Int, rating: Float) Rows. */ @Since("2.2.0") def recommendForAllUsers(numItems: Int): DataFrame = { - recommendForAll(userFactors, itemFactors, $(userCol), numItems) + recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) } /** * Returns top `numUsers` users recommended for each item, for all items. * @param numUsers max number of recommendations for each item * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are - * stored as an array of (userId: Int, rating: Float) tuples. + * stored as an array of (userCol: Int, rating: Float) Rows. */ @Since("2.2.0") def recommendForAllItems(numUsers: Int): DataFrame = { - recommendForAll(itemFactors, userFactors, $(itemCol), numUsers) + recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) } /** * Makes recommendations for all users (or items). * @param srcFactors src factors for which to generate recommendations * @param dstFactors dst factors used to make recommendations - * @param srcOutputColumn name of the column for the source in the output DataFrame + * @param srcOutputColumn name of the column for the source ID in the output DataFrame + * @param dstOutputColumn name of the column for the destination ID in the output DataFrame * @param num max number of recommendations for each record * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are - * stored as an array of (dstId: Int, rating: Float) tuples. + * stored as an array of (dstOutputColumn: Int, rating: Float) Rows. */ private def recommendForAll( srcFactors: DataFrame, dstFactors: DataFrame, srcOutputColumn: String, + dstOutputColumn: String, num: Int): DataFrame = { import srcFactors.sparkSession.implicits._ @@ -376,8 +378,21 @@ class ALSModel private[ml] ( predict(srcFactors("features"), dstFactors("features"))) // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) - ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) + val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) .toDF(srcOutputColumn, "recommendations") + + // There is some performance hit from converting the (Int, Float) tuples to + // (dstOutputColumn: Int, rating: Float) structs using .rdd. Need SPARK-16483 for a fix. + val schema = new StructType() + .add(srcOutputColumn, IntegerType) + .add("recommendations", + ArrayType( + StructType( + StructField(dstOutputColumn, IntegerType, nullable = false) :: + StructField("rating", FloatType, nullable = false) :: + Nil + ))) + recs.sparkSession.createDataFrame(recs.rdd, schema) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 5eeb11816eb04..339c5f22155c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -693,7 +693,7 @@ class ALSSuite 1 -> Array(Row(3, 39f), Row(5, 33f)), 2 -> Array(Row(3, 51f), Row(5, 45f)) ) - checkRecommendations(topItems, expected) + checkRecommendations(topItems, expected, "item") } test("recommendForAllUsers with k = num_items") { @@ -706,7 +706,7 @@ class ALSSuite 1 -> Array(Row(3, 39f), Row(5, 33f), Row(4, 26f), Row(6, 16f)), 2 -> Array(Row(3, 51f), Row(5, 45f), Row(4, 30f), Row(6, 18f)) ) - checkRecommendations(topItems, expected) + checkRecommendations(topItems, expected, "item") } test("recommendForAllItems with k < num_users") { @@ -720,7 +720,7 @@ class ALSSuite 5 -> Array(Row(2, 45f), Row(0, 42f)), 6 -> Array(Row(0, 28f), Row(2, 18f)) ) - checkRecommendations(topUsers, expected) + checkRecommendations(topUsers, expected, "user") } test("recommendForAllItems with k = num_users") { @@ -734,15 +734,20 @@ class ALSSuite 5 -> Array(Row(2, 45f), Row(0, 42f), Row(1, 33f)), 6 -> Array(Row(0, 28f), Row(2, 18f), Row(1, 16f)) ) - checkRecommendations(topUsers, expected) + checkRecommendations(topUsers, expected, "user") } - private def checkRecommendations(topK: DataFrame, expected: Map[Int, Array[Row]]): Unit = { + private def checkRecommendations( + topK: DataFrame, + expected: Map[Int, Array[Row]], + dstColName: String): Unit = { assert(topK.columns.contains("recommendations")) topK.collect().foreach { row => val id = row.getInt(0) val recs = row.getAs[WrappedArray[Row]]("recommendations") assert(recs === expected(id)) + assert(recs(0).fieldIndex(dstColName) == 0) + assert(recs(0).fieldIndex("rating") == 1) } } } From b0680db96c0966ab25449d5716dac9e082db27a2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 2 Mar 2017 11:58:27 -0800 Subject: [PATCH 09/10] comments --- .../spark/ml/recommendation/ALSSuite.scala | 38 ++++++++++--------- .../TopByKeyAggregatorSuite.scala | 16 ++------ 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 339c5f22155c7..e494ea89e63bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -689,9 +689,9 @@ class ALSSuite assert(topItems.columns.contains("user")) val expected = Map( - 0 -> Array(Row(3, 54f), Row(4, 44f)), - 1 -> Array(Row(3, 39f), Row(5, 33f)), - 2 -> Array(Row(3, 51f), Row(5, 45f)) + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f), (5, 33f)), + 2 -> Array((3, 51f), (5, 45f)) ) checkRecommendations(topItems, expected, "item") } @@ -702,9 +702,9 @@ class ALSSuite assert(topItems.columns.contains("user")) val expected = Map( - 0 -> Array(Row(3, 54f), Row(4, 44f), Row(5, 42f), Row(6, 28f)), - 1 -> Array(Row(3, 39f), Row(5, 33f), Row(4, 26f), Row(6, 16f)), - 2 -> Array(Row(3, 51f), Row(5, 45f), Row(4, 30f), Row(6, 18f)) + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) checkRecommendations(topItems, expected, "item") } @@ -715,10 +715,10 @@ class ALSSuite assert(topUsers.columns.contains("item")) val expected = Map( - 3 -> Array(Row(0, 54f), Row(2, 51f)), - 4 -> Array(Row(0, 44f), Row(2, 30f)), - 5 -> Array(Row(2, 45f), Row(0, 42f)), - 6 -> Array(Row(0, 28f), Row(2, 18f)) + 3 -> Array((0, 54f), (2, 51f)), + 4 -> Array((0, 44f), (2, 30f)), + 5 -> Array((2, 45f), (0, 42f)), + 6 -> Array((0, 28f), (2, 18f)) ) checkRecommendations(topUsers, expected, "user") } @@ -729,23 +729,27 @@ class ALSSuite assert(topUsers.columns.contains("item")) val expected = Map( - 3 -> Array(Row(0, 54f), Row(2, 51f), Row(1, 39f)), - 4 -> Array(Row(0, 44f), Row(2, 30f), Row(1, 26f)), - 5 -> Array(Row(2, 45f), Row(0, 42f), Row(1, 33f)), - 6 -> Array(Row(0, 28f), Row(2, 18f), Row(1, 16f)) + 3 -> Array((0, 54f), (2, 51f), (1, 39f)), + 4 -> Array((0, 44f), (2, 30f), (1, 26f)), + 5 -> Array((2, 45f), (0, 42f), (1, 33f)), + 6 -> Array((0, 28f), (2, 18f), (1, 16f)) ) checkRecommendations(topUsers, expected, "user") } private def checkRecommendations( topK: DataFrame, - expected: Map[Int, Array[Row]], + expected: Map[Int, Array[(Int, Float)]], dstColName: String): Unit = { + val spark = this.spark + import spark.implicits._ + assert(topK.columns.contains("recommendations")) + topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) => + assert(recs === expected(id)) + } topK.collect().foreach { row => - val id = row.getInt(0) val recs = row.getAs[WrappedArray[Row]]("recommendations") - assert(recs === expected(id)) assert(recs(0).fieldIndex(dstColName) == 0) assert(recs(0).fieldIndex("rating") == 1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala index 2f210f4b1e368..5e763a8e908b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala @@ -35,11 +35,7 @@ class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { (0, 5, 42f), (0, 6, 28f), (1, 3, 39f), - (1, 4, 26f), - (1, 5, 33f), - (1, 6, 16f), (2, 3, 51f), - (2, 4, 30f), (2, 5, 45f), (2, 6, 18f) ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn) @@ -51,7 +47,7 @@ class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map( 0 -> Array((3, 54f), (4, 44f)), - 1 -> Array((3, 39f), (5, 33f)), + 1 -> Array((3, 39f)), 2 -> Array((3, 51f), (5, 45f)) ) checkTopK(topK, expected) @@ -63,8 +59,8 @@ class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map( 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), - 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), - 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f), (6, 18f)) ) checkTopK(topK, expected) } @@ -72,10 +68,6 @@ class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { private def checkTopK( topK: Dataset[(Int, Array[(Int, Float)])], expected: Map[Int, Array[(Int, Float)]]): Unit = { - topK.collect().foreach { record => - val id = record._1 - val recs = record._2 - assert(recs === expected(id)) - } + topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) } } } From 6a7e3d138b33c66644cdf68b6b20287ab0705aa6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 3 Mar 2017 14:32:59 -0800 Subject: [PATCH 10/10] no longer needing to cause serialization costs --- .../apache/spark/ml/recommendation/ALS.scala | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4a5743650c547..60dd7367053e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -379,20 +379,14 @@ class ALSModel private[ml] ( // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) - .toDF(srcOutputColumn, "recommendations") - - // There is some performance hit from converting the (Int, Float) tuples to - // (dstOutputColumn: Int, rating: Float) structs using .rdd. Need SPARK-16483 for a fix. - val schema = new StructType() - .add(srcOutputColumn, IntegerType) - .add("recommendations", - ArrayType( - StructType( - StructField(dstOutputColumn, IntegerType, nullable = false) :: - StructField("rating", FloatType, nullable = false) :: - Nil - ))) - recs.sparkSession.createDataFrame(recs.rdd, schema) + .toDF("id", "recommendations") + + val arrayType = ArrayType( + new StructType() + .add(dstOutputColumn, IntegerType) + .add("rating", FloatType) + ) + recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) } }