Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
112ce2d
Checks column names are compatible to provided schema
MaxGekk Mar 20, 2018
a85ccce
Checking header is matched to schema in per-line mode
MaxGekk Mar 20, 2018
75e1534
Extract header and check that it is matched to schema
MaxGekk Mar 20, 2018
8eb45b8
Checking column names in header in multiLine mode
MaxGekk Mar 21, 2018
9b1a986
Adding the checkHeader option with true by default
MaxGekk Mar 21, 2018
6442633
Fix csv test by changing headers or disabling header checking
MaxGekk Mar 21, 2018
9440d8a
Adding comment for the checkHeader option
MaxGekk Mar 21, 2018
9f91ce7
Added comments
MaxGekk Mar 21, 2018
0878f7a
Adding a space between column names
MaxGekk Mar 21, 2018
a341dd7
Fix a test: checking name duplication in schemas
MaxGekk Mar 21, 2018
98c27ea
Fixing the test and adding ticket number to test's title
MaxGekk Mar 23, 2018
811df6f
Refactoring - removing unneeded parameter
MaxGekk Mar 23, 2018
691cfbc
Output filename in the exception
MaxGekk Mar 25, 2018
efb0105
PySpark: adding a test and checkHeader parameter
MaxGekk Mar 25, 2018
c9f5e14
Removing unneeded parameter - fileName
MaxGekk Mar 25, 2018
e195838
Fix for pycodestyle checks
MaxGekk Mar 25, 2018
d6d370d
Adding description of the checkHeader option
MaxGekk Mar 25, 2018
acd6d2e
Improving error messages and handling the case when header size is no…
MaxGekk Mar 26, 2018
13892fd
Refactoring: check header by calling an uniVocity method
MaxGekk Mar 26, 2018
476b517
Refactoring: convert val to def
MaxGekk Mar 26, 2018
f8167e4
Parse header only if the checkHeader option is true
MaxGekk Mar 26, 2018
d068f6c
Moving header checks to CSVDataSource
MaxGekk Mar 26, 2018
08cfcf4
Making uniVocity wrapper unaware of header
MaxGekk Mar 26, 2018
f6a1694
Fix the test: error mesage was changed
MaxGekk Mar 27, 2018
adbedf3
Revert CSV tests as it was before the option was introduced
MaxGekk Mar 31, 2018
0904daf
Renaming checkHeader to enforceSchema
MaxGekk Mar 31, 2018
191b415
Pass required parameter
MaxGekk Apr 1, 2018
718f7ca
Merge branch 'master' of github.com:apache/spark into check-column-names
MaxGekk Apr 5, 2018
75c1ce6
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk Apr 13, 2018
ab9c514
Addressing Xiao Li's review comments
MaxGekk Apr 13, 2018
0405863
Making header validation case sensitive
MaxGekk Apr 13, 2018
714c66d
Describing enforceSchema in PySpark's csv method
MaxGekk Apr 13, 2018
78d9f66
Respect to caseSensitive parameter
MaxGekk Apr 13, 2018
b43a7c7
Check header on csv parsing from dataset of strings
MaxGekk Apr 13, 2018
a5f2916
Merge branch 'master' of github.com:apache/spark into check-column-names
MaxGekk Apr 18, 2018
9b2d403
Make Scala style checker happy
MaxGekk Apr 18, 2018
1fffc16
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk Apr 27, 2018
ad6cda4
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk May 1, 2018
4bdabe2
Merge branch 'master' of github.com:apache/spark into check-column-names
MaxGekk May 4, 2018
2bd2713
Merge branch 'master' into check-column-names
MaxGekk May 14, 2018
b4bfd1d
Merge branch 'check-column-names' of github.com:MaxGekk/spark-1 into …
MaxGekk May 14, 2018
21f8b10
Removing a space to make Scala style checker happy.
MaxGekk May 14, 2018
aca4db9
Merge branch 'master' of github.com:apache/spark into check-column-names
MaxGekk May 16, 2018
e3b4275
Addressing review comments
MaxGekk May 16, 2018
d704766
Removing unnecessary empty checks
MaxGekk May 17, 2018
04199e0
Addressing review comments
MaxGekk May 17, 2018
d5fde52
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk May 17, 2018
795a878
Addressing Hyukjin Kwon's review comments
MaxGekk May 17, 2018
05fc7cd
Improving description of the option
MaxGekk May 18, 2018
9606711
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk May 18, 2018
11c7591
Addressing Wenchen Fan's review comment
MaxGekk May 18, 2018
7dce1e7
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk May 22, 2018
c008328
Output warnings when enforceSchema is enabled and the schema is not c…
MaxGekk May 25, 2018
9f7c440
Added tests for inferSchema is true and enforceSchema is false
MaxGekk May 25, 2018
e83ad60
Rename dropFirstRecord to shouldDropHeader
MaxGekk May 25, 2018
26ae4f9
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk May 25, 2018
4b6495b
Merge remote-tracking branch 'origin/master' into check-column-names
MaxGekk Jun 1, 2018
c5ee207
Renaming of 'is not conform' to 'does not conform'
MaxGekk Jun 1, 2018
a2cbb7b
Fix Scala coding style
MaxGekk Jun 1, 2018
70e2b75
Added description of checkHeaderColumnNames's arguments
MaxGekk Jun 1, 2018
e7c3ace
Test checks a warning presents in logs
MaxGekk Jun 1, 2018
3b37712
fix python tests
MaxGekk Jun 1, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
samplingRatio=None):
samplingRatio=None, enforceSchema=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.

