diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5f533fecf8d07..98d12c2ef4460 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -44,6 +45,12 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @ExpressionDescription( usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of `str1`, `str2`, ..., `strN`.", extended = """ + Arguments: + str - The strings to be concatenated. + + The arguments are expressions that return a value of a character string. If any argument is null, the + result is the null value. + Examples: > SELECT _FUNC_('Spark','SQL'); SparkSQL @@ -88,6 +95,15 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas @ExpressionDescription( usage = "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by `sep`.", extended = """ + Arguments: + sep - The separator for the rest of the arguments. + str | array(str) - The strings to be concatenated. + + The arguments can be expressions that return a value of a character string. The arguments from + the second argument can also be expressions that return array. Minimum number of + arguments is 3. The function ignores null values and returns an empty string if all values + are null. It returns null only if the separator is null. + Examples: > SELECT _FUNC_(' ', Spark', 'SQL'); Spark SQL @@ -108,6 +124,14 @@ case class ConcatWs(children: Seq[Expression]) override def dataType: DataType = StringType + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 3) { + TypeCheckFailure("requires at least three arguments") + } else { + super.checkInputDataTypes() + } + } + override def nullable: Boolean = children.head.nullable override def foldable: Boolean = children.forall(_.foldable) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index bcc2351049953..d43ed2e0fc903 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + +import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class StringFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ - test("string concat") { + test("string concat - basic") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") checkAnswer( @@ -36,7 +40,28 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("ab", null)) } - test("string concat_ws") { + test("string concat - all compatible types") { + val allTypeData = AllTypeTestData(spark) + val df = allTypeData.dataFrame + checkAnswer( + df.select(concat(allTypeData.getStringCompatibleColumns: _*)), + Row("11111.011.011.010001970-01-011970-01-01 00:00:00aatrue") :: + Row("22222.022.022.020001970-02-021970-01-01 00:00:05bbbbfalse") :: Nil) + } + + test("string concat - unsupported types") { + val allTypeData = AllTypeTestData(spark) + val df = allTypeData.dataFrame + + Seq(allTypeData.mapCol, allTypeData.arrayIntCol, allTypeData.structCol).foreach { col => + val e = intercept[AnalysisException] { + df.select(concat(allTypeData.stringCol, col)) + }.getMessage + assert(e.contains("argument 2 requires string type")) + } + } + + test("string concat_ws - basic") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") checkAnswer( @@ -46,6 +71,47 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("concat_ws('||', a, b, c)"), Row("a||b")) + + checkAnswer( + df.selectExpr("concat_ws(null, a, b, c)"), + Row(null)) + + checkAnswer( + df.selectExpr("concat_ws(a, b, b)"), + Row("bab")) + + val df1 = Seq[(String, String)]((null, null)).toDF("a", "b") + + checkAnswer( + df1.selectExpr("concat_ws('||', a, b)"), + Row("")) + + val e = intercept[AnalysisException] { + df1.selectExpr("concat_ws('||', b)") + }.getMessage + assert(e.contains("requires at least three arguments")) + } + + test("string concat_ws - all compatible types") { + val allTypeData = AllTypeTestData(spark) + val df = allTypeData.dataFrame + checkAnswer( + df.select(concat_ws("_", + allTypeData.getStringCompatibleColumns :+ allTypeData.arrayStringCol: _*)), + Row("1_1_1_1_1.01_1.01_1.01000_1970-01-01_1970-01-01 00:00:00_a_a_true_a_b") :: + Row("2_2_2_2_2.02_2.02_2.02000_1970-02-02_1970-01-01 00:00:05_bb_bb_false_c_d") :: Nil) + } + + test("string concat_ws - unsupported types") { + val allTypeData = AllTypeTestData(spark) + val df = allTypeData.dataFrame + + Seq(allTypeData.mapCol, allTypeData.arrayIntCol, allTypeData.structCol).foreach { col => + val e = intercept[AnalysisException] { + df.select(concat_ws("_", col, col)) + }.getMessage + assert(e.contains("argument 2 requires (array or string) type")) + } } test("string elt") { @@ -452,3 +518,72 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } } + +case class AllTypeTestData (spark: SparkSession) { + private val intSeq = Seq(1, 2) + private val doubleSeq = Seq(1.01d, 2.02d) + private val stringSeq = Seq("a", "bb") + private val booleanSeq = Seq(true, false) + private val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) + private val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05") + .map(Timestamp.valueOf) + private val arrayIntSeq = Seq(Seq(1), Seq(2)) + private val arrayStringSeq = Seq(Seq("a", "b"), Seq("c", "d")) + private val mapSeq = Seq(Map("a" -> "1", "b" -> "2"), Map("d" -> "3", "e" -> "4")) + private val structSeq = Seq(Row("d"), Row("c")) + + private val allTypeSchema = StructType(Seq( + StructField("byteCol", ByteType), + StructField("shortCol", ShortType), + StructField("intCol", IntegerType), + StructField("longCol", LongType), + StructField("floatCol", FloatType), + StructField("doubleCol", DoubleType), + StructField("decimalCol", DecimalType(10, 5)), + StructField("dateCol", DateType), + StructField("timestampCol", TimestampType), + StructField("stringCol", StringType), + StructField("binaryCol", BinaryType), + StructField("booleanCol", BooleanType), + StructField("arrayIntCol", ArrayType(IntegerType, containsNull = true)), + StructField("arrayStringCol", ArrayType(StringType, containsNull = true)), + StructField("mapCol", MapType(StringType, StringType)), + StructField("structCol", new StructType().add("a", StringType)) + )) + + private val rowRDD: RDD[Row] = spark.sparkContext.parallelize(intSeq.indices.map { i => + Row(intSeq(i).toByte, intSeq(i).toShort, intSeq(i), intSeq(i).toLong, + doubleSeq(i).toFloat, doubleSeq(i), Decimal(doubleSeq(i)), + dateSeq(i), timestampSeq(i), + stringSeq(i), stringSeq(i).getBytes, + booleanSeq(i), + arrayIntSeq(i), + arrayStringSeq(i), + mapSeq(i), + structSeq(i)) + }) + + val dataFrame: DataFrame = spark.createDataFrame(rowRDD, allTypeSchema) + + val byteCol = Column("byteCol") + val shortCol = Column("shortCol") + val intCol = Column("intCol") + val longCol = Column("longCol") + val floatCol = Column("floatCol") + val doubleCol = Column("doubleCol") + val decimalCol = Column("decimalCol") + val dateCol = Column("dateCol") + val timestampCol = Column("timestampCol") + val stringCol = Column("stringCol") + val binaryCol = Column("binaryCol") + val booleanCol = Column("booleanCol") + val arrayIntCol = Column("arrayIntCol") + val arrayStringCol = Column("arrayStringCol") + val mapCol = Column("mapCol") + val structCol = Column("structCol") + + def getStringCompatibleColumns: Seq[Column] = { + Seq(byteCol, shortCol, intCol, longCol, floatCol, doubleCol, decimalCol, + dateCol, timestampCol, stringCol, binaryCol, booleanCol) + } +}