Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -505,20 +505,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
val actualSchema =
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))

val linesWithoutHeader = if (parsedOptions.headerFlag && maybeFirstLine.isDefined) {
val firstLine = maybeFirstLine.get
val parser = new CsvParser(parsedOptions.asParserSettings)
val columnNames = parser.parseLine(firstLine)
CSVDataSource.checkHeaderColumnNames(
val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
val headerChecker = new CSVHeaderChecker(
actualSchema,
columnNames,
csvDataset.getClass.getCanonicalName,
parsedOptions.enforceSchema,
sparkSession.sessionState.conf.caseSensitiveAnalysis)
parsedOptions,
source = s"CSV source: $csvDataset")
headerChecker.checkHeaderColumnNames(firstLine)
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
} else {
filteredLines.rdd
}
}.getOrElse(filteredLines.rdd)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not directly related to your changes. Just in case, why do we convert Dataset to RDD here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't exactly remember. Looks we can change it to Dataset.


val parsed = linesWithoutHeader.mapPartitions { iter =>
val rawParser = new UnivocityParser(actualSchema, parsedOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,8 @@ abstract class CSVDataSource extends Serializable {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
requiredSchema: StructType,
// Actual schema of data in the csv file
dataSchema: StructType,
caseSensitive: Boolean,
columnPruning: Boolean): Iterator[InternalRow]
headerChecker: CSVHeaderChecker,
requiredSchema: StructType): Iterator[InternalRow]

/**
* Infers the schema from `inputPaths` files.
Expand All @@ -75,48 +72,6 @@ abstract class CSVDataSource extends Serializable {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): StructType

/**
* Generates a header from the given row which is null-safe and duplicate-safe.
*/
protected def makeSafeHeader(
row: Array[String],
caseSensitive: Boolean,
options: CSVOptions): Array[String] = {
if (options.headerFlag) {
val duplicates = {
val headerNames = row.filter(_ != null)
// scalastyle:off caselocale
.map(name => if (caseSensitive) name else name.toLowerCase)
// scalastyle:on caselocale
headerNames.diff(headerNames.distinct).distinct
}

row.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == options.nullValue) {
// When there are empty strings or the values set in `nullValue`, put the
// index as the suffix.
s"_c$index"
// scalastyle:off caselocale
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
// scalastyle:on caselocale
// When there are case-insensitive duplicates, put the index as the suffix.
s"$value$index"
} else if (duplicates.contains(value)) {
// When there are duplicates, put the index as the suffix.
s"$value$index"
} else {
value
}
}
} else {
row.zipWithIndex.map { case (_, index) =>
// Uses default column names, "_c#" where # is its position of fields
// when header option is disabled.
s"_c$index"
}
}
}
}

object CSVDataSource extends Logging {
Expand All @@ -127,67 +82,6 @@ object CSVDataSource extends Logging {
TextInputCSVDataSource
}
}

/**
* Checks that column names in a CSV header and field names in the schema are the same
* by taking into account case sensitivity.
*
* @param schema - provided (or inferred) schema to which CSV must conform.
* @param columnNames - names of CSV columns that must be checked against to the schema.
* @param fileName - name of CSV file that are currently checked. It is used in error messages.
* @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column
* names are checked for conformance to the schema. In the case if
* the column name don't conform to the schema, an exception is thrown.
* @param caseSensitive - if it is set to `false`, comparison of column names and schema field
* names is not case sensitive.
*/
def checkHeaderColumnNames(
schema: StructType,
columnNames: Array[String],
fileName: String,
enforceSchema: Boolean,
caseSensitive: Boolean): Unit = {
if (columnNames != null) {
val fieldNames = schema.map(_.name).toIndexedSeq
val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
var errorMessage: Option[String] = None

if (headerLen == schemaSize) {
var i = 0
while (errorMessage.isEmpty && i < headerLen) {
var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
if (!caseSensitive) {
// scalastyle:off caselocale
nameInSchema = nameInSchema.toLowerCase
nameInHeader = nameInHeader.toLowerCase
// scalastyle:on caselocale
}
if (nameInHeader != nameInSchema) {
errorMessage = Some(
s"""|CSV header does not conform to the schema.
| Header: ${columnNames.mkString(", ")}
| Schema: ${fieldNames.mkString(", ")}
|Expected: ${fieldNames(i)} but found: ${columnNames(i)}
|CSV file: $fileName""".stripMargin)
}
i += 1
}
} else {
errorMessage = Some(
s"""|Number of column in CSV header is not equal to number of fields in the schema:
| Header length: $headerLen, schema size: $schemaSize
|CSV file: $fileName""".stripMargin)
}

errorMessage.foreach { msg =>
if (enforceSchema) {
logWarning(msg)
} else {
throw new IllegalArgumentException(msg)
}
}
}
}
}

