@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._
2929import org .apache .spark .sql .catalyst .expressions .codegen .{GenerateSafeProjection , GenerateUnsafeProjection , GenerateUnsafeRowJoiner }
3030import org .apache .spark .sql .catalyst .plans .physical ._
3131import org .apache .spark .sql .execution .python .EvaluatePython
32- import org .apache .spark .sql .types .{ BinaryType , ObjectType , StructType }
32+ import org .apache .spark .sql .types .ObjectType
3333
3434/**
3535 * Helper functions for physical operators that work with user defined objects.
@@ -81,34 +81,22 @@ case class PythonMapPartitions(
8181
8282 override def expressions : Seq [Expression ] = Nil
8383
84- private def isPickled (schema : StructType ): Boolean = {
85- schema.length == 1 && schema.head.dataType == BinaryType &&
86- schema.head.metadata.contains(" pickled" )
87- }
88-
8984 override protected def doExecute (): RDD [InternalRow ] = {
9085 val inputRDD = child.execute().map(_.copy())
9186 val bufferSize = inputRDD.conf.getInt(" spark.buffer.size" , 65536 )
9287 val reuseWorker = inputRDD.conf.getBoolean(" spark.python.worker.reuse" , defaultValue = true )
93- val childIsPickled = isPickled(child.schema)
94- val outputIsPickled = isPickled(schema)
9588
9689 inputRDD.mapPartitions { iter =>
97- val inputIterator = if (childIsPickled) {
98- iter.map(_.getBinary(0 ))
99- } else {
100- EvaluatePython .registerPicklers() // register pickler for Row
101-
102- val pickle = new Pickler
103-
104- // Input iterator to Python: input rows are grouped so we send them in batches to Python.
105- // For each row, add it to the queue.
106- iter.grouped(100 ).map { inputRows =>
107- val toBePickled = inputRows.map { row =>
108- EvaluatePython .toJava(row, child.schema)
109- }.toArray
110- pickle.dumps(toBePickled)
111- }
90+ EvaluatePython .registerPicklers() // register pickler for Row
91+ val pickle = new Pickler
92+
93+ // Input iterator to Python: input rows are grouped so we send them in batches to Python.
94+ // For each row, add it to the queue.
95+ val inputIterator = iter.grouped(100 ).map { inputRows =>
96+ val toBePickled = inputRows.map { row =>
97+ EvaluatePython .toJava(row, child.schema)
98+ }.toArray
99+ pickle.dumps(toBePickled)
112100 }
113101
114102 val context = TaskContext .get()
@@ -127,22 +115,14 @@ case class PythonMapPartitions(
127115 reuseWorker
128116 ).compute(inputIterator, context.partitionId(), context)
129117
130- val resultProj = UnsafeProjection .create(output, output)
131-
132- if (outputIsPickled) {
133- val row = new GenericMutableRow (1 )
134- outputIterator.map { bytes =>
135- row(0 ) = bytes
136- resultProj(row)
137- }
138- } else {
139- val unpickle = new Unpickler
140- outputIterator.flatMap { pickedResult =>
141- val unpickledBatch = unpickle.loads(pickedResult)
142- unpickledBatch.asInstanceOf [java.util.ArrayList [Any ]].asScala
143- }.map { result =>
144- resultProj(EvaluatePython .fromJava(result, schema).asInstanceOf [InternalRow ])
145- }
118+ val unpickle = new Unpickler
119+ val toUnsafe = UnsafeProjection .create(output, output)
120+
121+ outputIterator.flatMap { pickedResult =>
122+ val unpickledBatch = unpickle.loads(pickedResult)
123+ unpickledBatch.asInstanceOf [java.util.ArrayList [Any ]].asScala
124+ }.map { result =>
125+ toUnsafe(EvaluatePython .fromJava(result, schema).asInstanceOf [InternalRow ])
146126 }
147127 }
148128 }
0 commit comments