Skip to content

Commit ba0d60f

Browse files
committed
add complex data type support
1 parent d127ac4 commit ba0d60f

File tree

2 files changed

+285
-77
lines changed

2 files changed

+285
-77
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 262 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,112 +2376,297 @@ case class ArrayDistinct(child: Expression)
23762376

23772377
lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
23782378

2379+
@transient private lazy val ordering: Ordering[Any] =
2380+
TypeUtils.getInterpretedOrdering(elementType)
2381+
2382+
override def checkInputDataTypes(): TypeCheckResult = {
2383+
super.checkInputDataTypes() match {
2384+
case f: TypeCheckResult.TypeCheckFailure => f
2385+
case TypeCheckResult.TypeCheckSuccess =>
2386+
TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
2387+
}
2388+
}
2389+
2390+
@transient private lazy val elementTypeSupportEquals = elementType match {
2391+
case BinaryType => false
2392+
case _: AtomicType => true
2393+
case _ => false
2394+
}
2395+
23792396
override def nullSafeEval(array: Any): Any = {
2380-
val elementType = child.dataType.asInstanceOf[ArrayType].elementType
2381-
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType).distinct
2382-
new GenericArrayData(data.asInstanceOf[Array[Any]])
2397+
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
2398+
if (elementTypeSupportEquals) {
2399+
new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
2400+
} else {
2401+
var foundNullElement = false
2402+
var pos = 0
2403+
for(i <- 0 until data.length) {
2404+
if (data(i) == null) {
2405+
if (!foundNullElement) {
2406+
foundNullElement = true
2407+
pos = pos + 1
2408+
}
2409+
} else {
2410+
var j = 0
2411+
var done = false
2412+
while (j <= i && !done) {
2413+
if (data(j) != null && ordering.equiv(data(j), data(i))) {
2414+
done = true
2415+
}
2416+
j = j + 1
2417+
}
2418+
if (i == j-1) {
2419+
pos = pos + 1
2420+
}
2421+
}
2422+
}
2423+
new GenericArrayData(data.slice(0, pos))
2424+
}
23832425
}
23842426

