Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,79 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
* AbstractStringType is an abstract class for StringType with collation support.
* AbstractStringType is an abstract class for StringType with collation support. As every type of
* collation can support trim specifier this class is parametrized with it.
*/
abstract class AbstractStringType extends AbstractDataType {
abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false)
extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def simpleString: String = "string"
private[sql] def canUseTrimCollation(other: DataType): Boolean =
supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation
}

/**
* Use StringTypeBinary for expressions supporting only binary collation.
*/
case object StringTypeBinary extends AbstractStringType {
case class StringTypeBinary(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality &&
canUseTrimCollation(other)
}

object StringTypeBinary extends StringTypeBinary(false) {
def apply(supportsTrimCollation: Boolean): StringTypeBinary = {
new StringTypeBinary(supportsTrimCollation)
}
}

/**
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
*/
case object StringTypeBinaryLcase extends AbstractStringType {
case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality ||
other.asInstanceOf[StringType].isUTF8LcaseCollation)
other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other)
}

object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
def apply(supportsTrimCollation: Boolean): StringTypeBinaryLcase = {
new StringTypeBinaryLcase(supportsTrimCollation)
}
}

/**
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
* and ICU) but limited to using case and accent sensitivity specifiers.
*/
case object StringTypeWithCaseAccentSensitivity extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]
case class StringTypeWithCaseAccentSensitivity(
override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && canUseTrimCollation(other)
}

object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) {
def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = {
new StringTypeWithCaseAccentSensitivity(supportsTrimCollation)
}
}

/**
* Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
* CS_AI collation types.
*/
case object StringTypeNonCSAICollation extends AbstractStringType {
case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI &&
canUseTrimCollation(other)
}

object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) {
def apply(supportsTrimCollation: Boolean): StringTypeNonCSAICollation = {
new StringTypeNonCSAICollation(supportsTrimCollation)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
private[sql] def isNonCSAI: Boolean =
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)

private[sql] def usesTrimCollation: Boolean =
CollationFactory.usesTrimCollation(collationId)

private[sql] def isUTF8BinaryCollation: Boolean =
collationId == CollationFactory.UTF8_BINARY_COLLATION_ID

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
override def dataType: DataType = BinaryType

final lazy val collationId: Int = expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ case class HllSketchAgg(

override def inputTypes: Seq[AbstractDataType] =
Seq(
TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType),
TypeCollection(
IntegerType,
LongType,
StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true),
BinaryType),
IntegerType)

override def dataType: DataType = BinaryType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ case class Collate(child: Expression, collationName: String)
extends UnaryExpression with ExpectsInputTypes {
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Expand Down Expand Up @@ -115,5 +116,6 @@ case class Collation(child: Expression)
val collationName = CollationFactory.fetchCollation(collationId).collationName
Literal.create(collationName, SQLConf.get.defaultStringType)
}
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
}
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,11 @@ class CollationSQLExpressionsSuite
StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI",
Map("1" -> "A", "2" -> "B", "3" -> "C"))
)
val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null)
val unsupportedTestCases = Seq(
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null))
testCases.foreach(t => {
// Unit test.
val text = Literal.create(t.text, StringType(t.collation))
Expand All @@ -998,28 +1002,30 @@ class CollationSQLExpressionsSuite
}
})
// Test unsupported collation.
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) {
val query =
s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " +
s"'${unsupportedTestCase.keyValueDelim}')"
checkError(
exception = intercept[AnalysisException] {
sql(query).collect()
},
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
sqlState = Some("42K09"),
parameters = Map(
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " +
"'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""),
"paramIndex" -> "first",
"inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"",
"inputType" -> "\"STRING COLLATE UNICODE_AI\"",
"requiredType" -> "\"STRING\""),
context = ExpectedContext(
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
start = 7,
stop = 41))
}
unsupportedTestCases.foreach(t => {
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) {
val query =
s"select str_to_map('${t.text}', '${t.pairDelim}', " +
s"'${t.keyValueDelim}')"
checkError(
exception = intercept[AnalysisException] {
sql(query).collect()
},
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
sqlState = Some("42K09"),
parameters = Map(
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate " + s"${t.collation}, " +
"'?' collate " + s"${t.collation}, '?' collate ${t.collation})" + "\""),
"paramIndex" -> "first",
"inputSql" -> ("\"'a:1,b:2,c:3' collate " + s"${t.collation}" + "\""),
"inputType" -> ("\"STRING COLLATE " + s"${t.collation}" + "\""),
"requiredType" -> "\"STRING\""),
context = ExpectedContext(
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
start = 7,
stop = 41))
}
})
}

test("Support RaiseError misc expression with collation") {
Expand Down