@@ -24,12 +24,14 @@ import java.util.{Map => JavaMap}
2424import  javax .annotation .Nullable 
2525
2626import  scala .language .existentials 
27+ import  scala .reflect .ClassTag 
2728
2829import  org .apache .spark .sql .Row 
2930import  org .apache .spark .sql .catalyst .expressions ._ 
3031import  org .apache .spark .sql .catalyst .util .DateTimeUtils 
3132import  org .apache .spark .sql .types ._ 
3233import  org .apache .spark .unsafe .types .UTF8String 
34+ import  org .apache .spark .util .Utils 
3335
3436/** 
3537 * Functions to convert Scala types to Catalyst types and vice versa. 
@@ -39,6 +41,8 @@ object CatalystTypeConverters {
3941  //  Since the map values can be mutable, we explicitly import scala.collection.Map at here.
4042  import  scala .collection .Map 
4143
44+   lazy  val  universe  =  ScalaReflection .universe
45+ 
4246  private  def  isPrimitive (dataType : DataType ):  Boolean  =  {
4347    dataType match  {
4448      case  BooleanType  =>  true 
@@ -454,4 +458,166 @@ object CatalystTypeConverters {
454458  def  convertToScala (catalystValue : Any , dataType : DataType ):  Any  =  {
455459    createToScalaConverter(dataType)(catalystValue)
456460  }
461+ 
462+   /**  
463+    * Like createToScalaConverter(DataType), creates a function that converts a Catalyst object to a 
464+    * Scala object; however, in this case, the Scala object is an instance of a subtype of Product 
465+    * (e.g. a case class). 
466+    * 
467+    * If the given Scala type is not compatible with the given structType, this method ultimately 
468+    * throws a ClassCastException when the converter is invoked. 
469+    * 
470+    * Typical use case would be converting a collection of rows that have the same schema. You will 
471+    * call this function once to get a converter, and apply it to every row. 
472+    */  
473+   private [sql] def  createToProductConverter [T  <:  Product ](
474+     structType : StructType )(implicit  classTag : ClassTag [T ]):  InternalRow  =>  T  =  {
475+ 
476+     //  Use ScalaReflectionLock, to avoid reflection thread safety issues in 2.10.
477+     //  https://issues.scala-lang.org/browse/SI-6240
478+     //  http://docs.scala-lang.org/overviews/reflection/thread-safety.html
479+     ScalaReflectionLock .synchronized  { createToProductConverter(classTag, structType) }
480+   }
481+ 
482+   private [sql] def  createToProductConverter [T  <:  Product ](
483+     classTag : ClassTag [T ], structType : StructType ):  InternalRow  =>  T  =  {
484+ 
485+     import  universe ._ 
486+ 
487+     val  constructorMirror  =  {
488+       val  mirror  =  runtimeMirror(Utils .getContextOrSparkClassLoader)
489+       val  classSymbol  =  mirror.classSymbol(classTag.runtimeClass)
490+       val  classMirror  =  mirror.reflectClass(classSymbol)
491+       val  constructorSymbol  =  {
492+         //  Adapted from ScalaReflection to find primary constructor.
493+         //  https://issues.apache.org/jira/browse/SPARK-4791
494+         val  symbol  =  classSymbol.toType.declaration(nme.CONSTRUCTOR )
495+         if  (symbol.isMethod) {
496+           symbol.asMethod
497+         } else  {
498+           val  candidateSymbol  = 
499+             symbol.asTerm.alternatives.find { s =>  s.isMethod &&  s.asMethod.isPrimaryConstructor }
500+           if  (candidateSymbol.isDefined) {
501+             candidateSymbol.get.asMethod
502+           } else  {
503+             throw  new  IllegalArgumentException (s " No primary constructor for  ${symbol.name}" )
504+           }
505+         }
506+       }
507+       classMirror.reflectConstructor(constructorSymbol)
508+     }
509+ 
510+     val  params  =  constructorMirror.symbol.paramss.head.toSeq
511+     val  paramTypes  =  params.map { _.asTerm.typeSignature }
512+     val  fields  =  structType.fields
513+     val  dataTypes  =  fields.map { _.dataType }
514+     val  converters :  Seq [Any  =>  Any ] = 
515+       paramTypes.zip(dataTypes).map { case  (pt, dt) =>  createToScalaConverter(pt, dt) }
516+ 
517+     (row : InternalRow ) =>  if  (row ==  null ) {
518+       null .asInstanceOf [T ]
519+     } else  {
520+       val  convertedArgs  = 
521+         converters.zip(row.toSeq(dataTypes)).map { case  (converter, arg) =>  converter(arg) }
522+       try  {
523+         constructorMirror.apply(convertedArgs : _* ).asInstanceOf [T ]
524+       } catch  {
525+         case  e : IllegalArgumentException  =>  //  argument type mismatch
526+           val  message  = 
527+             s """ |Error constructing  ${classTag.runtimeClass.getName}:  ${e.getMessage}; 
528+                 |paramTypes:  ${paramTypes}, dataTypes:  ${dataTypes}, 
529+                 |convertedArgs:  ${convertedArgs}""" .stripMargin.replace(" \n " "  " 
530+           throw  new  ClassCastException (message)
531+       }
532+     }
533+   }
534+ 
535+   /**  
536+    * Like createToScalaConverter(DataType), but with a Scala type hint. 
537+    * 
538+    * Please keep in sync with createToScalaConverter(DataType) and ScalaReflection.schemaFor[T]. 
539+    */  
540+   private [sql] def  createToScalaConverter (
541+     universeType : universe.Type , dataType : DataType ):  Any  =>  Any  =  {
542+ 
543+     import  universe ._ 
544+ 
545+     (universeType, dataType) match  {
546+       case  (t, dt) if  t <:<  typeOf[Option [_]] => 
547+         val  TypeRef (_, _, Seq (elementType)) =  t
548+         val  converter :  Any  =>  Any  =  createToScalaConverter(elementType, dt)
549+           (catalystValue : Any ) =>  Option (converter(catalystValue))
550+ 
551+       case  (t, udt : UserDefinedType [_]) => 
552+         (catalystValue : Any ) =>  if  (catalystValue ==  null ) null  else  udt.deserialize(catalystValue)
553+ 
554+       case  (t, bt : BinaryType ) =>  identity
555+ 
556+       case  (t, at : ArrayType ) if  t <:<  typeOf[Array [_]] => 
557+         throw  new  UnsupportedOperationException (" Array[_] is not supported; try using Seq instead." 
558+ 
559+       case  (t, at : ArrayType ) if  t <:<  typeOf[Seq [_]] => 
560+         val  TypeRef (_, _, Seq (elementType)) =  t
561+         val  converter :  Any  =>  Any  =  createToScalaConverter(elementType, at.elementType)
562+         (catalystValue : Any ) =>  catalystValue match  {
563+           case  arrayData : ArrayData  =>  arrayData.toArray[Any ](at.elementType).map(converter).toSeq
564+           case  o =>  o
565+         }
566+ 
567+       case  (t, mt : MapType ) if  t <:<  typeOf[Map [_, _]] => 
568+         val  TypeRef (_, _, Seq (keyType, valueType)) =  t
569+         val  keyConverter :  Any  =>  Any  =  createToScalaConverter(keyType, mt.keyType)
570+         val  valueConverter :  Any  =>  Any  =  createToScalaConverter(valueType, mt.valueType)
571+         (catalystValue : Any ) =>  catalystValue match  {
572+           case  mapData : MapData  => 
573+             val  keys  =  mapData.keyArray().toArray[Any ](mt.keyType)
574+             val  values  =  mapData.valueArray().toArray[Any ](mt.valueType)
575+             keys.map(keyConverter).zip(values.map(valueConverter)).toMap
576+           case  o =>  o
577+         }
578+ 
579+       case  (t, st : StructType ) if  t <:<  typeOf[Product ] => 
580+         val  className  =  t.erasure.typeSymbol.asClass.fullName
581+         val  classTag  =  if  (Utils .classIsLoadable(className)) {
582+           scala.reflect.ClassTag (Utils .classForName(className))
583+         } else  {
584+           throw  new  IllegalArgumentException (s " $className is not loadable " )
585+         }
586+         createToProductConverter(classTag, st).asInstanceOf [Any  =>  Any ]
587+ 
588+       case  (t, StringType ) if  t <:<  typeOf[String ] => 
589+         (catalystValue : Any ) =>  catalystValue match  {
590+           case  utf8 : UTF8String  =>  utf8.toString
591+           case  o =>  o
592+         }
593+ 
594+       case  (t, DateType ) if  t <:<  typeOf[Date ] => 
595+         (catalystValue : Any ) =>  catalystValue match  {
596+           case  i : Int  =>  DateTimeUtils .toJavaDate(i)
597+           case  o =>  o
598+         }
599+ 
600+       case  (t, TimestampType ) if  t <:<  typeOf[Timestamp ] => 
601+         (catalystValue : Any ) =>  catalystValue match  {
602+           case  x : Long  =>  DateTimeUtils .toJavaTimestamp(x)
603+           case  o =>  o
604+         }
605+ 
606+       case  (t, _ : DecimalType ) if  t <:<  typeOf[BigDecimal ] => 
607+         (catalystValue : Any ) =>  catalystValue match  {
608+           case  d : Decimal  =>  d.toBigDecimal
609+           case  o =>  o
610+         }
611+ 
612+       case  (t, _ : DecimalType ) if  t <:<  typeOf[java.math.BigDecimal ] => 
613+         (catalystValue : Any ) =>  catalystValue match  {
614+           case  d : Decimal  =>  d.toJavaBigDecimal
615+           case  o =>  o
616+         }
617+ 
618+       //  Pass non-string primitives through. (Strings are converted from UTF8Strings above.)
619+       //  For everything else, hope for the best.
620+       case  (t, o) =>  identity
621+     }
622+   }
457623}
0 commit comments