Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
Expand All @@ -44,6 +45,12 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@ExpressionDescription(
usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of `str1`, `str2`, ..., `strN`.",
extended = """
Arguments:
str - The strings to be concatenated.

The arguments are expressions that return a value of a character string. If any argument is null, the
result is the null value.

Examples:
> SELECT _FUNC_('Spark','SQL');
SparkSQL
Expand Down Expand Up @@ -88,6 +95,15 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
@ExpressionDescription(
usage = "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by `sep`.",
extended = """
Arguments:
sep - The separator for the rest of the arguments.
str | array(str) - The strings to be concatenated.

The arguments can be expressions that return a value of a character string. The arguments from
the second argument can also be expressions that return array<string>. Minimum number of
arguments is 3. The function ignores null values and returns an empty string if all values
are null. It returns null only if the separator is null.

Examples:
> SELECT _FUNC_(' ', Spark', 'SQL');
Spark SQL
Expand All @@ -108,6 +124,14 @@ case class ConcatWs(children: Seq[Expression])

override def dataType: DataType = StringType

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size < 3) {
TypeCheckFailure("requires at least three arguments")
} else {
super.checkInputDataTypes()
}
}

override def nullable: Boolean = children.head.nullable
override def foldable: Boolean = children.forall(_.foldable)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.sql

import java.sql.{Date, Timestamp}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class StringFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("string concat") {
test("string concat - basic") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
Expand All @@ -36,7 +40,28 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("ab", null))
}

test("string concat_ws") {
test("string concat - all compatible types") {
val allTypeData = AllTypeTestData(spark)
val df = allTypeData.dataFrame
checkAnswer(
df.select(concat(allTypeData.getStringCompatibleColumns: _*)),
Row("11111.011.011.010001970-01-011970-01-01 00:00:00aatrue") ::
Row("22222.022.022.020001970-02-021970-01-01 00:00:05bbbbfalse") :: Nil)
}

test("string concat - unsupported types") {
val allTypeData = AllTypeTestData(spark)
val df = allTypeData.dataFrame

Seq(allTypeData.mapCol, allTypeData.arrayIntCol, allTypeData.structCol).foreach { col =>
val e = intercept[AnalysisException] {
df.select(concat(allTypeData.stringCol, col))
}.getMessage
assert(e.contains("argument 2 requires string type"))
}
}

test("string concat_ws - basic") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
Expand All @@ -46,6 +71,47 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.selectExpr("concat_ws('||', a, b, c)"),
Row("a||b"))

checkAnswer(
df.selectExpr("concat_ws(null, a, b, c)"),
Row(null))

checkAnswer(
df.selectExpr("concat_ws(a, b, b)"),
Row("bab"))

val df1 = Seq[(String, String)]((null, null)).toDF("a", "b")

checkAnswer(
df1.selectExpr("concat_ws('||', a, b)"),
Row(""))

val e = intercept[AnalysisException] {
df1.selectExpr("concat_ws('||', b)")
}.getMessage
assert(e.contains("requires at least three arguments"))
}

test("string concat_ws - all compatible types") {
val allTypeData = AllTypeTestData(spark)
val df = allTypeData.dataFrame
checkAnswer(
df.select(concat_ws("_",
allTypeData.getStringCompatibleColumns :+ allTypeData.arrayStringCol: _*)),
Row("1_1_1_1_1.01_1.01_1.01000_1970-01-01_1970-01-01 00:00:00_a_a_true_a_b") ::
Row("2_2_2_2_2.02_2.02_2.02000_1970-02-02_1970-01-01 00:00:05_bb_bb_false_c_d") :: Nil)
}

test("string concat_ws - unsupported types") {
val allTypeData = AllTypeTestData(spark)
val df = allTypeData.dataFrame

Seq(allTypeData.mapCol, allTypeData.arrayIntCol, allTypeData.structCol).foreach { col =>
val e = intercept[AnalysisException] {
df.select(concat_ws("_", col, col))
}.getMessage
assert(e.contains("argument 2 requires (array<string> or string) type"))
}
}

test("string elt") {
Expand Down Expand Up @@ -452,3 +518,72 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {

}
}

case class AllTypeTestData (spark: SparkSession) {
private val intSeq = Seq(1, 2)
private val doubleSeq = Seq(1.01d, 2.02d)
private val stringSeq = Seq("a", "bb")
private val booleanSeq = Seq(true, false)
private val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf)
private val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05")
.map(Timestamp.valueOf)
private val arrayIntSeq = Seq(Seq(1), Seq(2))
private val arrayStringSeq = Seq(Seq("a", "b"), Seq("c", "d"))
private val mapSeq = Seq(Map("a" -> "1", "b" -> "2"), Map("d" -> "3", "e" -> "4"))
private val structSeq = Seq(Row("d"), Row("c"))

private val allTypeSchema = StructType(Seq(
StructField("byteCol", ByteType),
StructField("shortCol", ShortType),
StructField("intCol", IntegerType),
StructField("longCol", LongType),
StructField("floatCol", FloatType),
StructField("doubleCol", DoubleType),
StructField("decimalCol", DecimalType(10, 5)),
StructField("dateCol", DateType),
StructField("timestampCol", TimestampType),
StructField("stringCol", StringType),
StructField("binaryCol", BinaryType),
StructField("booleanCol", BooleanType),
StructField("arrayIntCol", ArrayType(IntegerType, containsNull = true)),
StructField("arrayStringCol", ArrayType(StringType, containsNull = true)),
StructField("mapCol", MapType(StringType, StringType)),
StructField("structCol", new StructType().add("a", StringType))
))

private val rowRDD: RDD[Row] = spark.sparkContext.parallelize(intSeq.indices.map { i =>
Row(intSeq(i).toByte, intSeq(i).toShort, intSeq(i), intSeq(i).toLong,
doubleSeq(i).toFloat, doubleSeq(i), Decimal(doubleSeq(i)),
dateSeq(i), timestampSeq(i),
stringSeq(i), stringSeq(i).getBytes,
booleanSeq(i),
arrayIntSeq(i),
arrayStringSeq(i),
mapSeq(i),
structSeq(i))
})

val dataFrame: DataFrame = spark.createDataFrame(rowRDD, allTypeSchema)

val byteCol = Column("byteCol")
val shortCol = Column("shortCol")
val intCol = Column("intCol")
val longCol = Column("longCol")
val floatCol = Column("floatCol")
val doubleCol = Column("doubleCol")
val decimalCol = Column("decimalCol")
val dateCol = Column("dateCol")
val timestampCol = Column("timestampCol")
val stringCol = Column("stringCol")
val binaryCol = Column("binaryCol")
val booleanCol = Column("booleanCol")
val arrayIntCol = Column("arrayIntCol")
val arrayStringCol = Column("arrayStringCol")
val mapCol = Column("mapCol")
val structCol = Column("structCol")

def getStringCompatibleColumns: Seq[Column] = {
Seq(byteCol, shortCol, intCol, longCol, floatCol, doubleCol, decimalCol,
dateCol, timestampCol, stringCol, binaryCol, booleanCol)
}
}