Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def schema(self, schema):

:param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string
(For example ``col0 INT, col1 DOUBLE``).

>>> s = spark.read.schema("col0 INT, col1 DOUBLE")
"""
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
Expand Down
24 changes: 16 additions & 8 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,21 @@ def schema(self, schema):

.. note:: Evolving.

:param schema: a :class:`pyspark.sql.types.StructType` object
:param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string
(For example ``col0 INT, col1 DOUBLE``).

>>> s = spark.readStream.schema(sdf_schema)
>>> s = spark.readStream.schema("col0 INT, col1 DOUBLE")
"""
from pyspark.sql import SparkSession
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
spark = SparkSession.builder.getOrCreate()
jschema = spark._jsparkSession.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
if isinstance(schema, StructType):
jschema = spark._jsparkSession.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
elif isinstance(schema, basestring):
self._jreader = self._jreader.schema(schema)
else:
raise TypeError("schema should be StructType or string")
return self

@since(2.0)
Expand Down Expand Up @@ -372,7 +377,8 @@ def load(self, path=None, format=None, schema=None, **options):

:param path: optional string for file-system backed data sources.
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param options: all other string options

>>> json_sdf = spark.readStream.format("json") \\
Expand Down Expand Up @@ -415,7 +421,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,

:param path: string represents path to the JSON dataset,
or RDD of Strings storing JSON objects.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param primitivesAsString: infers all primitive values as a string type. If None is set,
it uses the default value, ``false``.
:param prefersDecimal: infers all floating-point values as a decimal type. If the values
Expand Down Expand Up @@ -542,7 +549,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
.. note:: Evolving.

:param path: string, or list of strings, for input path(s).
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param sep: sets the single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
this
}

/**
* Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can
* infer the input schema automatically from data. By specifying the schema here, the underlying
* data source can skip the schema inference step, and thus speed up data loading.
*
* @since 2.3.0
*/
def schema(schemaString: String): DataStreamReader = {
this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString))
this
}

/**
* Adds an input option for the underlying data source.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,4 +663,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
}
assert(fs.exists(checkpointDir))
}

test("SPARK-20431: Specify a schema by using a DDL-formatted string") {
spark.readStream
.format("org.apache.spark.sql.streaming.test")
.schema("aa INT")
.load()

assert(LastOptions.schema.isDefined)
assert(LastOptions.schema.get === StructType(StructField("aa", IntegerType) :: Nil))

LastOptions.clear()
}
}