Skip to content

Commit 383c555

Browse files
chenghao-intelmarmbrus
authored andcommitted
[SPARK-4785][SQL] Initilize Hive UDFs on the driver and serialize them with a wrapper
Different from Hive 0.12.0, in Hive 0.13.1 UDF/UDAF/UDTF (aka Hive function) objects should only be initialized once on the driver side and then serialized to executors. However, not all function objects are serializable (e.g. GenericUDF doesn't implement Serializable). Hive 0.13.1 solves this issue with Kryo or XML serializer. Several utility ser/de methods are provided in class o.a.h.h.q.e.Utilities for this purpose. In this PR we chose Kryo for efficiency. The Kryo serializer used here is created in Hive. Spark Kryo serializer wasn't used because there's no available SparkConf instance. Author: Cheng Hao <[email protected]> Author: Cheng Lian <[email protected]> Closes #3640 from chenghao-intel/udf_serde and squashes the following commits: 8e13756 [Cheng Hao] Update the comment 74466a3 [Cheng Hao] refactor as feedbacks 396c0e1 [Cheng Hao] avoid Simple UDF to be serialized e9c3212 [Cheng Hao] update the comment 19cbd46 [Cheng Hao] support udf instance ser/de after initialization
1 parent bcb5cda commit 383c555

File tree

5 files changed

+173
-50
lines changed

5 files changed

+173
-50
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,10 @@ private[hive] object HiveQl {
11281128
Explode(attributes, nodeToExpr(child))
11291129

11301130
case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
1131-
HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
1131+
HiveGenericUdtf(
1132+
new HiveFunctionWrapper(functionName),
1133+
attributes,
1134+
children.map(nodeToExpr))
11321135

11331136
case a: ASTNode =>
11341137
throw new NotImplementedError(

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -54,46 +54,31 @@ private[hive] abstract class HiveFunctionRegistry
5454
val functionClassName = functionInfo.getFunctionClass.getName
5555

5656
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
57-
HiveSimpleUdf(functionClassName, children)
57+
HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
5858
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
59-
HiveGenericUdf(functionClassName, children)
59+
HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
6060
} else if (
6161
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
62-
HiveGenericUdaf(functionClassName, children)
62+
HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
6363
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
64-
HiveUdaf(functionClassName, children)
64+
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
6565
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
66-
HiveGenericUdtf(functionClassName, Nil, children)
66+
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children)
6767
} else {
6868
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
6969
}
7070
}
7171
}
7272

73-
private[hive] trait HiveFunctionFactory {
74-
val functionClassName: String
75-
76-
def createFunction[UDFType]() =
77-
getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
78-
}
79-
80-
private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
81-
self: Product =>
82-
83-
type UDFType
73+
private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
74+
extends Expression with HiveInspectors with Logging {
8475
type EvaluatedType = Any
76+
type UDFType = UDF
8577

8678
def nullable = true
8779

88-
lazy val function = createFunction[UDFType]()
89-
90-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
91-
}
92-
93-
private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
94-
extends HiveUdf with HiveInspectors {
95-
96-
type UDFType = UDF
80+
@transient
81+
lazy val function = funcWrapper.createFunction[UDFType]()
9782

9883
@transient
9984
protected lazy val method =
@@ -131,6 +116,8 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
131116
.convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
132117
returnInspector)
133118
}
119+
120+
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
134121
}
135122

136123
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
@@ -144,16 +131,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
144131
override def get(): AnyRef = wrap(func(), oi)
145132
}
146133

147-
private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
148-
extends HiveUdf with HiveInspectors {
134+
private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
135+
extends Expression with HiveInspectors with Logging {
149136
type UDFType = GenericUDF
137+
type EvaluatedType = Any
138+
139+
def nullable = true
140+
141+
@transient
142+
lazy val function = funcWrapper.createFunction[UDFType]()
150143

151144
@transient
152145
protected lazy val argumentInspectors = children.map(toInspector)
153146

154147
@transient
155-
protected lazy val returnInspector =
148+
protected lazy val returnInspector = {
156149
function.initializeAndFoldConstants(argumentInspectors.toArray)
150+
}
157151

158152
@transient
159153
protected lazy val isUDFDeterministic = {
@@ -183,18 +177,19 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
183177
}
184178
unwrap(function.evaluate(deferedObjects), returnInspector)
185179
}
180+
181+
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
186182
}
187183

