Skip to content

Commit 4dfe604

Browse files
committed
infer schema
1 parent d96f103 commit 4dfe604

File tree

3 files changed

+40
-47
lines changed

3 files changed

+40
-47
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ def __init__(self, prev, func, output_schema=None):
13741374
self._lazy_rdd = None
13751375

13761376
if output_schema is not None:
1377-
# This transformation is applying schema, just copy member variables from prev.
1377+
# This transformation is adding schema, just copy member variables from prev.
13781378
self.func = func
13791379
self._prev_jdf = prev._prev_jdf
13801380
elif not isinstance(prev, PipelinedDataFrame) or not prev.is_cached:
@@ -1385,16 +1385,22 @@ def __init__(self, prev, func, output_schema=None):
13851385
self.func = _pipeline_func(prev.func, func)
13861386
self._prev_jdf = prev._prev_jdf # maintain the pipeline
13871387

1388-
def applySchema(self, schema):
1388+
def schema(self, schema):
13891389
return PipelinedDataFrame(self, self.func, schema)
13901390

13911391
@property
13921392
def _jdf(self):
1393+
from pyspark.sql.types import _infer_type, _merge_type
1394+
13931395
if self._jdf_val is None:
13941396
if self.output_schema is None:
1395-
schema = StructType().add("binary", BinaryType(), False, {"pickled": True})
1396-
final_func = self.func
1397-
elif isinstance(self.output_schema, StructType):
1397+
# If no schema is specified, infer it from the whole data set.
1398+
jrdd = self._prev_jdf.javaToPython()
1399+
rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer()))
1400+
func = self.func # assign to a local varible to avoid referencing self in closure.
1401+
self.output_schema = rdd.mapPartitions(func).map(_infer_type).reduce(_merge_type)
1402+
1403+
if isinstance(self.output_schema, StructType):
13981404
schema = self.output_schema
13991405
to_row = lambda iterator: map(schema.toInternal, iterator)
14001406
final_func = _pipeline_func(self.func, to_row)

python/pyspark/sql/tests.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,24 +1159,31 @@ def test_dataset(self):
11591159
func = lambda row: {"key": row.key + 1, "value": row.value} # convert row to python dict
11601160
ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator))
11611161
schema = StructType().add("key", IntegerType()).add("value", StringType())
1162-
ds3 = ds2.applySchema(schema)
1162+
ds3 = ds2.schema(schema)
11631163
result = ds3.select("key").collect()
11641164
self.assertEqual(result[0][0], 2)
11651165
self.assertEqual(result[1][0], 3)
11661166

11671167
schema = StructType().add("value", StringType()) # use a different but compatible schema
1168-
ds3 = ds2.applySchema(schema)
1168+
ds3 = ds2.schema(schema)
11691169
result = ds3.collect()
11701170
self.assertEqual(result[0][0], "1")
11711171
self.assertEqual(result[1][0], "2")
11721172

11731173
func = lambda row: row.key * 3
11741174
ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator))
1175-
ds3 = ds2.applySchema(IntegerType()) # use a flat schema
1175+
ds3 = ds2.schema(IntegerType()) # use a flat schema
11761176
result = ds3.collect()
11771177
self.assertEqual(result[0][0], 3)
11781178
self.assertEqual(result[1][0], 6)
11791179

1180+
result = ds2.collect() # schema can be inferred automatically
1181+
self.assertEqual(result[0][0], 3)
1182+
self.assertEqual(result[1][0], 6)
1183+
1184+
# row count should be corrected even no schema is specified.
1185+
self.assertEqual(ds2.count(), 2)
1186+
11801187

11811188
class HiveContextSQLTests(ReusedPySparkTestCase):
11821189

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
3030
import org.apache.spark.sql.catalyst.plans.physical._
3131
import org.apache.spark.sql.execution.python.EvaluatePython
32-
import org.apache.spark.sql.types.{BinaryType, ObjectType, StructType}
32+
import org.apache.spark.sql.types.ObjectType
3333

3434
/**
3535
* Helper functions for physical operators that work with user defined objects.
@@ -81,34 +81,22 @@ case class PythonMapPartitions(
8181

8282
override def expressions: Seq[Expression] = Nil
8383

84-
private def isPickled(schema: StructType): Boolean = {
85-
schema.length == 1 && schema.head.dataType == BinaryType &&
86-
schema.head.metadata.contains("pickled")
87-
}
88-
8984
override protected def doExecute(): RDD[InternalRow] = {
9085
val inputRDD = child.execute().map(_.copy())
9186
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
9287
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
93-
val childIsPickled = isPickled(child.schema)
94-
val outputIsPickled = isPickled(schema)
9588

9689
inputRDD.mapPartitions { iter =>
97-
val inputIterator = if (childIsPickled) {
98-
iter.map(_.getBinary(0))
99-
} else {
100-
EvaluatePython.registerPicklers() // register pickler for Row
101-
102-
val pickle = new Pickler
103-
104-
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
105-
// For each row, add it to the queue.
106-
iter.grouped(100).map { inputRows =>
107-
val toBePickled = inputRows.map { row =>
108-
EvaluatePython.toJava(row, child.schema)
109-
}.toArray
110-
pickle.dumps(toBePickled)
111-
}
90+
EvaluatePython.registerPicklers() // register pickler for Row
91+
val pickle = new Pickler
92+
93+
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
94+
// For each row, add it to the queue.
95+
val inputIterator = iter.grouped(100).map { inputRows =>
96+
val toBePickled = inputRows.map { row =>
97+
EvaluatePython.toJava(row, child.schema)
98+
}.toArray
99+
pickle.dumps(toBePickled)
112100
}
113101

114102
val context = TaskContext.get()
@@ -127,22 +115,14 @@ case class PythonMapPartitions(
127115
reuseWorker
128116
).compute(inputIterator, context.partitionId(), context)
129117

130-
val resultProj = UnsafeProjection.create(output, output)
131-
132-
if (outputIsPickled) {
133-
val row = new GenericMutableRow(1)
134-
outputIterator.map { bytes =>
135-
row(0) = bytes
136-
resultProj(row)
137-
}
138-
} else {
139-
val unpickle = new Unpickler
140-
outputIterator.flatMap { pickedResult =>
141-
val unpickledBatch = unpickle.loads(pickedResult)
142-
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
143-
}.map { result =>
144-
resultProj(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow])
145-
}
118+
val unpickle = new Unpickler
119+
val toUnsafe = UnsafeProjection.create(output, output)
120+
121+
outputIterator.flatMap { pickedResult =>
122+
val unpickledBatch = unpickle.loads(pickedResult)
123+
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
124+
}.map { result =>
125+
toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow])
146126
}
147127
}
148128
}

0 commit comments

Comments
 (0)