Skip to content

Commit 72446f1

Browse files
committed
Move tests.
1 parent 42303d2 commit 72446f1

File tree

2 files changed

+12
-27
lines changed

2 files changed

+12
-27
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.SparkFunSuite
2929
import org.apache.spark.sql.Encoders
3030
import org.apache.spark.sql.catalyst.expressions.AttributeReference
3131
import org.apache.spark.sql.catalyst.util.ArrayData
32-
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
32+
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData, ScalaReflection}
3333
import org.apache.spark.sql.types.{StructType, ArrayType}
3434

3535
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -239,6 +239,16 @@ class ExpressionEncoderSuite extends SparkFunSuite {
239239
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
240240
}
241241

242+
test("user type with ScalaReflection") {
243+
val point = (new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))
244+
val schema = ScalaReflection.schemaFor[Tuple2[ExamplePoint, ExamplePoint]]
245+
.dataType.asInstanceOf[StructType]
246+
val attributeSeq = schema.toAttributes
247+
val boundEncoder = encoderFor[Tuple2[ExamplePoint, ExamplePoint]]
248+
.resolve(attributeSeq, outers).bind(attributeSeq)
249+
assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point)
250+
}
251+
242252
test("nullable of encoder schema") {
243253
def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
244254
assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)

sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ import scala.reflect.runtime.universe.TypeTag
2727
import com.google.common.collect.MapMaker
2828

2929
import org.apache.spark.rdd.RDD
30-
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
31-
import org.apache.spark.sql.catalyst.encoders._
32-
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
33-
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
34-
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
30+
import org.apache.spark.sql.catalyst.CatalystTypeConverters
3531
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
3632
import org.apache.spark.sql.functions._
3733
import org.apache.spark.sql.test.SharedSQLContext
@@ -101,27 +97,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
10197
private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap()
10298
outers.put(getClass.getName, this)
10399

104-
test("user type with ScalaReflection") {
105-
val points = Seq(
106-
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
107-
MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
108-
109-
val schema = ScalaReflection.schemaFor[MyLabeledPoint].dataType.asInstanceOf[StructType]
110-
val attributeSeq = schema.toAttributes
111-
112-
val pointEncoder = encoderFor[MyLabeledPoint]
113-
val unsafeRows = points.map(pointEncoder.toRow(_).copy())
114-
val df = DataFrame(sqlContext, LocalRelation(attributeSeq, unsafeRows))
115-
val decodedPoints = df.collect()
116-
points.zip(decodedPoints).foreach { case (p, p2) =>
117-
assert(p.label == p2(0) && p.features == p2(1))
118-
}
119-
120-
val boundEncoder = pointEncoder.resolve(attributeSeq, outers).bind(attributeSeq)
121-
val point = MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0)))
122-
assert(boundEncoder.fromRow(boundEncoder.toRow(point)) === point)
123-
}
124-
125100
test("UDTs and UDFs") {
126101
sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
127102
pointsRDD.registerTempTable("points")

0 commit comments

Comments
 (0)