From 5477c8085d4afb434f7f3964b9c6a9fcd5bd1569 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 5 Mar 2017 00:15:47 +0900 Subject: [PATCH 1/8] Add an API to load DataFrame from Dataset[String] storing CSV --- .../apache/spark/sql/DataFrameReader.scala | 75 +++++++++++++++++-- .../datasources/csv/CSVDataSource.scala | 36 +++++---- .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/csv/UnivocityParser.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 23 ++++++ 5 files changed, 114 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 41470ae6aae19..ec86f9cfb9612 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,6 +21,8 @@ import java.util.Properties import scala.collection.JavaConverters._ +import com.univocity.parsers.csv.CsvParser + import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.Partition @@ -29,6 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema @@ -368,14 +371,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { createParser) } - // Check a field requirement for corrupt records here to throw an exception in a driver side - schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) @@ -398,6 +394,53 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { csv(Seq(path): _*) } + /** + * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. + * + * If the schema is not specified using `schema` function and `inferSchema` option is enabled, + * this function goes through the input once to determine the input schema. + * + * If the schema is not specified using `schema` function and `inferSchema` option is disabled, + * it determines the columns as string types and it reads only the first line to determine the + * names and the number of fields. + * + * @param csvDataset input Dataset with one CSV row per record + * @since 2.2.0 + */ + def csv(csvDataset: Dataset[String]): DataFrame = { + val parsedOptions: CSVOptions = new CSVOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone) + val filteredLines = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) + val maybeFirstLine = filteredLines.take(1).headOption + if (maybeFirstLine.isEmpty) { + return sparkSession.emptyDataFrame + } + val firstLine = maybeFirstLine.get + + val schema = userSpecifiedSchema.getOrElse { + TextInputCSVDataSource.inferFromDataset( + sparkSession, + csvDataset, + firstLine, + parsedOptions) + } + + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + + val linesWithoutHeader: RDD[String] = filteredLines.rdd.mapPartitions { iter => + CSVUtils.filterHeaderLine(iter, firstLine, parsedOptions) + } + val parsed = linesWithoutHeader.mapPartitions { iter => + val parser = new UnivocityParser(schema, parsedOptions) + iter.flatMap(line => parser.parse(line)) + } + + Dataset.ofRows( + sparkSession, + LogicalRDD(schema.toAttributes, parsed)(sparkSession)) + } + /** * Loads a CSV file and returns the result as a `DataFrame`. * @@ -604,6 +647,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + private def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 47567032b0195..4328d6447b217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.InputStream import java.nio.charset.{Charset, StandardCharsets} -import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -136,23 +135,32 @@ object TextInputCSVDataSource extends CSVDataSource { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { case Some(firstLine) => - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + Some(inferFromDataset(sparkSession, csv, firstLine, parsedOptions)) case None => // If the first line could not be read, just return the empty schema. Some(StructType(Nil)) } } + def inferFromDataset( + sparkSession: SparkSession, + csv: Dataset[String], + firstLine: String, + parsedOptions: CSVOptions): StructType = { + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + } + private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 50503385ad6d1..0b1e5dac2da66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} -private[csv] class CSVOptions( +class CSVOptions( @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3b3b87e4354d6..e42ea3fa391f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[csv] class UnivocityParser( +class UnivocityParser( schema: StructType, requiredSchema: StructType, private val options: CSVOptions) extends Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index eaedede349134..76ea83595afa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -129,6 +129,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkTypes = true) } + test("simple csv test with string dataset") { + val csvDataset = spark.read.text(testFile(carsFile)).as[String] + val cars = spark.read + .option("header", "true") + .option("inferSchema", "true") + .csv(csvDataset) + + verifyCars(cars, withHeader = true, checkTypes = true) + + val carsWithoutHeader = spark.read + .option("header", "false") + .csv(csvDataset) + + verifyCars(carsWithoutHeader, withHeader = false, checkTypes = false) + } + test("test inferring booleans") { val result = spark.read .format("csv") @@ -1088,4 +1104,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, spark.emptyDataFrame) } } + + test("Empty file produces empty dataframe with empty schema - CSV string dataset") { + val df = spark.read.csv(spark.emptyDataset[String]) + assert(df.schema === spark.emptyDataFrame.schema) + checkAnswer(df, spark.emptyDataFrame) + } + } From 477fc98c7720edaa93b14cb73f5a9da34829bb51 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Mar 2017 19:58:49 +0900 Subject: [PATCH 2/8] Add a comment --- .../spark/sql/execution/datasources/csv/CSVDataSource.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4328d6447b217..242d2de3236c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -142,6 +142,9 @@ object TextInputCSVDataSource extends CSVDataSource { } } + /** + * Infers the schema from `Dataset` that stores CSV string records. + */ def inferFromDataset( sparkSession: SparkSession, csv: Dataset[String], From 92dfdf94993091c8e5072c6f6c85030d8b8664a5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Mar 2017 20:00:11 +0900 Subject: [PATCH 3/8] Remove unused import --- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ec86f9cfb9612..957f547870692 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,8 +21,6 @@ import java.util.Properties import scala.collection.JavaConverters._ -import com.univocity.parsers.csv.CsvParser - import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.Partition From a2739fd07866700b9e62a26d9c3d6fbf9b5d1d5c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Mar 2017 20:44:23 +0900 Subject: [PATCH 4/8] Cleaner --- .../apache/spark/sql/DataFrameReader.scala | 19 +++++---- .../datasources/csv/CSVDataSource.scala | 40 +++++++++---------- .../execution/datasources/csv/CSVSuite.scala | 13 ++++-- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 957f547870692..a753c2dfb50c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -409,26 +409,25 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone) - val filteredLines = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) - val maybeFirstLine = filteredLines.take(1).headOption - if (maybeFirstLine.isEmpty) { - return sparkSession.emptyDataFrame - } - val firstLine = maybeFirstLine.get + val filteredLines: Dataset[String] = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) + val maybeFirstLine: Option[String] = filteredLines.take(1).headOption val schema = userSpecifiedSchema.getOrElse { TextInputCSVDataSource.inferFromDataset( sparkSession, csvDataset, - firstLine, + maybeFirstLine, parsedOptions) } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) - val linesWithoutHeader: RDD[String] = filteredLines.rdd.mapPartitions { iter => - CSVUtils.filterHeaderLine(iter, firstLine, parsedOptions) - } + val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + filteredLines.rdd.mapPartitions { iter => + CSVUtils.filterHeaderLine(iter, firstLine, parsedOptions) + } + }.getOrElse(filteredLines.rdd) + val parsed = linesWithoutHeader.mapPartitions { iter => val parser = new UnivocityParser(schema, parsedOptions) iter.flatMap(line => parser.parse(line)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 242d2de3236c0..35ff924f27ce5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -133,13 +133,8 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { - case Some(firstLine) => - Some(inferFromDataset(sparkSession, csv, firstLine, parsedOptions)) - case None => - // If the first line could not be read, just return the empty schema. - Some(StructType(Nil)) - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)) } /** @@ -148,20 +143,23 @@ object TextInputCSVDataSource extends CSVDataSource { def inferFromDataset( sparkSession: SparkSession, csv: Dataset[String], - firstLine: String, - parsedOptions: CSVOptions): StructType = { - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + maybeFirstLine: Option[String], + parsedOptions: CSVOptions): StructType = maybeFirstLine match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case None => + // If the first line could not be read, just return the empty schema. + StructType(Nil) } private def createBaseDataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 76ea83595afa5..ed6d3903c8e98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1106,9 +1106,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Empty file produces empty dataframe with empty schema - CSV string dataset") { - val df = spark.read.csv(spark.emptyDataset[String]) - assert(df.schema === spark.emptyDataFrame.schema) - checkAnswer(df, spark.emptyDataFrame) + val emptyDF = spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], + StructType(StructField("a", StringType) :: Nil)) + val df = spark.read.schema(emptyDF.schema).csv(spark.emptyDataset[String]) + assert(df.schema === emptyDF.schema) + checkAnswer(df, emptyDF) + + val emptyDFWithoutSchema = spark.read.csv(spark.emptyDataset[String]) + assert(emptyDFWithoutSchema.schema === spark.emptyDataFrame.schema) + checkAnswer(emptyDFWithoutSchema, spark.emptyDataFrame) } } From abae589289b827a74de9ec2ede0ce5ba2c85e981 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Mar 2017 21:05:06 +0900 Subject: [PATCH 5/8] Add some more comments and make it cleaner --- .../org/apache/spark/sql/DataFrameReader.scala | 7 +++---- .../sql/execution/datasources/csv/CSVSuite.scala | 16 +++++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a753c2dfb50c2..a5e38e25b1ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -409,7 +409,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone) - val filteredLines: Dataset[String] = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) + val filteredLines: Dataset[String] = + CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) val maybeFirstLine: Option[String] = filteredLines.take(1).headOption val schema = userSpecifiedSchema.getOrElse { @@ -423,9 +424,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => - filteredLines.rdd.mapPartitions { iter => - CSVUtils.filterHeaderLine(iter, firstLine, parsedOptions) - } + filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) val parsed = linesWithoutHeader.mapPartitions { iter => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index ed6d3903c8e98..3f5e0731a7c8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1106,16 +1106,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Empty file produces empty dataframe with empty schema - CSV string dataset") { + // Empty dataframe with schema. val emptyDF = spark.createDataFrame( spark.sparkContext.emptyRDD[Row], StructType(StructField("a", StringType) :: Nil)) - val df = spark.read.schema(emptyDF.schema).csv(spark.emptyDataset[String]) - assert(df.schema === emptyDF.schema) - checkAnswer(df, emptyDF) - - val emptyDFWithoutSchema = spark.read.csv(spark.emptyDataset[String]) - assert(emptyDFWithoutSchema.schema === spark.emptyDataFrame.schema) - checkAnswer(emptyDFWithoutSchema, spark.emptyDataFrame) + val df1 = spark.read.schema(emptyDF.schema).csv(spark.emptyDataset[String]) + assert(df1.schema === emptyDF.schema) + checkAnswer(df1, emptyDF) + + // Empty dataframe without schema. + val df2 = spark.read.csv(spark.emptyDataset[String]) + assert(df2.schema === spark.emptyDataFrame.schema) + checkAnswer(df2, spark.emptyDataFrame) } } From a14df70f2244462fd53f551f79ff8887950d6883 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Mar 2017 21:06:33 +0900 Subject: [PATCH 6/8] Fix test title --- .../apache/spark/sql/execution/datasources/csv/CSVSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 3f5e0731a7c8e..a903c5342d75e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1105,7 +1105,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty file produces empty dataframe with empty schema - CSV string dataset") { + test("Empty file produces empty dataframe - CSV string dataset") { // Empty dataframe with schema. val emptyDF = spark.createDataFrame( spark.sparkContext.emptyRDD[Row], From a0a79dcf59365a7d0a017b87cb3b90a89a2bdd0e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Mar 2017 21:07:36 +0900 Subject: [PATCH 7/8] Fix title --- .../apache/spark/sql/execution/datasources/csv/CSVSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a903c5342d75e..34c5a514e12ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1105,7 +1105,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty file produces empty dataframe - CSV string dataset") { + test("Empty dataframe produces empty dataframe") { // Empty dataframe with schema. val emptyDF = spark.createDataFrame( spark.sparkContext.emptyRDD[Row], From 3f42c4cc7b93e32cb8d4f2517987097b73e733fd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 8 Mar 2017 18:44:31 +0900 Subject: [PATCH 8/8] Addresss comment on the test --- .../execution/datasources/csv/CSVSuite.scala | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 34c5a514e12ab..4435e4df38ef6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1105,19 +1105,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty dataframe produces empty dataframe") { - // Empty dataframe with schema. - val emptyDF = spark.createDataFrame( - spark.sparkContext.emptyRDD[Row], - StructType(StructField("a", StringType) :: Nil)) - val df1 = spark.read.schema(emptyDF.schema).csv(spark.emptyDataset[String]) - assert(df1.schema === emptyDF.schema) - checkAnswer(df1, emptyDF) - - // Empty dataframe without schema. - val df2 = spark.read.csv(spark.emptyDataset[String]) - assert(df2.schema === spark.emptyDataFrame.schema) - checkAnswer(df2, spark.emptyDataFrame) + test("Empty string dataset produces empty dataframe and keep user-defined schema") { + val df1 = spark.read.csv(spark.emptyDataset[String]) + assert(df1.schema === spark.emptyDataFrame.schema) + checkAnswer(df1, spark.emptyDataFrame) + + val schema = StructType(StructField("a", StringType) :: Nil) + val df2 = spark.read.schema(schema).csv(spark.emptyDataset[String]) + assert(df2.schema === schema) } }