Skip to content

Commit 7e5359b

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-19610][SQL] Support parsing multiline CSV files
## What changes were proposed in this pull request? This PR proposes the support for multiple lines for CSV by resembling the multiline supports in JSON datasource (in case of JSON, per file). So, this PR introduces `wholeFile` option which makes the format not splittable and reads each whole file. Since Univocity parser can produces each row from a stream, it should be capable of parsing very large documents when the internal rows are fix in the memory. ## How was this patch tested? Unit tests in `CSVSuite` and `tests.py` Manual tests with a single 9GB CSV file in local file system, for example, ```scala spark.read.option("wholeFile", true).option("inferSchema", true).csv("tmp.csv").count() ``` Author: hyukjinkwon <[email protected]> Closes #16976 from HyukjinKwon/SPARK-19610.
1 parent ce233f1 commit 7e5359b

File tree

14 files changed

+525
-197
lines changed

14 files changed

+525
-197
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
308308
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
309309
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
310310
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
311-
columnNameOfCorruptRecord=None):
311+
columnNameOfCorruptRecord=None, wholeFile=None):
312312
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
313313
314314
This function will go through the input once to determine the input schema if
@@ -385,6 +385,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
385385
``spark.sql.columnNameOfCorruptRecord``. If None is set,
386386
it uses the value specified in
387387
``spark.sql.columnNameOfCorruptRecord``.
388+
:param wholeFile: parse records, which may span multiple lines. If None is
389+
set, it uses the default value, ``false``.
388390
389391
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
390392
>>> df.dtypes
@@ -398,7 +400,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
398400
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
399401
maxCharsPerColumn=maxCharsPerColumn,
400402
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
401-
columnNameOfCorruptRecord=columnNameOfCorruptRecord)
403+
columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile)
402404
if isinstance(path, basestring):
403405
path = [path]
404406
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))

python/pyspark/sql/streaming.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
562562
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
563563
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
564564
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
565-
columnNameOfCorruptRecord=None):
565+
columnNameOfCorruptRecord=None, wholeFile=None):
566566
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
567567
568568
This function will go through the input once to determine the input schema if
@@ -637,6 +637,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
637637
``spark.sql.columnNameOfCorruptRecord``. If None is set,
638638
it uses the value specified in
639639
``spark.sql.columnNameOfCorruptRecord``.
640+
:param wholeFile: parse one record, which may span multiple lines. If None is
641+
set, it uses the default value, ``false``.
640642
641643
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
642644
>>> csv_sdf.isStreaming
@@ -652,7 +654,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
652654
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
653655
maxCharsPerColumn=maxCharsPerColumn,
654656
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
655-
columnNameOfCorruptRecord=columnNameOfCorruptRecord)
657+
columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile)
656658
if isinstance(path, basestring):
657659
return self._df(self._jreader.csv(path))
658660
else:

python/pyspark/sql/tests.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,19 @@ def test_udf_with_order_by_and_limit(self):
437437
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
438438

439439
def test_wholefile_json(self):
440-
from pyspark.sql.types import StringType
441440
people1 = self.spark.read.json("python/test_support/sql/people.json")
442441
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
443442
wholeFile=True)
444443
self.assertEqual(people1.collect(), people_array.collect())
445444