This function will go through the input once to determine the input schema if
Expand All @@ -373,6 +373,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
default value, ``false``.
:param inferSchema: infers the input schema automatically from data. It requires one extra
pass over the data. If None is set, it uses the default value, ``false``.
:param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add this option to streaming reader and writer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

forcibly applied to datasource files, and headers in CSV files will be
ignored. If the option is set to ``false``, the schema will be
validated against all headers in CSV files or the first header in RDD
if the ``header`` option is set to ``true``. Field names in the schema
and column names in CSV headers are checked by their positions
taking into account ``spark.sql.caseSensitive``. If None is set,
``true`` is used by default. Though the default value is ``true``,
it is recommended to disable the ``enforceSchema`` option
to avoid incorrect results.
:param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from
values being read should be skipped. If None is set, it
uses the default value, ``false``.
Expand Down Expand Up @@ -449,7 +459,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio)
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
enforceSchema=enforceSchema)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down
15 changes: 13 additions & 2 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
enforceSchema=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.

This function will go through the input once to determine the input schema if
Expand Down Expand Up @@ -592,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
default value, ``false``.
:param inferSchema: infers the input schema automatically from data. It requires one extra
pass over the data. If None is set, it uses the default value, ``false``.
:param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
forcibly applied to datasource files, and headers in CSV files will be
ignored. If the option is set to ``false``, the schema will be
validated against all headers in CSV files or the first header in RDD
if the ``header`` option is set to ``true``. Field names in the schema
and column names in CSV headers are checked by their positions
taking into account ``spark.sql.caseSensitive``. If None is set,
``true`` is used by default. Though the default value is ``true``,
it is recommended to disable the ``enforceSchema`` option
to avoid incorrect results.
:param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from
values being read should be skipped. If None is set, it
uses the default value, ``false``.
Expand Down Expand Up @@ -664,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3056,6 +3056,24 @@ def test_csv_sampling_ratio(self):
.csv(rdd, samplingRatio=0.5).schema
self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))

def test_checking_csv_header(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)
try:
self.spark.createDataFrame([[1, 1000], [2000, 2]])\
.toDF('f1', 'f2').write.option("header", "true").csv(path)
schema = StructType([
StructField('f2', IntegerType(), nullable=True),
StructField('f1', IntegerType(), nullable=True)])
df = self.spark.read.option('header', 'true').schema(schema)\
.csv(path, enforceSchema=False)
self.assertRaisesRegexp(
Exception,
"CSV header does not conform to the schema",
lambda: df.collect())
finally:
shutil.rmtree(path)


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.{Locale, Properties}
import scala.collection.JavaConverters._

import com.fasterxml.jackson.databind.ObjectMapper
import com.univocity.parsers.csv.CsvParser

import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
Expand Down Expand Up @@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* it determines the columns as string types and it reads only the first line to determine the
* names and the number of fields.
*
* If the enforceSchema is set to `false`, only the CSV header in the first line is checked
* to conform specified or inferred schema.
*
* @param csvDataset input Dataset with one CSV row per record
* @since 2.2.0
*/
Expand All @@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))

val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
CSVDataSource.checkHeader(
firstLine,
new CsvParser(parsedOptions.asParserSettings),
actualSchema,
csvDataset.getClass.getCanonicalName,
parsedOptions.enforceSchema,
sparkSession.sessionState.conf.caseSensitiveAnalysis)
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)

