Skip to content

Commit 7525ce9

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-20431][SS][FOLLOWUP] Specify a schema by using a DDL-formatted string in DataStreamReader
## What changes were proposed in this pull request? This pr supported a DDL-formatted string in `DataStreamReader.schema`. This fix could make users easily define a schema without importing the type classes. For example, ```scala scala> spark.readStream.schema("col0 INT, col1 DOUBLE").load("/tmp/abc").printSchema() root |-- col0: integer (nullable = true) |-- col1: double (nullable = true) ``` ## How was this patch tested? Added tests in `DataStreamReaderWriterSuite`. Author: hyukjinkwon <[email protected]> Closes #18373 from HyukjinKwon/SPARK-20431.
1 parent 03eb611 commit 7525ce9

File tree

4 files changed

+42
-8
lines changed

4 files changed

+42
-8
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def schema(self, schema):
9898
9999
:param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string
100100
(For example ``col0 INT, col1 DOUBLE``).
101+
102+
>>> s = spark.read.schema("col0 INT, col1 DOUBLE")
101103
"""
102104
from pyspark.sql import SparkSession
103105
spark = SparkSession.builder.getOrCreate()

python/pyspark/sql/streaming.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -319,16 +319,21 @@ def schema(self, schema):
319319
320320
.. note:: Evolving.
321321
322-
:param schema: a :class:`pyspark.sql.types.StructType` object
322+
:param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string
323+
(For example ``col0 INT, col1 DOUBLE``).
323324
324325
>>> s = spark.readStream.schema(sdf_schema)
326+
>>> s = spark.readStream.schema("col0 INT, col1 DOUBLE")
325327
"""
326328
from pyspark.sql import SparkSession
327-
if not isinstance(schema, StructType):
328-
raise TypeError("schema should be StructType")
329329
spark = SparkSession.builder.getOrCreate()
330-
jschema = spark._jsparkSession.parseDataType(schema.json())
331-
self._jreader = self._jreader.schema(jschema)
330+
if isinstance(schema, StructType):
331+
jschema = spark._jsparkSession.parseDataType(schema.json())
332+
self._jreader = self._jreader.schema(jschema)
333+
elif isinstance(schema, basestring):
334+
self._jreader = self._jreader.schema(schema)
335+
else:
336+
raise TypeError("schema should be StructType or string")
332337
return self
333338

334339
@since(2.0)
@@ -372,7 +377,8 @@ def load(self, path=None, format=None, schema=None, **options):
372377
373378
:param path: optional string for file-system backed data sources.
374379
:param format: optional string for format of the data source. Default to 'parquet'.
375-
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema.
380+
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema
381+
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
376382
:param options: all other string options
377383
378384
>>> json_sdf = spark.readStream.format("json") \\
@@ -415,7 +421,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
415421
416422
:param path: string represents path to the JSON dataset,
417423
or RDD of Strings storing JSON objects.
418-
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
424+
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
425+
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
419426
:param primitivesAsString: infers all primitive values as a string type. If None is set,
420427
it uses the default value, ``false``.
421428
:param prefersDecimal: infers all floating-point values as a decimal type. If the values
@@ -542,7 +549,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
542549
.. note:: Evolving.
543550
544551
:param path: string, or list of strings, for input path(s).
545-
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
552+
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
553+
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
546554
:param sep: sets the single character as a separator for each field and value.
547555
If None is set, it uses the default value, ``,``.
548556
:param encoding: decodes the CSV files by the given encoding type. If None is set,

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
5959
this
6060
}
6161

62+
/**
63+
* Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can
64+
* infer the input schema automatically from data. By specifying the schema here, the underlying
65+
* data source can skip the schema inference step, and thus speed up data loading.
66+
*
67+
* @since 2.3.0
68+
*/
69+
def schema(schemaString: String): DataStreamReader = {
70+
this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString))
71+
this
72+
}
73+
6274
/**
6375
* Adds an input option for the underlying data source.
6476
*

sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,4 +663,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
663663
}
664664
assert(fs.exists(checkpointDir))
665665
}
666+
667+
test("SPARK-20431: Specify a schema by using a DDL-formatted string") {
668+
spark.readStream
669+
.format("org.apache.spark.sql.streaming.test")
670+
.schema("aa INT")
671+
.load()
672+
673+
assert(LastOptions.schema.isDefined)
674+
assert(LastOptions.schema.get === StructType(StructField("aa", IntegerType) :: Nil))
675+
676+
LastOptions.clear()
677+
}
666678
}

0 commit comments

Comments
 (0)