@@ -26,8 +26,7 @@ import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
2626import org .apache .hadoop .hive .serde .serdeConstants
2727import org .apache .hadoop .hive .serde2 .objectinspector .ObjectInspectorUtils .ObjectInspectorCopyOption
2828import org .apache .hadoop .hive .serde2 .objectinspector ._
29- import org .apache .hadoop .hive .serde2 .objectinspector .primitive .JavaHiveDecimalObjectInspector
30- import org .apache .hadoop .hive .serde2 .objectinspector .primitive .JavaHiveVarcharObjectInspector
29+ import org .apache .hadoop .hive .serde2 .objectinspector .primitive ._
3130import org .apache .hadoop .hive .serde2 .typeinfo .TypeInfoUtils
3231import org .apache .hadoop .hive .serde2 .{ColumnProjectionUtils , Serializer }
3332import org .apache .hadoop .io .Writable
@@ -95,29 +94,34 @@ case class HiveTableScan(
9594 attributes.map { a =>
9695 val ordinal = relation.partitionKeys.indexOf(a)
9796 if (ordinal >= 0 ) {
97+ val dataType = relation.partitionKeys(ordinal).dataType
9898 (_ : Any , partitionKeys : Array [String ]) => {
99- val value = partitionKeys(ordinal)
100- val dataType = relation.partitionKeys(ordinal).dataType
101- unwrapHiveData(castFromString(value, dataType))
99+ castFromString(partitionKeys(ordinal), dataType)
102100 }
103101 } else {
104102 val ref = objectInspector.getAllStructFieldRefs
105103 .find(_.getFieldName == a.name)
106104 .getOrElse(sys.error(s " Can't find attribute $a" ))
105+ val fieldObjectInspector = ref.getFieldObjectInspector
106+
107+ val unwrapHiveData = fieldObjectInspector match {
108+ case _ : HiveVarcharObjectInspector =>
109+ (value : Any ) => value.asInstanceOf [HiveVarchar ].getValue
110+ case _ : HiveDecimalObjectInspector =>
111+ (value : Any ) => BigDecimal (value.asInstanceOf [HiveDecimal ].bigDecimalValue())
112+ case _ =>
113+ identity[Any ] _
114+ }
115+
107116 (row : Any , _ : Array [String ]) => {
108117 val data = objectInspector.getStructFieldData(row, ref)
109- unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector))
118+ val hiveData = unwrapData(data, fieldObjectInspector)
119+ if (hiveData != null ) unwrapHiveData(hiveData) else null
110120 }
111121 }
112122 }
113123 }
114124
115- private def unwrapHiveData (value : Any ) = value match {
116- case varchar : HiveVarchar => varchar.getValue
117- case decimal : HiveDecimal => BigDecimal (decimal.bigDecimalValue)
118- case other => other
119- }
120-
121125 private def castFromString (value : String , dataType : DataType ) = {
122126 Cast (Literal (value), dataType).eval(null )
123127 }
0 commit comments