object TextInputCSVDataSource extends CSVDataSource {
Expand All @@ -197,10 +91,8 @@ object TextInputCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
requiredSchema: StructType,
dataSchema: StructType,
caseSensitive: Boolean,
columnPruning: Boolean): Iterator[InternalRow] = {
headerChecker: CSVHeaderChecker,
requiredSchema: StructType): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
Expand All @@ -209,25 +101,7 @@ object TextInputCSVDataSource extends CSVDataSource {
}
}

val hasHeader = parser.options.headerFlag && file.start == 0
if (hasHeader) {
// Checking that column names in the header are matched to field names of the schema.
// The header will be removed from lines.
// Note: if there are only comments in the first block, the header would probably
// be not extracted.
CSVUtils.extractHeader(lines, parser.options).foreach { header =>
val schema = if (columnPruning) requiredSchema else dataSchema
val columnNames = parser.tokenizer.parseLine(header)
CSVDataSource.checkHeaderColumnNames(
schema,
columnNames,
file.filePath,
parser.options.enforceSchema,
caseSensitive)
}
}

UnivocityParser.parseIterator(lines, parser, requiredSchema)
UnivocityParser.parseIterator(lines, parser, headerChecker, requiredSchema)
}

override def infer(
Expand All @@ -251,7 +125,7 @@ object TextInputCSVDataSource extends CSVDataSource {
maybeFirstLine.map(csvParser.parseLine(_)) match {
case Some(firstRow) if firstRow != null =>
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about to import it from CSVUtils? What is the reason to have the prefix here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because mostly in this codes use CSVUtils... one. I just followed it.

val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions)
val tokenRDD = sampled.rdd.mapPartitions { iter =>
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
Expand Down Expand Up @@ -298,26 +172,13 @@ object MultiLineCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
requiredSchema: StructType,
dataSchema: StructType,
caseSensitive: Boolean,
columnPruning: Boolean): Iterator[InternalRow] = {
def checkHeader(header: Array[String]): Unit = {
val schema = if (columnPruning) requiredSchema else dataSchema
CSVDataSource.checkHeaderColumnNames(
schema,
header,
file.filePath,
parser.options.enforceSchema,
caseSensitive)
}

headerChecker: CSVHeaderChecker,
requiredSchema: StructType): Iterator[InternalRow] = {
UnivocityParser.parseStream(
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))),
parser.options.headerFlag,
parser,
requiredSchema,
checkHeader)
headerChecker,
requiredSchema)
}

override def infer(
Expand All @@ -334,7 +195,7 @@ object MultiLineCSVDataSource extends CSVDataSource {
}.take(1).headOption match {
case Some(firstRow) =>
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions)
val tokenRDD = csv.flatMap { lines =>
UnivocityParser.tokenizeStream(
CodecStreams.createInputStreamWithCloseResource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
"df.filter($\"_corrupt_record\".isNotNull).count()."
)
}
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val columnPruning = sparkSession.sessionState.conf.csvColumnPruning

(file: PartitionedFile) => {
Expand All @@ -139,14 +138,16 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
parsedOptions)
val schema = if (columnPruning) requiredSchema else dataSchema
val isStartOfFile = file.start == 0
val headerChecker = new CSVHeaderChecker(
schema, parsedOptions, source = s"CSV file: ${file.filePath}", isStartOfFile)
CSVDataSource(parsedOptions).readFile(
conf,
file,
parser,
requiredSchema,
dataSchema,
caseSensitive,
columnPruning)
headerChecker,
requiredSchema)
}
}

Expand Down
Loading