Expand Down Expand Up @@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`comment` (default empty string): sets a single character used for skipping lines
* beginning with this character. By default, it is disabled.</li>
* <li>`header` (default `false`): uses the first line as names of columns.</li>
* <li>`enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema
* will be forcibly applied to datasource files, and headers in CSV files will be ignored.
* If the option is set to `false`, the schema will be validated against all headers in CSV files
* in the case when the `header` option is set to `true`. Field names in the schema
* and column names in CSV headers are checked by their positions taking into account
* `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable
* the `enforceSchema` option to avoid incorrect results.</li>
* <li>`inferSchema` (default `false`): infers the input schema automatically from data. It
* requires one extra pass over the data.</li>
* <li>`samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.</li>
Expand Down Expand Up @@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
* </ul>
*
* @since 2.0.0
*/
@scala.annotation.varargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat

import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
schema: StructType): Iterator[InternalRow]
requiredSchema: StructType,
// Actual schema of data in the csv file
dataSchema: StructType,
caseSensitive: Boolean): Iterator[InternalRow]

/**
* Infers the schema from `inputPaths` files.
Expand Down Expand Up @@ -110,14 +114,92 @@ abstract class CSVDataSource extends Serializable {
}
}

object CSVDataSource {
object CSVDataSource extends Logging {
def apply(options: CSVOptions): CSVDataSource = {
if (options.multiLine) {
MultiLineCSVDataSource
} else {
TextInputCSVDataSource
}
}

/**
* Checks that column names in a CSV header and field names in the schema are the same
* by taking into account case sensitivity.
*
* @param schema - provided (or inferred) schema to which CSV must conform.
* @param columnNames - names of CSV columns that must be checked against to the schema.
* @param fileName - name of CSV file that are currently checked. It is used in error messages.
* @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column
* names are checked for conformance to the schema. In the case if
* the column name don't conform to the schema, an exception is thrown.
* @param caseSensitive - if it is set to `false`, comparison of column names and schema field
* names is not case sensitive.
*/
def checkHeaderColumnNames(
schema: StructType,
columnNames: Array[String],
fileName: String,
enforceSchema: Boolean,
caseSensitive: Boolean): Unit = {
if (columnNames != null) {
val fieldNames = schema.map(_.name).toIndexedSeq
val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
var errorMessage: Option[String] = None

if (headerLen == schemaSize) {
var i = 0
while (errorMessage.isEmpty && i < headerLen) {
var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
if (!caseSensitive) {
nameInSchema = nameInSchema.toLowerCase
nameInHeader = nameInHeader.toLowerCase
}
if (nameInHeader != nameInSchema) {
errorMessage = Some(
s"""|CSV header does not conform to the schema.
| Header: ${columnNames.mkString(", ")}
| Schema: ${fieldNames.mkString(", ")}
|Expected: ${fieldNames(i)} but found: ${columnNames(i)}
|CSV file: $fileName""".stripMargin)
}
i += 1
}
} else {
errorMessage = Some(
s"""|Number of column in CSV header is not equal to number of fields in the schema:
| Header length: $headerLen, schema size: $schemaSize
|CSV file: $fileName""".stripMargin)
}

errorMessage.foreach { msg =>
if (enforceSchema) {
logWarning(msg)
} else {
throw new IllegalArgumentException(msg)
}
}
}
}

/**
* Checks that CSV header contains the same column names as fields names in the given schema
* by taking into account case sensitivity.
*/
def checkHeader(
header: String,
parser: CsvParser,
schema: StructType,
fileName: String,
enforceSchema: Boolean,
caseSensitive: Boolean): Unit = {
checkHeaderColumnNames(
schema,
parser.parseLine(header),
fileName,
enforceSchema,
caseSensitive)
}
}

object TextInputCSVDataSource extends CSVDataSource {
Expand All @@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
schema: StructType): Iterator[InternalRow] = {
requiredSchema: StructType,
dataSchema: StructType,
caseSensitive: Boolean): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
Expand All @@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource {
}
}

val shouldDropHeader = parser.options.headerFlag && file.start == 0
UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema)
val hasHeader = parser.options.headerFlag && file.start == 0
if (hasHeader) {
// Checking that column names in the header are matched to field names of the schema.
// The header will be removed from lines.
// Note: if there are only comments in the first block, the header would probably
// be not extracted.
CSVUtils.extractHeader(lines, parser.options).foreach { header =>
CSVDataSource.checkHeader(
header,
parser.tokenizer,
dataSchema,
file.filePath,
parser.options.enforceSchema,
caseSensitive)
}
}

UnivocityParser.parseIterator(lines, parser, requiredSchema)
}

override def infer(
Expand Down Expand Up @@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
schema: StructType): Iterator[InternalRow] = {
requiredSchema: StructType,
dataSchema: StructType,
caseSensitive: Boolean): Iterator[InternalRow] = {
def checkHeader(header: Array[String]): Unit = {
CSVDataSource.checkHeaderColumnNames(
dataSchema,
header,
file.filePath,
parser.options.enforceSchema,
caseSensitive)
}

UnivocityParser.parseStream(
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))),
parser.options.headerFlag,
parser,
schema)
requiredSchema,
checkHeader)
}

override def infer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
"df.filter($\"_corrupt_record\".isNotNull).count()."
)
}
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis

(file: PartitionedFile) => {
val conf = broadcastedHadoopConf.value.value
val parser = new UnivocityParser(
StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
parsedOptions)
CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema)
CSVDataSource(parsedOptions).readFile(
conf,
file,
parser,
requiredSchema,
dataSchema,
caseSensitive)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ class CSVOptions(
val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

/**
* Forcibly apply the specified or inferred schema to datasource files.
* If the option is enabled, headers of CSV files will be ignored.
*/
val enforceSchema = getBool("enforceSchema", default = true)

def asWriterSettings: CsvWriterSettings = {
val writerSettings = new CsvWriterSettings()
val format = writerSettings.getFormat
Expand Down
Loading