445+
def test_wholefile_csv(self):
446+
ages_newlines = self.spark.read.csv(
447+
"python/test_support/sql/ages_newlines.csv", wholeFile=True)
448+
expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
449+
Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
450+
Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
451+
self.assertEqual(ages_newlines.collect(), expected)
452+
446453
def test_udf_with_input_file_name(self):
447454
from pyspark.sql.functions import udf, input_file_name
448455
from pyspark.sql.types import StringType
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Joe,20,"Hi,
2+
I am Jeo"
3+
Tom,30,"My name is Tom"
4+
Hyukjin,25,"I am Hyukjin
5+
6+
I love Spark!"

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
463463
* <li>`columnNameOfCorruptRecord` (default is the value specified in
464464
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
465465
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
466+
* <li>`wholeFile` (default `false`): parse one record, which may span multiple lines.</li>
466467
* </ul>
467468
* @since 2.0.0
468469
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.JobContext
2727
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
2828
import org.apache.hadoop.util.ReflectionUtils
2929

30+
import org.apache.spark.TaskContext
31+
3032
object CodecStreams {
3133
private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = {
3234
val compressionCodecs = new CompressionCodecFactory(config)
@@ -42,6 +44,16 @@ object CodecStreams {
4244
.getOrElse(inputStream)
4345
}
4446

47+
/**
48+
* Creates an input stream from the string path and add a closure for the input stream to be
49+
* closed on task completion.
50+
*/
51+
def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = {
52+
val inputStream = createInputStream(config, new Path(path))
53+
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close()))
54+
inputStream
55+
}
56+
4557
private def getCompressionCodec(
4658
context: JobContext,
4759
file: Option[Path] = None): Option[CompressionCodec] = {
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources.csv
19+
20+
import java.io.InputStream
21+
import java.nio.charset.{Charset, StandardCharsets}
22+
23+
import com.univocity.parsers.csv.{CsvParser, CsvParserSettings}
24+
import org.apache.hadoop.conf.Configuration
25+
import org.apache.hadoop.fs.{FileStatus, Path}
26+
import org.apache.hadoop.io.{LongWritable, Text}
27+
import org.apache.hadoop.mapred.TextInputFormat
28+
import org.apache.hadoop.mapreduce.Job
29+
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
30+
31+
import org.apache.spark.TaskContext
32+
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
33+
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
34+
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
35+
import org.apache.spark.sql.catalyst.InternalRow
36+
import org.apache.spark.sql.execution.datasources._
37+
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
38+
import org.apache.spark.sql.types.StructType
39+
40+
/**
41+
* Common functions for parsing CSV files
42+
*/
43+
abstract class CSVDataSource extends Serializable {
44+
def isSplitable: Boolean
45+
46+
/**
47+
* Parse a [[PartitionedFile]] into [[InternalRow]] instances.
48+
*/
49+
def readFile(
50+
conf: Configuration,
51+
file: PartitionedFile,
52+
parser: UnivocityParser,
53+
parsedOptions: CSVOptions): Iterator[InternalRow]
54+
55+
/**
56+
* Infers the schema from `inputPaths` files.
57+
*/
58+
def infer(
59+
sparkSession: SparkSession,
60+
inputPaths: Seq[FileStatus],
61+
parsedOptions: CSVOptions): Option[StructType]
62+
63+
/**
64+
* Generates a header from the given row which is null-safe and duplicate-safe.
65+
*/
66+
protected def makeSafeHeader(
67+
row: Array[String],
68+
caseSensitive: Boolean,
69+
options: CSVOptions): Array[String] = {
70+
if (options.headerFlag) {
71+
val duplicates = {
72+
val headerNames = row.filter(_ != null)
73+
.map(name => if (caseSensitive) name else name.toLowerCase)
74+
headerNames.diff(headerNames.distinct).distinct
75+
}
76+
77+
row.zipWithIndex.map { case (value, index) =>
78+
if (value == null || value.isEmpty || value == options.nullValue) {
79+
// When there are empty strings or the values set in `nullValue`, put the
80+
// index as the suffix.
81+
s"_c$index"
82+
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
83+
// When there are case-insensitive duplicates, put the index as the suffix.
84+
s"$value$index"
85+
} else if (duplicates.contains(value)) {
86+
// When there are duplicates, put the index as the suffix.
87+
s"$value$index"
88+
} else {
89+
value
90+
}
91+
}
92+
} else {
93+
row.zipWithIndex.map { case (_, index) =>
94+
// Uses default column names, "_c#" where # is its position of fields
95+
// when header option is disabled.
96+
s"_c$index"
97+
}
98+
}
99+
}
100+
}
101+
102+
object CSVDataSource {
103+
def apply(options: CSVOptions): CSVDataSource = {
104+
if (options.wholeFile) {
105+
WholeFileCSVDataSource
106+
} else {
107+
TextInputCSVDataSource
108+
}
109+
}
110+
}
111+
112+
object TextInputCSVDataSource extends CSVDataSource {
113+
override val isSplitable: Boolean = true
114+
115+
override def readFile(
116+
conf: Configuration,
117+
file: PartitionedFile,
118+
parser: UnivocityParser,
119+
parsedOptions: CSVOptions): Iterator[InternalRow] = {
120+
val lines = {
121+
val linesReader = new HadoopFileLinesReader(file, conf)
122+
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
123+
linesReader.map { line =>
124+
new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
125+
}
126+
}
127+
128+
val shouldDropHeader = parsedOptions.headerFlag && file.start == 0
129+
UnivocityParser.parseIterator(lines, shouldDropHeader, parser)
130+
}
131+
132+
override def infer(
133+
sparkSession: SparkSession,
134+
inputPaths: Seq[FileStatus],
135+
parsedOptions: CSVOptions): Option[StructType] = {
136+
val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions)
137+
val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first()
138+
val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
139+
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
140+
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
141+
val tokenRDD = csv.rdd.mapPartitions { iter =>
142+
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
143+
val linesWithoutHeader =
144+
CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
145+
val parser = new CsvParser(parsedOptions.asParserSettings)
146+
linesWithoutHeader.map(parser.parseLine)
147+
}
148+
149+
Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
150+
}
151+
152+
private def createBaseDataset(
153+
sparkSession: SparkSession,
154+
inputPaths: Seq[FileStatus],
155+
options: CSVOptions): Dataset[String] = {
156+
val paths = inputPaths.map(_.getPath.toString)
157+
if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
158+
sparkSession.baseRelationToDataFrame(
159+
DataSource.apply(
160+
sparkSession,
161+
paths = paths,
162+
className = classOf[TextFileFormat].getName
163+
).resolveRelation(checkFilesExist = false))
164+
.select("value").as[String](Encoders.STRING)
165+
} else {
166+
val charset = options.charset
167+
val rdd = sparkSession.sparkContext
168+
.hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(","))
169+
.mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
170+
sparkSession.createDataset(rdd)(Encoders.STRING)
171+
}
172+
}
173+
}
174+
175+
object WholeFileCSVDataSource extends CSVDataSource {
176+
override val isSplitable: Boolean = false
177+
178+
override def readFile(
179+
conf: Configuration,
180+
file: PartitionedFile,
181+
parser: UnivocityParser,
182+
parsedOptions: CSVOptions): Iterator[InternalRow] = {
183+
UnivocityParser.parseStream(
184+
CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
185+
parsedOptions.headerFlag,
186+
parser)
187+
}
188+
189+
override def infer(
190+
sparkSession: SparkSession,
191+
inputPaths: Seq[FileStatus],
192+
parsedOptions: CSVOptions): Option[StructType] = {
193+
val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions)
194+
val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines =>
195+
UnivocityParser.tokenizeStream(
196+
CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
197+
false,
198+
new CsvParser(parsedOptions.asParserSettings))
199+
}.take(1).headOption
200+
201+
if (maybeFirstRow.isDefined) {
202+
val firstRow = maybeFirstRow.get
203+
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
204+
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
205+
val tokenRDD = csv.flatMap { lines =>
206+
UnivocityParser.tokenizeStream(
207+
CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
208+
parsedOptions.headerFlag,
209+
new CsvParser(parsedOptions.asParserSettings))
210+
}
211+
Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
212+
} else {
213+
// If the first row could not be read, just return the empty schema.
214+
Some(StructType(Nil))
215+
}
216+
}
217+
218+
private def createBaseRdd(
219+
sparkSession: SparkSession,
220+
inputPaths: Seq[FileStatus],
221+
options: CSVOptions): RDD[PortableDataStream] = {
222+
val paths = inputPaths.map(_.getPath)
223+
val name = paths.mkString(",")
224+
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
225+
FileInputFormat.setInputPaths(job, paths: _*)
226+
val conf = job.getConfiguration
227+
228+
val rdd = new BinaryFileRDD(
229+
sparkSession.sparkContext,
230+
classOf[StreamInputFormat],
231+
classOf[String],
232+
classOf[PortableDataStream],
233+
conf,
234+
sparkSession.sparkContext.defaultMinPartitions)
235+
236+
// Only returns `PortableDataStream`s without paths.
237+
rdd.setName(s"CSVFile: $name").values
238+
}
239+
}

0 commit comments

Comments
 (0)