Skip to content

Commit d31854d

Browse files
Earthsonmengxr
authored andcommitted
[SPARK-12746][ML] ArrayType(_, true) should also accept ArrayType(_, false) fix for branch-1.6
https://issues.apache.org/jira/browse/SPARK-13359 Author: Earthson Lu <[email protected]> Closes #11237 from Earthson/SPARK-13359.
1 parent 2902798 commit d31854d

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
7070

7171
/** Validates and transforms the input schema. */
7272
protected def validateAndTransformSchema(schema: StructType): StructType = {
73-
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
73+
val typeCandidates = List(ArrayType(StringType, true), ArrayType(StringType, false))
74+
SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
7475
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
7576
}
7677

mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,23 @@ private[spark] object SchemaUtils {
4343
s"Column $colName must be of type $dataType but was actually $actualDataType.$message")
4444
}
4545

46+
/**
47+
* Check whether the given schema contains a column of one of the require data types.
48+
* @param colName column name
49+
* @param dataTypes required column data types
50+
*/
51+
def checkColumnTypes(
52+
schema: StructType,
53+
colName: String,
54+
dataTypes: Seq[DataType],
55+
msg: String = ""): Unit = {
56+
val actualDataType = schema(colName).dataType
57+
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
58+
require(dataTypes.exists(actualDataType.equals),
59+
s"Column $colName must be of type equal to one of the following types: " +
60+
s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message")
61+
}
62+
4663
/**
4764
* Appends a new column to the input schema. This fails if the given output column already exists.
4865
* @param schema input schema

0 commit comments

Comments
 (0)