188184
private[hive] case class HiveGenericUdaf(
189-
functionClassName: String,
185+
funcWrapper: HiveFunctionWrapper,
190186
children: Seq[Expression]) extends AggregateExpression
191-
with HiveInspectors
192-
with HiveFunctionFactory {
187+
with HiveInspectors {
193188

194189
type UDFType = AbstractGenericUDAFResolver
195190

196191
@transient
197-
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
192+
protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
198193

199194
@transient
200195
protected lazy val objectInspector = {
@@ -209,22 +204,22 @@ private[hive] case class HiveGenericUdaf(
209204

210205
def nullable: Boolean = true
211206

212-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
207+
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
213208

214-
def newInstance() = new HiveUdafFunction(functionClassName, children, this)
209+
def newInstance() = new HiveUdafFunction(funcWrapper, children, this)
215210
}
216211

217212
/** It is used as a wrapper for the hive functions which uses UDAF interface */
218213
private[hive] case class HiveUdaf(
219-
functionClassName: String,
214+
funcWrapper: HiveFunctionWrapper,
220215
children: Seq[Expression]) extends AggregateExpression
221-
with HiveInspectors
222-
with HiveFunctionFactory {
216+
with HiveInspectors {
223217

224218
type UDFType = UDAF
225219

226220
@transient
227-
protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction())
221+
protected lazy val resolver: AbstractGenericUDAFResolver =
222+
new GenericUDAFBridge(funcWrapper.createFunction())
228223

229224
@transient
230225
protected lazy val objectInspector = {
@@ -239,10 +234,10 @@ private[hive] case class HiveUdaf(
239234

240235
def nullable: Boolean = true
241236

242-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
237+
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
243238

244239
def newInstance() =
245-
new HiveUdafFunction(functionClassName, children, this, true)
240+
new HiveUdafFunction(funcWrapper, children, this, true)
246241
}
247242

248243
/**
@@ -257,13 +252,13 @@ private[hive] case class HiveUdaf(
257252
* user defined aggregations, which have clean semantics even in a partitioned execution.
258253
*/
259254
private[hive] case class HiveGenericUdtf(
260-
functionClassName: String,
255+
funcWrapper: HiveFunctionWrapper,
261256
aliasNames: Seq[String],
262257
children: Seq[Expression])
263-
extends Generator with HiveInspectors with HiveFunctionFactory {
258+
extends Generator with HiveInspectors {
264259

265260
@transient
266-
protected lazy val function: GenericUDTF = createFunction()
261+
protected lazy val function: GenericUDTF = funcWrapper.createFunction()
267262

268263
@transient
269264
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
@@ -320,25 +315,24 @@ private[hive] case class HiveGenericUdtf(
320315
}
321316
}
322317

323-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
318+
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
324319
}
325320

326321
private[hive] case class HiveUdafFunction(
327-
functionClassName: String,
322+
funcWrapper: HiveFunctionWrapper,
328323
exprs: Seq[Expression],
329324
base: AggregateExpression,
330325
isUDAFBridgeRequired: Boolean = false)
331326
extends AggregateFunction
332-
with HiveInspectors
333-
with HiveFunctionFactory {
327+
with HiveInspectors {
334328

335329
def this() = this(null, null, null)
336330

337331
private val resolver =
338332
if (isUDAFBridgeRequired) {
339-
new GenericUDAFBridge(createFunction[UDAF]())
333+
new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
340334
} else {
341-
createFunction[AbstractGenericUDAFResolver]()
335+
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
342336
}
343337

344338
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
@@ -361,3 +355,4 @@ private[hive] case class HiveUdafFunction(
361355
function.iterate(buffer, inputs)
362356
}
363357
}
358+

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ class HiveUdfSuite extends QueryTest {
6060
| getStruct(1).f5 FROM src LIMIT 1
6161
""".stripMargin).first() === Row(1, 2, 3, 4, 5))
6262
}
63+
64+
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
65+
checkAnswer(
66+
sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
67+
8
68+
)
69+
}
6370

6471
test("hive struct udf") {
6572
sql(

sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ import scala.language.implicitConversions
4343

4444
import org.apache.spark.sql.catalyst.types.DecimalType
4545

46+
class HiveFunctionWrapper(var functionClassName: String) extends java.io.Serializable {
47+
// for Serialization
48+
def this() = this(null)
49+
50+
import org.apache.spark.util.Utils._
51+
def createFunction[UDFType <: AnyRef](): UDFType = {
52+
getContextOrSparkClassLoader
53+
.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
54+
}
55+
}
56+
4657
/**
4758
* A compatibility layer for interacting with Hive version 0.12.0.
4859
*/

sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
1919

2020
import java.util.{ArrayList => JArrayList}
2121
import java.util.Properties
22+
2223
import org.apache.hadoop.conf.Configuration
2324
import org.apache.hadoop.fs.Path
2425
import org.apache.hadoop.mapred.InputFormat
@@ -42,6 +43,112 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
4243
import scala.collection.JavaConversions._
4344
import scala.language.implicitConversions
4445

46+
47+
/**
48+
* This class provides the UDF creation and also the UDF instance serialization and
49+
* de-serialization cross process boundary.
50+
*
51+
* Detail discussion can be found at https://github.com/apache/spark/pull/3640
52+
*
53+
* @param functionClassName UDF class name
54+
*/
55+
class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable {
56+
// for Serialization
57+
def this() = this(null)
58+
59+
import java.io.{OutputStream, InputStream}
60+
import com.esotericsoftware.kryo.Kryo
61+
import org.apache.spark.util.Utils._
62+
import org.apache.hadoop.hive.ql.exec.Utilities
63+
import org.apache.hadoop.hive.ql.exec.UDF
64+
65+
@transient
66+
private val methodDeSerialize = {
67+
val method = classOf[Utilities].getDeclaredMethod(
68+
"deserializeObjectByKryo",
69+
classOf[Kryo],
70+
classOf[InputStream],
71+
classOf[Class[_]])
72+
method.setAccessible(true)
73+
74+
method
75+
}
76+
77+
@transient
78+
private val methodSerialize = {
79+
val method = classOf[Utilities].getDeclaredMethod(
80+
"serializeObjectByKryo",
81+
classOf[Kryo],
82+
classOf[Object],
83+
classOf[OutputStream])
84+
method.setAccessible(true)
85+
86+
method
87+
}
88+
89+
def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
90+
methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz)
91+
.asInstanceOf[UDFType]
92+
}
93+
94+
def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
95+
methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out)
96+
}
97+
98+
private var instance: AnyRef = null
99+
100+
def writeExternal(out: java.io.ObjectOutput) {
101+
// output the function name
102+
out.writeUTF(functionClassName)
103+
104+
// Write a flag if instance is null or not
105+
out.writeBoolean(instance != null)
106+
if (instance != null) {
107+
// Some of the UDF are serializable, but some others are not
108+
// Hive Utilities can handle both cases
109+
val baos = new java.io.ByteArrayOutputStream()
110+
serializePlan(instance, baos)
111+
val functionInBytes = baos.toByteArray
112+
113+
// output the function bytes
114+
out.writeInt(functionInBytes.length)
115+
out.write(functionInBytes, 0, functionInBytes.length)
116+
}
117+
}
118+
119+
def readExternal(in: java.io.ObjectInput) {
120+
// read the function name
121+
functionClassName = in.readUTF()
122+
123+
if (in.readBoolean()) {
124+
// if the instance is not null
125+
// read the function in bytes
126+
val functionInBytesLength = in.readInt()
127+
val functionInBytes = new Array[Byte](functionInBytesLength)
128+
in.read(functionInBytes, 0, functionInBytesLength)
129+
130+
// deserialize the function object via Hive Utilities
131+
instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
132+
getContextOrSparkClassLoader.loadClass(functionClassName))
133+
}
134+
}
135+
136+
def createFunction[UDFType <: AnyRef](): UDFType = {
137+
if (instance != null) {
138+
instance.asInstanceOf[UDFType]
139+
} else {
140+
val func = getContextOrSparkClassLoader
141+
.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
142+
if (!func.isInstanceOf[UDF]) {
143+
// We cache the function if it's no the Simple UDF,
144+
// as we always have to create new instance for Simple UDF
145+
instance = func
146+
}
147+
func
148+
}
149+
}
150+
}
151+
45152
/**
46153
* A compatibility layer for interacting with Hive version 0.13.1.
47154
*/

0 commit comments

Comments
 (0)