Skip to content

Commit d01c661

Browse files
committed
Add UserDefinedType to Cast.
1 parent 51b76b1 commit d01c661

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ object Cast {
8181
toField.nullable)
8282
}
8383

84+
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass =>
85+
true
86+
8487
case _ => false
8588
}
8689

@@ -473,6 +476,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
473476
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
474477
case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
475478
case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
479+
case udt: UserDefinedType[_]
480+
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
481+
(c, evPrim, evNull) => s"$evPrim = $c;"
476482
}
477483

478484
// Since we need to cast child expressions recursively inside ComplexTypes, such as Map's

sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.util.concurrent.ConcurrentMap
21+
2022
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
2123

2224
import scala.beans.{BeanInfo, BeanProperty}
2325
import scala.reflect.runtime.universe.TypeTag
2426

27+
import com.google.common.collect.MapMaker
28+
2529
import org.apache.spark.rdd.RDD
2630
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
2731
import org.apache.spark.sql.catalyst.encoders._
@@ -94,6 +98,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
9498
assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0))))
9599
}
96100

101+
private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap()
102+
outers.put(getClass.getName, this)
103+
97104
test("user type with ScalaReflection") {
98105
val points = Seq(
99106
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
@@ -109,6 +116,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
109116
points.zip(decodedPoints).foreach { case (p, p2) =>
110117
assert(p.label == p2(0) && p.features == p2(1))
111118
}
119+
120+
val boundEncoder = pointEncoder.resolve(attributeSeq, outers).bind(attributeSeq)
121+
val point = MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0)))
122+
assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point)
112123
}
113124

114125
test("UDTs and UDFs") {

0 commit comments

Comments
 (0)