Skip to content

Commit 6d4bec2

Browse files
committed
[SPARK-7160][SQL] Support converting DataFrames to typed RDDs.
1 parent 9223388 commit 6d4bec2

File tree

4 files changed

+550
-0
lines changed

4 files changed

+550
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ import java.util.{Map => JavaMap}
2424
import javax.annotation.Nullable
2525

2626
import scala.language.existentials
27+
import scala.reflect.ClassTag
2728

2829
import org.apache.spark.sql.Row
2930
import org.apache.spark.sql.catalyst.expressions._
3031
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3132
import org.apache.spark.sql.types._
3233
import org.apache.spark.unsafe.types.UTF8String
34+
import org.apache.spark.util.Utils
3335

3436
/**
3537
* Functions to convert Scala types to Catalyst types and vice versa.
@@ -39,6 +41,8 @@ object CatalystTypeConverters {
3941
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
4042
import scala.collection.Map
4143

44+
lazy val universe = ScalaReflection.universe
45+
4246
private def isPrimitive(dataType: DataType): Boolean = {
4347
dataType match {
4448
case BooleanType => true
@@ -454,4 +458,166 @@ object CatalystTypeConverters {
454458
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
455459
createToScalaConverter(dataType)(catalystValue)
456460
}
461+
462+
/**
463+
* Like createToScalaConverter(DataType), creates a function that converts a Catalyst object to a
464+
* Scala object; however, in this case, the Scala object is an instance of a subtype of Product
465+
* (e.g. a case class).
466+
*
467+
* If the given Scala type is not compatible with the given structType, this method ultimately
468+
* throws a ClassCastException when the converter is invoked.
469+
*
470+
* Typical use case would be converting a collection of rows that have the same schema. You will
471+
* call this function once to get a converter, and apply it to every row.
472+
*/
473+
private[sql] def createToProductConverter[T <: Product](
474+
structType: StructType)(implicit classTag: ClassTag[T]): InternalRow => T = {
475+
476+
// Use ScalaReflectionLock, to avoid reflection thread safety issues in 2.10.
477+
// https://issues.scala-lang.org/browse/SI-6240
478+
// http://docs.scala-lang.org/overviews/reflection/thread-safety.html
479+
ScalaReflectionLock.synchronized { createToProductConverter(classTag, structType) }
480+
}
481+
482+
private[sql] def createToProductConverter[T <: Product](
483+
classTag: ClassTag[T], structType: StructType): InternalRow => T = {
484+
485+
import universe._
486+
487+
val constructorMirror = {
488+
val mirror = runtimeMirror(Utils.getContextOrSparkClassLoader)
489+
val classSymbol = mirror.classSymbol(classTag.runtimeClass)
490+
val classMirror = mirror.reflectClass(classSymbol)
491+
val constructorSymbol = {
492+
// Adapted from ScalaReflection to find primary constructor.
493+
// https://issues.apache.org/jira/browse/SPARK-4791
494+
val symbol = classSymbol.toType.declaration(nme.CONSTRUCTOR)
495+
if (symbol.isMethod) {
496+
symbol.asMethod
497+
} else {
498+
val candidateSymbol =
499+
symbol.asTerm.alternatives.find { s => s.isMethod && s.asMethod.isPrimaryConstructor }
500+
if (candidateSymbol.isDefined) {
501+
candidateSymbol.get.asMethod
502+
} else {
503+
throw new IllegalArgumentException(s"No primary constructor for ${symbol.name}")
504+
}
505+
}
506+
}
507+
classMirror.reflectConstructor(constructorSymbol)
508+
}
509+
510+
val params = constructorMirror.symbol.paramss.head.toSeq
511+
val paramTypes = params.map { _.asTerm.typeSignature }
512+
val fields = structType.fields
513+
val dataTypes = fields.map { _.dataType }
514+
val converters: Seq[Any => Any] =
515+
paramTypes.zip(dataTypes).map { case (pt, dt) => createToScalaConverter(pt, dt) }
516+
517+
(row: InternalRow) => if (row == null) {
518+
null.asInstanceOf[T]
519+
} else {
520+
val convertedArgs =
521+
converters.zip(row.toSeq(dataTypes)).map { case (converter, arg) => converter(arg) }
522+
try {
523+
constructorMirror.apply(convertedArgs: _*).asInstanceOf[T]
524+
} catch {
525+
case e: IllegalArgumentException => // argument type mismatch
526+
val message =
527+
s"""|Error constructing ${classTag.runtimeClass.getName}: ${e.getMessage};
528+
|paramTypes: ${paramTypes}, dataTypes: ${dataTypes},
529+
|convertedArgs: ${convertedArgs}""".stripMargin.replace("\n", " ")
530+
throw new ClassCastException(message)
531+
}
532+
}
533+
}
534+
535+
/**
536+
* Like createToScalaConverter(DataType), but with a Scala type hint.
537+
*
538+
* Please keep in sync with createToScalaConverter(DataType) and ScalaReflection.schemaFor[T].
539+
*/
540+
private[sql] def createToScalaConverter(
541+
universeType: universe.Type, dataType: DataType): Any => Any = {
542+
543+
import universe._
544+
545+
(universeType, dataType) match {
546+
case (t, dt) if t <:< typeOf[Option[_]] =>
547+
val TypeRef(_, _, Seq(elementType)) = t
548+
val converter: Any => Any = createToScalaConverter(elementType, dt)
549+
(catalystValue: Any) => Option(converter(catalystValue))
550+
551+
case (t, udt: UserDefinedType[_]) =>
552+
(catalystValue: Any) => if (catalystValue == null) null else udt.deserialize(catalystValue)
553+
554+
case (t, bt: BinaryType) => identity
555+
556+
case (t, at: ArrayType) if t <:< typeOf[Array[_]] =>
557+
throw new UnsupportedOperationException("Array[_] is not supported; try using Seq instead.")
558+
559+
case (t, at: ArrayType) if t <:< typeOf[Seq[_]] =>
560+
val TypeRef(_, _, Seq(elementType)) = t
561+
val converter: Any => Any = createToScalaConverter(elementType, at.elementType)
562+
(catalystValue: Any) => catalystValue match {
563+
case arrayData: ArrayData => arrayData.toArray[Any](at.elementType).map(converter).toSeq
564+
case o => o
565+
}
566+
567+
case (t, mt: MapType) if t <:< typeOf[Map[_, _]] =>
568+
val TypeRef(_, _, Seq(keyType, valueType)) = t
569+
val keyConverter: Any => Any = createToScalaConverter(keyType, mt.keyType)
570+
val valueConverter: Any => Any = createToScalaConverter(valueType, mt.valueType)
571+
(catalystValue: Any) => catalystValue match {
572+
case mapData: MapData =>
573+
val keys = mapData.keyArray().toArray[Any](mt.keyType)
574+
val values = mapData.valueArray().toArray[Any](mt.valueType)
575+
keys.map(keyConverter).zip(values.map(valueConverter)).toMap
576+
case o => o
577+
}
578+
579+
case (t, st: StructType) if t <:< typeOf[Product] =>
580+
val className = t.erasure.typeSymbol.asClass.fullName
581+
val classTag = if (Utils.classIsLoadable(className)) {
582+
scala.reflect.ClassTag(Utils.classForName(className))
583+
} else {
584+
throw new IllegalArgumentException(s"$className is not loadable")
585+
}
586+
createToProductConverter(classTag, st).asInstanceOf[Any => Any]
587+
588+
case (t, StringType) if t <:< typeOf[String] =>
589+
(catalystValue: Any) => catalystValue match {
590+
case utf8: UTF8String => utf8.toString
591+
case o => o
592+
}
593+
594+
case (t, DateType) if t <:< typeOf[Date] =>
595+
(catalystValue: Any) => catalystValue match {
596+
case i: Int => DateTimeUtils.toJavaDate(i)
597+
case o => o
598+
}
599+
600+
case (t, TimestampType) if t <:< typeOf[Timestamp] =>
601+
(catalystValue: Any) => catalystValue match {
602+
case x: Long => DateTimeUtils.toJavaTimestamp(x)
603+
case o => o
604+
}
605+
606+
case (t, _: DecimalType) if t <:< typeOf[BigDecimal] =>
607+
(catalystValue: Any) => catalystValue match {
608+
case d: Decimal => d.toBigDecimal
609+
case o => o
610+
}
611+
612+
case (t, _: DecimalType) if t <:< typeOf[java.math.BigDecimal] =>
613+
(catalystValue: Any) => catalystValue match {
614+
case d: Decimal => d.toJavaBigDecimal
615+
case o => o
616+
}
617+
618+
// Pass non-string primitives through. (Strings are converted from UTF8Strings above.)
619+
// For everything else, hope for the best.
620+
case (t, o) => identity
621+
}
622+
}
457623
}

0 commit comments

Comments
 (0)