-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-23786][SQL] Checking column names of csv headers #20894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
112ce2d
a85ccce
75e1534
8eb45b8
9b1a986
6442633
9440d8a
9f91ce7
0878f7a
a341dd7
98c27ea
811df6f
691cfbc
efb0105
c9f5e14
e195838
d6d370d
acd6d2e
13892fd
476b517
f8167e4
d068f6c
08cfcf4
f6a1694
adbedf3
0904daf
191b415
718f7ca
75c1ce6
ab9c514
0405863
714c66d
78d9f66
b43a7c7
a5f2916
9b2d403
1fffc16
ad6cda4
4bdabe2
2bd2713
b4bfd1d
21f8b10
aca4db9
e3b4275
d704766
04199e0
d5fde52
795a878
05fc7cd
9606711
11c7591
7dce1e7
c008328
9f7c440
e83ad60
26ae4f9
4b6495b
c5ee207
a2cbb7b
70e2b75
e7c3ace
3b37712
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2974,6 +2974,21 @@ def test_create_dateframe_from_pandas_with_dst(self): | |
| os.environ['TZ'] = orig_env_tz | ||
| time.tzset() | ||
|
|
||
| def test_checking_csv_header(self): | ||
| tmpPath = tempfile.mkdtemp() | ||
|
||
| shutil.rmtree(tmpPath) | ||
| self.spark.createDataFrame([[1, 1000], [2000, 2]]).\ | ||
| toDF('f1', 'f2').write.option("header", "true").csv(tmpPath) | ||
|
||
| schema = StructType([ | ||
| StructField('f2', IntegerType(), nullable=True), | ||
| StructField('f1', IntegerType(), nullable=True)]) | ||
| df = self.spark.read.option('header', 'true').schema(schema).csv(tmpPath) | ||
| self.assertRaisesRegexp( | ||
| Exception, | ||
| "Fields in the header of csv file are not matched to field names of the schema", | ||
| lambda: df.collect()) | ||
| shutil.rmtree(tmpPath) | ||
|
|
||
|
|
||
| class HiveSparkSubmitTests(SparkSubmitTests): | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,9 @@ abstract class CSVDataSource extends Serializable { | |
| conf: Configuration, | ||
| file: PartitionedFile, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] | ||
| schema: StructType, // Schema of projection | ||
| dataSchema: StructType // Schema of data in csv files | ||
| ): Iterator[InternalRow] | ||
|
||
|
|
||
| /** | ||
| * Infers the schema from `inputPaths` files. | ||
|
|
@@ -127,7 +129,8 @@ object TextInputCSVDataSource extends CSVDataSource { | |
| conf: Configuration, | ||
| file: PartitionedFile, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] = { | ||
| schema: StructType, | ||
| dataSchema: StructType): Iterator[InternalRow] = { | ||
| val lines = { | ||
| val linesReader = new HadoopFileLinesReader(file, conf) | ||
| Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) | ||
|
|
@@ -136,8 +139,22 @@ object TextInputCSVDataSource extends CSVDataSource { | |
| } | ||
| } | ||
|
|
||
| val shouldDropHeader = parser.options.headerFlag && file.start == 0 | ||
| UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) | ||
| val hasHeader = parser.options.headerFlag && file.start == 0 | ||
| if (hasHeader) { | ||
| // Checking that column names in the header are matched to field names of the schema. | ||
| // The header will be removed from lines. | ||
| // Note: if there are only comments in the first block, the header would probably | ||
| // be not extracted. | ||
| val checkHeader = UnivocityParser.checkHeader( | ||
| parser, | ||
| dataSchema, | ||
| _: String, | ||
| file.filePath | ||
| ) | ||
| CSVUtils.extractHeader(lines, parser.options).foreach(checkHeader(_)) | ||
| } | ||
|
|
||
| UnivocityParser.parseIterator(lines, parser, schema) | ||
| } | ||
|
|
||
| override def infer( | ||
|
|
@@ -204,24 +221,35 @@ object MultiLineCSVDataSource extends CSVDataSource { | |
| conf: Configuration, | ||
| file: PartitionedFile, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] = { | ||
| schema: StructType, | ||
| dataSchema: StructType): Iterator[InternalRow] = { | ||
| val checkHeader = UnivocityParser.checkHeaderColumnNames( | ||
| parser, | ||
| dataSchema, | ||
| _: Array[String], | ||
| file.filePath | ||
| ) | ||
| UnivocityParser.parseStream( | ||
| CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), | ||
| parser.options.headerFlag, | ||
| parser, | ||
| schema) | ||
| schema, | ||
| checkHeader) | ||
| } | ||
|
|
||
| override def infer( | ||
| sparkSession: SparkSession, | ||
| inputPaths: Seq[FileStatus], | ||
| parsedOptions: CSVOptions): StructType = { | ||
| val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) | ||
| // The header is not checked because there is no schema against with it could be check | ||
| val checkHeader = (_: Array[String]) => () | ||
| csv.flatMap { lines => | ||
| val path = new Path(lines.getPath()) | ||
| UnivocityParser.tokenizeStream( | ||
| CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), | ||
| shouldDropHeader = false, | ||
| checkHeader, | ||
|
||
| new CsvParser(parsedOptions.asParserSettings)) | ||
| }.take(1).headOption match { | ||
| case Some(firstRow) => | ||
|
|
@@ -233,6 +261,7 @@ object MultiLineCSVDataSource extends CSVDataSource { | |
| lines.getConfiguration, | ||
| new Path(lines.getPath())), | ||
| parsedOptions.headerFlag, | ||
| checkHeader, | ||
| new CsvParser(parsedOptions.asParserSettings)) | ||
| } | ||
| CSVInferSchema.infer(tokenRDD, header, parsedOptions) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,6 +150,12 @@ class CSVOptions( | |
|
|
||
| val isCommentSet = this.comment != '\u0000' | ||
|
|
||
| /** | ||
| * The option enables checks of headers in csv files. In particular, column names | ||
| * are matched to field names of provided schema. | ||
| */ | ||
| val checkHeader = getBool("checkHeader", true) | ||
|
||
|
|
||
| def asWriterSettings: CsvWriterSettings = { | ||
| val writerSettings = new CsvWriterSettings() | ||
| val format = writerSettings.getFormat | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -237,8 +237,9 @@ private[csv] object UnivocityParser { | |
| def tokenizeStream( | ||
| inputStream: InputStream, | ||
| shouldDropHeader: Boolean, | ||
| checkHeader: Array[String] => Unit, | ||
| tokenizer: CsvParser): Iterator[Array[String]] = { | ||
| convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens) | ||
| convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader)(tokens => tokens) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -248,26 +249,30 @@ private[csv] object UnivocityParser { | |
| inputStream: InputStream, | ||
| shouldDropHeader: Boolean, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] = { | ||
| schema: StructType, | ||
| checkHeader: Array[String] => Unit): Iterator[InternalRow] = { | ||
| val tokenizer = parser.tokenizer | ||
| val safeParser = new FailureSafeParser[Array[String]]( | ||
| input => Seq(parser.convert(input)), | ||
| parser.options.parseMode, | ||
| schema, | ||
| parser.options.columnNameOfCorruptRecord) | ||
| convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => | ||
| convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => | ||
| safeParser.parse(tokens) | ||
| }.flatten | ||
| } | ||
|
|
||
| private def convertStream[T]( | ||
| inputStream: InputStream, | ||
| shouldDropHeader: Boolean, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, why did we rename this variable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To show what exactly the parameter controls - dropping the first record in the stream. Responsibility for header manipulation belongs to higher level - |
||
| tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { | ||
| tokenizer: CsvParser, | ||
| checkHeader: Array[String] => Unit | ||
| )(convert: Array[String] => T) = new Iterator[T] { | ||
| tokenizer.beginParsing(inputStream) | ||
| private var nextRecord = { | ||
| if (shouldDropHeader) { | ||
| tokenizer.parseNext() | ||
| val header = tokenizer.parseNext() | ||
| checkHeader(header) | ||
| } | ||
| tokenizer.parseNext() | ||
| } | ||
|
|
@@ -289,27 +294,52 @@ private[csv] object UnivocityParser { | |
| */ | ||
| def parseIterator( | ||
| lines: Iterator[String], | ||
| shouldDropHeader: Boolean, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] = { | ||
| val options = parser.options | ||
|
|
||
| val linesWithoutHeader = if (shouldDropHeader) { | ||
| // Note that if there are only comments in the first block, the header would probably | ||
| // be not dropped. | ||
| CSVUtils.dropHeaderLine(lines, options) | ||
| } else { | ||
| lines | ||
| } | ||
|
|
||
| val filteredLines: Iterator[String] = | ||
| CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) | ||
| CSVUtils.filterCommentAndEmpty(lines, options) | ||
|
||
|
|
||
| val safeParser = new FailureSafeParser[String]( | ||
| input => Seq(parser.parse(input)), | ||
| parser.options.parseMode, | ||
| schema, | ||
| parser.options.columnNameOfCorruptRecord) | ||
|
|
||
|
||
| filteredLines.flatMap(safeParser.parse) | ||
| } | ||
|
|
||
| def checkHeaderColumnNames( | ||
| parser: UnivocityParser, | ||
| schema: StructType, | ||
| columnNames: Array[String], | ||
| fileName: String | ||
| ): Unit = { | ||
| if (parser.options.checkHeader && columnNames != null) { | ||
| val fieldNames = schema.map(_.name) | ||
| val isMatched = fieldNames.zip(columnNames).forall { pair => | ||
| val (nameInSchema, nameInHeader) = pair | ||
| nameInSchema == nameInHeader | ||
|
||
| } | ||
| if (!isMatched) { | ||
| throw new IllegalArgumentException( | ||
| s"""|Fields in the header of csv file are not matched to field names of the schema: | ||
| | Header: ${columnNames.mkString(", ")} | ||
| | Schema: ${fieldNames.mkString(", ")} | ||
| |CSV file: $fileName""".stripMargin | ||
| ) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def checkHeader( | ||
| parser: UnivocityParser, | ||
| schema: StructType, | ||
| header: String, | ||
| fileName: String | ||
| ): Unit = { | ||
| lazy val columnNames = parser.tokenizer.parseLine(header) | ||
| checkHeaderColumnNames(parser, schema, columnNames, fileName) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -252,7 +252,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | |
| |(yearMade double, makeName string, modelName string, priceTag decimal, | ||
| | comments string, grp string) | ||
| |USING csv | ||
| |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") | ||
| |OPTIONS ( | ||
| | path "${testFile(carsTsvFile)}", | ||
| | header "true", checkHeader "false", | ||
| | delimiter "\t" | ||
| |) | ||
| """.stripMargin.replaceAll("\n", " ")) | ||
|
|
||
| assert( | ||
|
|
@@ -275,7 +279,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | |
| test("test for blank column names on read and select columns") { | ||
| val cars = spark.read | ||
| .format("csv") | ||
| .options(Map("header" -> "true", "inferSchema" -> "true")) | ||
| .options(Map("header" -> "true", "checkHeader" -> "false", "inferSchema" -> "true")) | ||
| .load(testFile(carsBlankColName)) | ||
|
|
||
| assert(cars.select("customer").collect().size == 2) | ||
|
|
@@ -348,15 +352,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | |
| spark.sql( | ||
| s""" | ||
| |CREATE TEMPORARY VIEW carsTable | ||
| |(yearMade double, makeName string, modelName string, comments string, blank string) | ||
| |(year double, make string, model string, comment string, blank string) | ||
| |USING csv | ||
| |OPTIONS (path "${testFile(carsFile)}", header "true") | ||
| """.stripMargin.replaceAll("\n", " ")) | ||
|
|
||
| val cars = spark.table("carsTable") | ||
| verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) | ||
| assert( | ||
| cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) | ||
| cars.schema.fieldNames === Array("year", "make", "model", "comment", "blank")) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1279,4 +1283,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | |
| Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil | ||
| ) | ||
| } | ||
|
|
||
| def checkHeader(multiLine: String): Unit = { | ||
|
||
| test(s"SPARK-23786: Checking column names against schema ($multiLine)") { | ||
|
||
| withTempPath { path => | ||
| import collection.JavaConverters._ | ||
|
||
| val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) | ||
| val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) | ||
| odf.write.option("header", "true").csv(path.getCanonicalPath) | ||
| val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) | ||
| val exception = intercept[SparkException] { | ||
| spark.read | ||
| .schema(ischema) | ||
| .option("multiLine", multiLine) | ||
| .option("header", "true") | ||
| .option("checkHeader", "true") | ||
| .csv(path.getCanonicalPath) | ||
| .collect() | ||
| } | ||
| assert(exception.getMessage.contains( | ||
| "Fields in the header of csv file are not matched to field names of the schema" | ||
| )) | ||
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| List("false", "true").foreach(checkHeader(_)) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try-except