-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-25684][SQL] Organize header related codes in CSV datasource #22676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,11 +51,8 @@ abstract class CSVDataSource extends Serializable { | |
conf: Configuration, | ||
file: PartitionedFile, | ||
parser: UnivocityParser, | ||
requiredSchema: StructType, | ||
// Actual schema of data in the csv file | ||
dataSchema: StructType, | ||
caseSensitive: Boolean, | ||
columnPruning: Boolean): Iterator[InternalRow] | ||
headerChecker: CSVHeaderChecker, | ||
requiredSchema: StructType): Iterator[InternalRow] | ||
|
||
/** | ||
* Infers the schema from `inputPaths` files. | ||
|
@@ -75,48 +72,6 @@ abstract class CSVDataSource extends Serializable { | |
sparkSession: SparkSession, | ||
inputPaths: Seq[FileStatus], | ||
parsedOptions: CSVOptions): StructType | ||
|
||
/** | ||
* Generates a header from the given row which is null-safe and duplicate-safe. | ||
*/ | ||
protected def makeSafeHeader( | ||
row: Array[String], | ||
caseSensitive: Boolean, | ||
options: CSVOptions): Array[String] = { | ||
if (options.headerFlag) { | ||
val duplicates = { | ||
val headerNames = row.filter(_ != null) | ||
// scalastyle:off caselocale | ||
.map(name => if (caseSensitive) name else name.toLowerCase) | ||
// scalastyle:on caselocale | ||
headerNames.diff(headerNames.distinct).distinct | ||
} | ||
|
||
row.zipWithIndex.map { case (value, index) => | ||
if (value == null || value.isEmpty || value == options.nullValue) { | ||
// When there are empty strings or the values set in `nullValue`, put the | ||
// index as the suffix. | ||
s"_c$index" | ||
// scalastyle:off caselocale | ||
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { | ||
// scalastyle:on caselocale | ||
// When there are case-insensitive duplicates, put the index as the suffix. | ||
s"$value$index" | ||
} else if (duplicates.contains(value)) { | ||
// When there are duplicates, put the index as the suffix. | ||
s"$value$index" | ||
} else { | ||
value | ||
} | ||
} | ||
} else { | ||
row.zipWithIndex.map { case (_, index) => | ||
// Uses default column names, "_c#" where # is its position of fields | ||
// when header option is disabled. | ||
s"_c$index" | ||
} | ||
} | ||
} | ||
} | ||
|
||
object CSVDataSource extends Logging { | ||
|
@@ -127,67 +82,6 @@ object CSVDataSource extends Logging { | |
TextInputCSVDataSource | ||
} | ||
} | ||
|
||
/** | ||
* Checks that column names in a CSV header and field names in the schema are the same | ||
* by taking into account case sensitivity. | ||
* | ||
* @param schema - provided (or inferred) schema to which CSV must conform. | ||
* @param columnNames - names of CSV columns that must be checked against to the schema. | ||
* @param fileName - name of CSV file that are currently checked. It is used in error messages. | ||
* @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column | ||
* names are checked for conformance to the schema. In the case if | ||
* the column name don't conform to the schema, an exception is thrown. | ||
* @param caseSensitive - if it is set to `false`, comparison of column names and schema field | ||
* names is not case sensitive. | ||
*/ | ||
def checkHeaderColumnNames( | ||
schema: StructType, | ||
columnNames: Array[String], | ||
fileName: String, | ||
enforceSchema: Boolean, | ||
caseSensitive: Boolean): Unit = { | ||
if (columnNames != null) { | ||
val fieldNames = schema.map(_.name).toIndexedSeq | ||
val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) | ||
var errorMessage: Option[String] = None | ||
|
||
if (headerLen == schemaSize) { | ||
var i = 0 | ||
while (errorMessage.isEmpty && i < headerLen) { | ||
var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) | ||
if (!caseSensitive) { | ||
// scalastyle:off caselocale | ||
nameInSchema = nameInSchema.toLowerCase | ||
nameInHeader = nameInHeader.toLowerCase | ||
// scalastyle:on caselocale | ||
} | ||
if (nameInHeader != nameInSchema) { | ||
errorMessage = Some( | ||
s"""|CSV header does not conform to the schema. | ||
| Header: ${columnNames.mkString(", ")} | ||
| Schema: ${fieldNames.mkString(", ")} | ||
|Expected: ${fieldNames(i)} but found: ${columnNames(i)} | ||
|CSV file: $fileName""".stripMargin) | ||
} | ||
i += 1 | ||
} | ||
} else { | ||
errorMessage = Some( | ||
s"""|Number of column in CSV header is not equal to number of fields in the schema: | ||
| Header length: $headerLen, schema size: $schemaSize | ||
|CSV file: $fileName""".stripMargin) | ||
} | ||
|
||
errorMessage.foreach { msg => | ||
if (enforceSchema) { | ||
logWarning(msg) | ||
} else { | ||
throw new IllegalArgumentException(msg) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
object TextInputCSVDataSource extends CSVDataSource { | ||
|
@@ -197,10 +91,8 @@ object TextInputCSVDataSource extends CSVDataSource { | |
conf: Configuration, | ||
file: PartitionedFile, | ||
parser: UnivocityParser, | ||
requiredSchema: StructType, | ||
dataSchema: StructType, | ||
caseSensitive: Boolean, | ||
columnPruning: Boolean): Iterator[InternalRow] = { | ||
headerChecker: CSVHeaderChecker, | ||
requiredSchema: StructType): Iterator[InternalRow] = { | ||
val lines = { | ||
val linesReader = new HadoopFileLinesReader(file, conf) | ||
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) | ||
|
@@ -209,25 +101,7 @@ object TextInputCSVDataSource extends CSVDataSource { | |
} | ||
} | ||
|
||
val hasHeader = parser.options.headerFlag && file.start == 0 | ||
if (hasHeader) { | ||
// Checking that column names in the header are matched to field names of the schema. | ||
// The header will be removed from lines. | ||
// Note: if there are only comments in the first block, the header would probably | ||
// be not extracted. | ||
CSVUtils.extractHeader(lines, parser.options).foreach { header => | ||
val schema = if (columnPruning) requiredSchema else dataSchema | ||
val columnNames = parser.tokenizer.parseLine(header) | ||
CSVDataSource.checkHeaderColumnNames( | ||
schema, | ||
columnNames, | ||
file.filePath, | ||
parser.options.enforceSchema, | ||
caseSensitive) | ||
} | ||
} | ||
|
||
UnivocityParser.parseIterator(lines, parser, requiredSchema) | ||
UnivocityParser.parseIterator(lines, parser, headerChecker, requiredSchema) | ||
} | ||
|
||
override def infer( | ||
|
@@ -251,7 +125,7 @@ object TextInputCSVDataSource extends CSVDataSource { | |
maybeFirstLine.map(csvParser.parseLine(_)) match { | ||
case Some(firstRow) if firstRow != null => | ||
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis | ||
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) | ||
val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about to import it from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because mostly in this codes use |
||
val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) | ||
val tokenRDD = sampled.rdd.mapPartitions { iter => | ||
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) | ||
|
@@ -298,26 +172,13 @@ object MultiLineCSVDataSource extends CSVDataSource { | |
conf: Configuration, | ||
file: PartitionedFile, | ||
parser: UnivocityParser, | ||
requiredSchema: StructType, | ||
dataSchema: StructType, | ||
caseSensitive: Boolean, | ||
columnPruning: Boolean): Iterator[InternalRow] = { | ||
def checkHeader(header: Array[String]): Unit = { | ||
val schema = if (columnPruning) requiredSchema else dataSchema | ||
CSVDataSource.checkHeaderColumnNames( | ||
schema, | ||
header, | ||
file.filePath, | ||
parser.options.enforceSchema, | ||
caseSensitive) | ||
} | ||
|
||
headerChecker: CSVHeaderChecker, | ||
requiredSchema: StructType): Iterator[InternalRow] = { | ||
UnivocityParser.parseStream( | ||
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), | ||
parser.options.headerFlag, | ||
parser, | ||
requiredSchema, | ||
checkHeader) | ||
headerChecker, | ||
requiredSchema) | ||
} | ||
|
||
override def infer( | ||
|
@@ -334,7 +195,7 @@ object MultiLineCSVDataSource extends CSVDataSource { | |
}.take(1).headOption match { | ||
case Some(firstRow) => | ||
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis | ||
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) | ||
val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions) | ||
val tokenRDD = csv.flatMap { lines => | ||
UnivocityParser.tokenizeStream( | ||
CodecStreams.createInputStreamWithCloseResource( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not directly related to your changes. Just in case, why do we convert
Dataset
toRDD
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't exactly remember. Looks we can change it to
Dataset
.