Skip to content

Commit edec2d8

Browse files
committed
Add sort_map expression and enable aggregates and joins
1 parent 8dd8ddb commit edec2d8

File tree

10 files changed

+300
-15
lines changed

10 files changed

+300
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class Analyzer(
104104
ResolveAggregateFunctions ::
105105
TimeWindowing ::
106106
ResolveInlineTables ::
107+
SortMaps ::
107108
TypeCoercion.typeCoercionRules ++
108109
extendedResolutionRules : _*),
109110
Batch("Nondeterministic", Once,
@@ -2332,3 +2333,27 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
23322333
CreateNamedStruct(children.toList)
23332334
}
23342335
}
2336+
2337+
/**
2338+
* MapType expressions are not comparable.
2339+
*/
2340+
object SortMaps extends Rule[LogicalPlan] {
2341+
private def hasUnorderedMap(e: Expression): Boolean = e.dataType match {
2342+
case m: MapType => !m.ordered
2343+
case _ => false
2344+
}
2345+
2346+
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
2347+
case cmp @ BinaryComparison(left, right) if cmp.resolved && hasUnorderedMap(left) =>
2348+
cmp.withNewChildren(SortMap(left) :: right :: Nil)
2349+
case cmp @ BinaryComparison(left, right) if cmp.resolved && hasUnorderedMap(right) =>
2350+
cmp.withNewChildren(left :: SortMap(right) :: Nil)
2351+
} transform {
2352+
case a: Aggregate if a.resolved && a.groupingExpressions.exists(hasUnorderedMap) =>
2353+
a.transformExpressionsUp {
2354+
case a: Attribute if hasUnorderedMap(a) =>
2355+
Alias(SortMap(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
2356+
case e if hasUnorderedMap(e) => SortMap(e)
2357+
}
2358+
}
2359+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ trait CheckAnalysis extends PredicateHelper {
190190
case e if e.dataType.isInstanceOf[BinaryType] =>
191191
failAnalysis(s"binary type expression ${e.sql} cannot be used " +
192192
"in join conditions")
193-
case e if e.dataType.isInstanceOf[MapType] =>
193+
case e if e.dataType.isInstanceOf[MapType] &&
194+
!e.dataType.asInstanceOf[MapType].ordered =>
194195
failAnalysis(s"map type expression ${e.sql} cannot be used " +
195196
"in join conditions")
196197
case _ => // OK

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,53 @@ class CodegenContext {
544544
"""
545545
addNewFunction(compareFunc, funcCode)
546546
s"this.$compareFunc($c1, $c2)"
547+
case MapType(keyType, valueType, _, true) =>
548+
val compareFunc = freshName("compareMap")
549+
val funcCode: String =
550+
s"""
551+
public int $compareFunc(MapData a, MapData b) {
552+
int lengthA = a.numElements();
553+
int lengthB = b.numElements();
554+
ArrayData aKeys = a.keyArray();
555+
ArrayData aValues = a.valueArray();
556+
ArrayData bKeys = b.keyArray();
557+
ArrayData bValues = b.valueArray();
558+
int minLength = (lengthA > lengthB) ? lengthB : lengthA;
559+
for (int i = 0; i < minLength; i++) {
560+
${javaType(keyType)} keyA = ${getValue("aKeys", valueType, "i")};
561+
${javaType(keyType)} keyB = ${getValue("bKeys", valueType, "i")};
562+
int comp = ${genComp(valueType, "keyA", "keyB")};
563+
if (comp != 0) {
564+
return comp;
565+
}
566+
boolean isNullA = aValues.isNullAt(i);
567+
boolean isNullB = bValues.isNullAt(i);
568+
if (isNullA && isNullB) {
569+
// Nothing
570+
} else if (isNullA) {
571+
return -1;
572+
} else if (isNullB) {
573+
return 1;
574+
} else {
575+
${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")};
576+
${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")};
577+
int comp = ${genComp(valueType, "valueA", "valueB")};
578+
if (comp != 0) {
579+
return comp;
580+
}
581+
}
582+
}
583+
584+
if (lengthA < lengthB) {
585+
return -1;
586+
} else if (lengthA > lengthB) {
587+
return 1;
588+
}
589+
return 0;
590+
}
591+
"""
592+
addNewFunction(compareFunc, funcCode)
593+
s"this.$compareFunc($c1, $c2)"
547594
case schema: StructType =>
548595
INPUT_ROW = "i"
549596
val comparisons = GenerateOrdering.genComparisons(this, schema)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.Comparator
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
24-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
24+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
2525
import org.apache.spark.sql.types._
2626

2727
/**
@@ -287,3 +287,138 @@ case class ArrayContains(left: Expression, right: Expression)
287287

288288
override def prettyName: String = "array_contains"
289289
}
290+
291+
/**
292+
* This expression sorts a map in ascending order.
293+
*/
294+
case class SortMap(child: Expression) extends UnaryExpression with ExpectsInputTypes {
295+
296+
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
297+
298+
/** Create a data type in which all maps are ordered. */
299+
private[this] def createDataType(dataType: DataType): DataType = dataType match {
300+
case StructType(fields) =>
301+
StructType(fields.map { field =>
302+
field.copy(dataType = createDataType(field.dataType))
303+
})
304+
case ArrayType(elementType, containsNull) =>
305+
ArrayType(createDataType(elementType), containsNull)
306+
case MapType(keyType, valueType, valueContainsNull, false) =>
307+
MapType(createDataType(keyType), createDataType(valueType), valueContainsNull, true)
308+
case _ =>
309+
dataType
310+
}
311+
312+
override lazy val dataType: DataType = createDataType(child.dataType)
313+
314+
private[this] val id = identity[Any] _
315+
316+
/**
317+
* Create a function that transforms a Spark SQL datum to a new datum for which all MapData
318+
* elements have been ordered.
319+
*/
320+
private[this] def createTransform(dataType: DataType): Option[Any => Any] = {
321+
dataType match {
322+
case m@MapType(keyType, valueType, _, false) =>
323+
val keyTransform = createTransform(keyType).getOrElse(id)
324+
val valueTransform = createTransform(valueType).getOrElse(id)
325+
val ordering = Ordering.Tuple2(m.interpretedKeyOrdering, m.interpretedValueOrdering)
326+
Option((data: Any) => {
327+
val input = data.asInstanceOf[MapData]
328+
val length = input.numElements()
329+
val buffer = Array.ofDim[(Any, Any)](length)
330+
331+
// Move the entries into a temporary buffer.
332+
var i = 0
333+
val keys = input.keyArray()
334+
val values = input.valueArray()
335+
while (i < length) {
336+
val key = keyTransform(keys.get(i, keyType))
337+
val value = if (!values.isNullAt(i)) {
338+
valueTransform(values.get(i, valueType))
339+
} else {
340+
null
341+
}
342+
buffer(i) = key -> value
343+
i += 1
344+
}
345+
346+
// Sort the buffer.
347+
java.util.Arrays.sort(buffer, ordering)
348+
349+
// Recreate the map data.
350+
i = 0
351+
val sortedKeys = Array.ofDim[Any](length)
352+
val sortedValues = Array.ofDim[Any](length)
353+
while (i < length) {
354+
sortedKeys(i) = buffer(i)._1
355+
sortedValues(i) = buffer(i)._2
356+
i += 1
357+
}
358+
ArrayBasedMapData(sortedKeys, sortedValues)
359+
})
360+
case ArrayType(dt, _) =>
361+
createTransform(dt).map { transform =>
362+
data: Any => {
363+
val input = data.asInstanceOf[ArrayData]
364+
val length = input.numElements()
365+
val output = Array.ofDim[Any](length)
366+
var i = 0
367+
while (i < length) {
368+
if (!input.isNullAt(i)) {
369+
output(i) = transform(input.get(i, dt))
370+
}
371+
i += i
372+
}
373+
new GenericArrayData(output)
374+
}
375+
}
376+
case StructType(fields) =>
377+
val transformOpts = fields.map { field =>
378+
createTransform(field.dataType)
379+
}
380+
// Only transform a struct if a meaningful transformation has been defined.
381+
if (transformOpts.exists(_.isDefined)) {
382+
val transforms = transformOpts.zip(fields).map { case (opt, field) =>
383+
val dataType = field.dataType
384+
val transform = opt.getOrElse(id)
385+
(input: InternalRow, i: Int) => {
386+
transform(input.get(i, dataType))
387+
}
388+
}
389+
val length = fields.length
390+
val tf = (data: Any) => {
391+
val input = data.asInstanceOf[InternalRow]
392+
val output = Array.ofDim[Any](length)
393+
var i = 0
394+
while (i < length) {
395+
if (!input.isNullAt(i)) {
396+
output(i) = transforms(i)(input, i)
397+
}
398+
i += 1
399+
}
400+
new GenericInternalRow(output)
401+
}
402+
Some(tf)
403+
} else {
404+
None
405+
}
406+
case _ =>
407+
None
408+
}
409+
}
410+
411+
@transient private[this] lazy val transform = {
412+
createTransform(child.dataType).getOrElse(id)
413+
}
414+
415+
override protected def nullSafeEval(input: Any): Any = transform(input)
416+
417+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
418+
// TODO we should code generate this.
419+
val tf = ctx.addReferenceObj("transform", transform, classOf[Any => Any].getCanonicalName)
420+
nullSafeCodeGen(ctx, ev, eval => {
421+
s"${ev.value} = (MapData)$tf.apply($eval);"
422+
})
423+
}
424+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
5353
a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
5454
case a: ArrayType if order.direction == Descending =>
5555
a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
56+
case m: MapType if m.ordered && order.direction == Ascending =>
57+
m.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
58+
case m: MapType if m.ordered && order.direction == Descending =>
59+
m.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
5660
case s: StructType if order.direction == Ascending =>
5761
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
5862
case s: StructType if order.direction == Descending =>
@@ -66,7 +70,7 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
6670
}
6771
i += 1
6872
}
69-
return 0
73+
0
7074
}
7175
}
7276

@@ -92,6 +96,7 @@ object RowOrdering {
9296
case dt: AtomicType => true
9397
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
9498
case array: ArrayType => isOrderable(array.elementType)
99+
case map: MapType => map.ordered
95100
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
96101
case _ => false
97102
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,14 @@ object TypeUtils {
6969
}
7070

7171
def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
72-
for (i <- 0 until x.length; if i < y.length) {
73-
val res = x(i).compareTo(y(i))
74-
if (res != 0) return res
72+
var i = 0
73+
val length = scala.math.min(x.length, y.length)
74+
while (i < length) {
75+
val res = x(i) - y(i)
76+
if (res != 0) {
77+
return res
78+
}
79+
i += 1
7580
}
7681
x.length - y.length
7782
}

0 commit comments

Comments
 (0)