23852427
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
23862428
nullSafeCodeGen(ctx, ev, (array) => {
23872429
val i = ctx.freshName("i")
23882430
val j = ctx.freshName("j")
2389-
val hs = ctx.freshName("hs")
2431+
val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
2432+
val getValue1 = CodeGenerator.getValue(array, elementType, i)
2433+
val getValue2 = CodeGenerator.getValue(array, elementType, j)
23902434
val foundNullElement = ctx.freshName("foundNullElement")
2391-
val distinctArrayLen = ctx.freshName("distinctArrayLen")
2392-
val getValue = CodeGenerator.getValue(array, elementType, i)
23932435
val openHashSet = classOf[OpenHashSet[_]].getName
2436+
val hs = ctx.freshName("hs")
23942437
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
2438+
if(elementTypeSupportEquals) {
2439+
s"""
2440+
|int $sizeOfDistinctArray = 0;
2441+
|boolean $foundNullElement = false;
2442+
|$openHashSet $hs = new $openHashSet($classTag);
2443+
|for (int $i = 0; $i < $array.numElements(); $i++) {
2444+
| if ($array.isNullAt($i)) {
2445+
| if (!($foundNullElement)) {
2446+
| $foundNullElement = true;
2447+
| }
2448+
| }
2449+
| else {
2450+
| if (!($hs.contains($getValue1))) {
2451+
| $hs.add($getValue1);
2452+
| }
2453+
| }
2454+
|}
2455+
|$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
2456+
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
2457+
""".stripMargin
2458+
}
2459+
else {
2460+
s"""
2461+
|int $sizeOfDistinctArray = 0;
2462+
|boolean $foundNullElement = false;
2463+
|for (int $i = 0; $i < $array.numElements(); $i ++) {
2464+
| if ($array.isNullAt($i)) {
2465+
| if (!($foundNullElement)) {
2466+
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
2467+
| $foundNullElement = true;
2468+
| }
2469+
| }
2470+
| else {
2471+
| int $j;
2472+
| for ($j = 0; $j < $i; $j++) {
2473+
| if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)})
2474+
| break;
2475+
| }
2476+
| if ($i == $j) {
2477+
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
2478+
| }
2479+
| }
2480+
|}
2481+
|
2482+
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
2483+
""".stripMargin
2484+
}
2485+
})
2486+
}
2487+
2488+
private def setNull(
2489+
isPrimitive: Boolean,
2490+
foundNullElement: String,
2491+
distinctArray: String,
2492+
pos: String): String = {
2493+
val setNullValue =
2494+
if (!isPrimitive) {
2495+
s"""
2496+
|$distinctArray[$pos] = null;
2497+
""".
2498+
stripMargin
2499+
} else {
2500+
s"""
2501+
|$distinctArray.setNullAt($pos);
2502+
""".
2503+
stripMargin
2504+
}
2505+
2506+
s"""
2507+
|if (!($foundNullElement)) {
2508+
| $setNullValue;
2509+
| $pos = $pos + 1;
2510+
| $foundNullElement = true;
2511+
|}
2512+
""".stripMargin
2513+
}
2514+
2515+
private def setNotNullValue(isPrimitive: Boolean,
2516+
distinctArray: String,
2517+
pos: String,
2518+
getValue1: String,
2519+
primitiveValueTypeName: String): String = {
2520+
if (!isPrimitive) {
23952521
s"""
2396-
|int $distinctArrayLen = 0;
2397-
|boolean $foundNullElement = false;
2398-
|$openHashSet $hs = new $openHashSet($classTag);
2399-
|for (int $i = 0; $i < $array.numElements(); $i++) {
2400-
| if ($array.isNullAt($i)) {
2401-
| if (!($foundNullElement)) {
2402-
| $distinctArrayLen = $distinctArrayLen + 1;
2403-
| $foundNullElement = true;
2404-
| }
2405-
| }
2406-
| else {
2407-
| if (!($hs.contains($getValue))) {
2408-
| $hs.add($getValue);
2409-
| $distinctArrayLen = $distinctArrayLen + 1;
2410-
| }
2411-
| }
2412-
|}
2413-
|${genCodeForResult(ctx, ev, array, distinctArrayLen)}
2522+
|$distinctArray[$pos] = $getValue1;
24142523
""".stripMargin
2415-
})
2524+
} else {
2525+
s"""
2526+
|$distinctArray.set$primitiveValueTypeName($pos, $getValue1);
2527+
""".stripMargin
2528+
}
2529+
}
2530+
2531+
private def setValueForFastEval(
2532+
isPrimitive: Boolean,
2533+
hs: String,
2534+
distinctArray: String,
2535+
pos: String,
2536+
getValue1: String,
2537+
primitiveValueTypeName: String): String = {
2538+
val setValue = setNotNullValue(isPrimitive,
2539+
distinctArray, pos, getValue1, primitiveValueTypeName)
2540+
s"""
2541+
|if (!($hs.contains($getValue1))) {
2542+
| $hs.add($getValue1);
2543+
| $setValue;
2544+
| $pos = $pos + 1;
2545+
|}
2546+
""".stripMargin
2547+
}
2548+
2549+
private def setValueForbruteForceEval(
2550+
isPrimitive: Boolean,
2551+
i: String,
2552+
j: String,
2553+
inputArray: String,
2554+
distinctArray: String,
2555+
pos: String,
2556+
getValue1: String,
2557+
isEqual: String,
2558+
primitiveValueTypeName: String): String = {
2559+
val setValue = setNotNullValue(isPrimitive,
2560+
distinctArray, pos, getValue1, primitiveValueTypeName)
2561+
s"""
2562+
|int $j;
2563+
|for ($j = 0; $j < $i; $j ++) {
2564+
| if (!$inputArray.isNullAt($j) && $isEqual)
2565+
| break;
2566+
| }
2567+
| if ($i == $j) {
2568+
| $setValue;
2569+
| $pos = $pos + 1;
2570+
| }
2571+
""".stripMargin
24162572
}
24172573

