From 8993c0ec4542f3df89d673c1c13111f918037e24 Mon Sep 17 00:00:00 2001 From: Joshi Date: Thu, 7 May 2015 16:36:59 -0700 Subject: [PATCH 1/4] SPARK-7137: Add checkInputColumn back to Params and print more info --- .../scala/org/apache/spark/ml/param/params.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 51ce19d29cd29..ea779f7d409cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,7 +24,8 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.sql.types.{DataType, StructType} /** * :: AlphaComponent :: @@ -380,6 +381,18 @@ trait Params extends Identifiable with Serializable { this } + /** + * Check whether the given schema contains an input column. + * @param colName Input column name + * @param dataType Input column DataType + */ + protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + SchemaUtils.checkColumnType(schema, colName, dataType) + require(actualDataType.equals(dataType), s"Input column Name: $colName Description: ${getParam(colName)}") + } + + /** * Gets the default value of a parameter. */ From acf3e17b46e44b20b4144b585e08de7151b517d4 Mon Sep 17 00:00:00 2001 From: Joshi Date: Wed, 1 Jul 2015 16:07:35 -0700 Subject: [PATCH 2/4] update checkInputColumn to print more info if needed --- .../scala/org/apache/spark/ml/param/params.scala | 15 +-------------- .../org/apache/spark/ml/util/SchemaUtils.scala | 5 +++-- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ea779f7d409cf..51ce19d29cd29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,8 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.ml.util.Identifiable /** * :: AlphaComponent :: @@ -381,18 +380,6 @@ trait Params extends Identifiable with Serializable { this } - /** - * Check whether the given schema contains an input column. - * @param colName Input column name - * @param dataType Input column DataType - */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - SchemaUtils.checkColumnType(schema, colName, dataType) - require(actualDataType.equals(dataType), s"Input column Name: $colName Description: ${getParam(colName)}") - } - - /** * Gets the default value of a parameter. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 0383bf0b382b7..9252618715625 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -34,10 +34,11 @@ object SchemaUtils { * @param colName column name * @param dataType required column data type */ - def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + def checkColumnType(schema: StructType, colName: String, dataType: DataType, + msg: String = ""): Unit = { val actualDataType = schema(colName).dataType require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.") + s"Column $colName must be of type $dataType but was actually $actualDataType.$msg") } /** From 33ddd2e3fa56212f48c748dee7751eeafe6763f0 Mon Sep 17 00:00:00 2001 From: Joshi Date: Wed, 1 Jul 2015 19:58:41 -0700 Subject: [PATCH 3/4] update checkInputColumn to print more info if needed --- .../main/scala/org/apache/spark/ml/util/SchemaUtils.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 9252618715625..a5f3e14dac696 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -34,11 +34,14 @@ object SchemaUtils { * @param colName column name * @param dataType required column data type */ - def checkColumnType(schema: StructType, colName: String, dataType: DataType, + def checkColumnType(schema: StructType, + colName: String, + dataType: DataType, msg: String = ""): Unit = { val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.$msg") + s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } /** From 8c42b57cd2539f9e8b9aadb76c702ef6cacb9712 Mon Sep 17 00:00:00 2001 From: Joshi Date: Thu, 2 Jul 2015 18:02:05 -0700 Subject: [PATCH 4/4] update checkInputColumn to print more info if needed --- .../scala/org/apache/spark/ml/util/SchemaUtils.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index a5f3e14dac696..a1ebdf12d11d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -34,10 +34,11 @@ object SchemaUtils { * @param colName column name * @param dataType required column data type */ - def checkColumnType(schema: StructType, - colName: String, - dataType: DataType, - msg: String = ""): Unit = { + def checkColumnType( + schema: StructType, + colName: String, + dataType: DataType, + msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType),