Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,32 +137,14 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}

def checkSupportedGroupingDataType(
expressionString: String,
dataType: DataType): Unit = dataType match {
case BinaryType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in binary type or its inner field is " +
s"in binary type")
case a: ArrayType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in array type or its inner field is " +
s"in array type")
case m: MapType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in map type or its inner field is " +
s"in map type")
case s: StructType =>
s.fields.foreach { f =>
checkSupportedGroupingDataType(expressionString, f.dataType)
}
case udt: UserDefinedType[_] =>
checkSupportedGroupingDataType(expressionString, udt.sqlType)
case _ => // OK
}

def checkValidGroupingExprs(expr: Expression): Unit = {
checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
s"expression ${expr.prettyString} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not a orderable " +
s"data type.")
}

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,49 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
val elementA = freshName("elementA")
val isNullA = freshName("isNullA")
val elementB = freshName("elementB")
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val minLength = freshName("minLength")
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
int lengthA = a.numElements();
int lengthB = b.numElements();
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < $minLength; i++) {
boolean $isNullA = a.isNullAt(i);
boolean $isNullB = b.isNullAt(i);
if ($isNullA && $isNullB) {
// Nothing
} else if ($isNullA) {
return -1;
} else if ($isNullB) {
return 1;
} else {
${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
}
}
}

if (lengthA < lengthB) {
return -1;
} else if (lengthA > lengthB) {
return 1;
}
return 0;
}
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
private lazy val lt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}

Expand All @@ -90,6 +91,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
private lazy val gt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
case dt: AtomicType if order.direction == Descending =>
dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case a: ArrayType if order.direction == Ascending =>
a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case a: ArrayType if order.direction == Descending =>
a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case s: StructType if order.direction == Ascending =>
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if order.direction == Descending =>
Expand Down Expand Up @@ -86,6 +90,8 @@ object RowOrdering {
case NullType => true
case dt: AtomicType => true
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
case array: ArrayType => isOrderable(array.elementType)
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ object TypeUtils {
def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ private[sql] object TypeCollection {
* Types that can be ordered/compared. In the long run we should probably make this a trait
* that can be mixed into each data type, and perhaps create an [[AbstractDataType]].
*/
// TODO: Should we consolidate this with RowOrdering.isOrderable?
val Ordered = TypeCollection(
BooleanType,
ByteType, ShortType, IntegerType, LongType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.spark.sql.types

import org.apache.spark.sql.catalyst.util.ArrayData
import org.json4s.JsonDSL._

import org.apache.spark.annotation.DeveloperApi

import scala.math.Ordering


object ArrayType extends AbstractDataType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
Expand Down Expand Up @@ -81,4 +84,49 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || elementType.existsRecursively(f)
}

@transient
private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
private[this] val elementOrdering: Ordering[Any] = elementType match {
case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case other =>
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
}

def compare(x: ArrayData, y: ArrayData): Int = {
val leftArray = x
val rightArray = y
val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
var i = 0
while (i < minLength) {
val isNullLeft = leftArray.isNullAt(i)
val isNullRight = rightArray.isNullAt(i)
if (isNullLeft && isNullRight) {
// Do nothing.
} else if (isNullLeft) {
return -1
} else if (isNullRight) {
return 1
} else {
val comp =
elementOrdering.compare(
leftArray.get(i, elementType),
rightArray.get(i, elementType))
if (comp != 0) {
return comp
}
}
i += 1
}
if (leftArray.numElements() < rightArray.numElements()) {
return -1
} else if (leftArray.numElements() > rightArray.numElements()) {
return 1
} else {
return 0
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData}
import org.apache.spark.sql.types._

import scala.beans.{BeanProperty, BeanInfo}
Expand Down Expand Up @@ -53,21 +53,29 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
}

@BeanInfo
private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int])

private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {

override def sqlType: DataType = ArrayType(IntegerType)
override def sqlType: DataType = MapType(IntegerType, IntegerType)

override def serialize(obj: Any): ArrayData = {
override def serialize(obj: Any): MapData = {
obj match {
case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
case groupableData: UngroupableData =>
val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
new ArrayBasedMapData(keyArray, valueArray)
}
}

override def deserialize(datum: Any): UngroupableData = {
datum match {
case data: Array[Int] => UngroupableData(data)
case data: MapData =>
val keyArray = data.keyArray().array
val valueArray = data.valueArray().array
assert(keyArray.length == valueArray.length)
val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]]
UngroupableData(mapData)
}
}

Expand Down Expand Up @@ -154,8 +162,8 @@ class AnalysisErrorSuite extends AnalysisTest {

errorTest(
"sorting by unsupported column types",
listRelation.orderBy('list.asc),
"sort" :: "type" :: "array<int>" :: Nil)
mapRelation.orderBy('map.asc),
"sort" :: "type" :: "map<int,int>" :: Nil)

errorTest(
"non-boolean filters",
Expand Down Expand Up @@ -259,32 +267,33 @@ class AnalysisErrorSuite extends AnalysisTest {
case true =>
assertAnalysisSuccess(plan, true)
case false =>
assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
}

}

val supportedDataTypes = Seq(
StringType,
StringType, BinaryType,
NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", StringType, nullable = true),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
new GroupableUDT())
supportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = true)
}

val unsupportedDataTypes = Seq(
BinaryType,
ArrayType(IntegerType),
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
.add("f2", MapType(StringType, LongType), nullable = true),
new UngroupableUDT())
unsupportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{TypeCollection, StringType}
import org.apache.spark.sql.types.{LongType, TypeCollection, StringType}

class ExpressionTypeCheckingSuite extends SparkFunSuite {

val testRelation = LocalRelation(
'intField.int,
'stringField.string,
'booleanField.boolean,
'complexField.array(StringType))
'arrayField.array(StringType),
'mapField.map(StringType, LongType))

def assertError(expr: Expression, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
Expand Down Expand Up @@ -90,9 +91,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type")
assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type")

assertError(MaxOf('complexField, 'complexField),
assertError(MaxOf('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(MinOf('complexField, 'complexField),
assertError(MinOf('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
}

Expand All @@ -109,31 +110,31 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(EqualTo('intField, 'booleanField))
assertSuccess(EqualNullSafe('intField, 'booleanField))

assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
assertErrorForDifferingTypes(EqualTo('intField, 'mapField))
assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField))
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))

assertError(LessThan('complexField, 'complexField),
assertError(LessThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(LessThanOrEqual('complexField, 'complexField),
assertError(LessThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(GreaterThan('complexField, 'complexField),
assertError(GreaterThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(GreaterThanOrEqual('complexField, 'complexField),
assertError(GreaterThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")

assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))

assertError(
CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)),
CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)),
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
Expand All @@ -147,9 +148,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
// We will cast String to Double for sum and average
assertSuccess(Sum('stringField))
assertSuccess(Average('stringField))
assertSuccess(Min('arrayField))

assertError(Min('complexField), "min does not support ordering on type")
assertError(Max('complexField), "max does not support ordering on type")
assertError(Min('mapField), "min does not support ordering on type")
assertError(Max('mapField), "max does not support ordering on type")
assertError(Sum('booleanField), "function sum requires numeric type")
assertError(Average('booleanField), "function average requires numeric type")
}
Expand Down Expand Up @@ -184,7 +186,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {

assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
assertError(Round('intField, 'booleanField), "requires int type")
assertError(Round('intField, 'complexField), "requires int type")
assertError(Round('intField, 'mapField), "requires int type")
assertError(Round('booleanField, 'intField), "requires numeric type")
}
}
Loading