24182574
def genCodeForResult(
24192575
ctx: CodegenContext,
24202576
ev: ExprCode,
24212577
inputArray: String,
2422-
newArraySize: String): String = {
2423-
val distinctArr = ctx.freshName("distinctArray")
2424-
val hs = ctx.freshName("hs")
2425-
val openHashSet = classOf[OpenHashSet[_]].getName
2578+
size: String): String = {
2579+
val distinctArray = ctx.freshName("distinctArray")
24262580
val i = ctx.freshName("i")
24272581
val j = ctx.freshName("j")
24282582
val pos = ctx.freshName("pos")
2583+
val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
2584+
val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
2585+
val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
24292586
val foundNullElement = ctx.freshName("foundNullElement")
2430-
val genericArrayData = classOf[GenericArrayData].getName
2431-
val getValue = CodeGenerator.getValue(inputArray, elementType, i)
2432-
2587+
val hs = ctx.freshName("hs")
2588+
val openHashSet = classOf[OpenHashSet[_]].getName
24332589
if (!CodeGenerator.isPrimitiveType(elementType)) {
24342590
val arrayClass = classOf[GenericArrayData].getName
24352591
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
2436-
s"""
2437-
|Object[] $distinctArr = new Object[$newArraySize];
2438-
|int $pos = 0;
2439-
|boolean $foundNullElement = false;
2440-
|$openHashSet $hs = new $openHashSet($classTag);
2441-
|for (int $i = 0; $i < $inputArray.numElements(); $i++) {
2442-
| if ($inputArray.isNullAt($i)) {
2443-
| if (!($foundNullElement)) {
2444-
| $distinctArr[$pos] = null;
2445-
| $pos = $pos + 1;
2446-
| $foundNullElement = true;
2447-
| }
2448-
| }
2449-
| else {
2450-
| if (!($hs.contains($getValue))) {
2451-
| $hs.add($getValue);
2452-
| $distinctArr[$pos] = $getValue;
2453-
| $pos = $pos + 1;
2454-
| }
2455-
| }
2456-
|}
2457-
|${ev.value} = new $arrayClass($distinctArr);
2592+
val setNullForNonPrimitive =
2593+
setNull(false, foundNullElement, distinctArray, pos)
2594+
if (elementTypeSupportEquals) {
2595+
val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "")
2596+
s"""
2597+
|int $pos = 0;
2598+
|Object[] $distinctArray = new Object[$size];
2599+
|boolean $foundNullElement = false;
2600+
|$openHashSet $hs = new $openHashSet($classTag);
2601+
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2602+
| if ($inputArray.isNullAt($i)) {
2603+
| $setNullForNonPrimitive;
2604+
| }
2605+
| else {
2606+
| $setValueForFast;
2607+
| }
2608+
|}
2609+
|${ev.value} = new $arrayClass($distinctArray);
2610+
""".stripMargin
2611+
}
2612+
else {
2613+
val setValueForbruteForce = setValueForbruteForceEval(false, i, j,
2614+
inputArray, distinctArray, pos, getValue1: String, isEqual, "")
2615+
s"""
2616+
|int $pos = 0;
2617+
|Object[] $distinctArray = new Object[$size];
2618+
|boolean $foundNullElement = false;
2619+
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2620+
| if ($inputArray.isNullAt($i)) {
2621+
| $setNullForNonPrimitive;
2622+
| }
2623+
| else {
2624+
| $setValueForbruteForce;
2625+
| }
2626+
|}
2627+
|${ev.value} = new $arrayClass($distinctArray);
24582628
""".stripMargin
2629+
}
24592630
} else {
24602631
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
2632+
val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos)
24612633
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()"
2462-
s"""
2463-
|${ctx.createUnsafeArray(distinctArr, newArraySize, elementType, s" $prettyName failed.")}
2464-
|int $pos = 0;
2465-
|boolean $foundNullElement = false;
2466-
|$openHashSet $hs = new $openHashSet($classTag);
2467-
|for (int $i = 0; $i < $inputArray.numElements(); $i++) {
2468-
| if ($inputArray.isNullAt($i)) {
2469-
| if (!($foundNullElement)) {
2470-
| $distinctArr.setNullAt($pos);
2471-
| $pos = $pos + 1;
2472-
| $foundNullElement = true;
2473-
| }
2474-
| }
2475-
| else {
2476-
| if (!($hs.contains($getValue))) {
2477-
| $hs.add($getValue);
2478-
| $distinctArr.set$primitiveValueTypeName($pos, $getValue);
2479-
| $pos = $pos + 1;
2480-
| }
2481-
| }
2482-
|}
2483-
|${ev.value} = $distinctArr;
2484-
""".stripMargin
2634+
val setValueForFast =
2635+
setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName)
2636+
if (elementTypeSupportEquals) {
2637+
s"""
2638+
|${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")}
2639+
|int $pos = 0;
2640+
|boolean $foundNullElement = false;
2641+
|$openHashSet $hs = new $openHashSet($classTag);
2642+
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2643+
| if ($inputArray.isNullAt($i)) {
2644+
| $setNullForPrimitive;
2645+
| }
2646+
| else {
2647+
| $setValueForFast;
2648+
| }
2649+
|}
2650+
|${ev.value} = $distinctArray;
2651+
""".stripMargin
2652+
} else {
2653+
val setValueForbruteForce = setValueForbruteForceEval(true, i, j,
2654+
inputArray, distinctArray, pos, getValue1: String, isEqual, primitiveValueTypeName)
2655+
s"""
2656+
|${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")}
2657+
|int $pos = 0;
2658+
|boolean $foundNullElement = false;
2659+
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2660+
| if ($inputArray.isNullAt($i)) {
2661+
| $setNullForPrimitive;
2662+
| }
2663+
| else {
2664+
| $setValueForbruteForce;
2665+
| }
2666+
|}
2667+
|${ev.value} = $distinctArray;
2668+
""".stripMargin
2669+
}
24852670
}
24862671
}
24872672

0 commit comments

Comments
 (0)