Skip to content

Commit 79551f5

Browse files
MaxGekkHyukjinKwon
authored andcommitted
[SPARK-25945][SQL] Support locale while parsing date/timestamp from CSV/JSON
## What changes were proposed in this pull request? In the PR, I propose to add new option `locale` into CSVOptions/JSONOptions to make parsing date/timestamps in local languages possible. Currently the locale is hard coded to `Locale.US`. ## How was this patch tested? Added two tests for parsing a date from CSV/JSON - `ноя 2018`. Closes #22951 from MaxGekk/locale. Authored-by: Maxim Gekk <[email protected]> Signed-off-by: hyukjinkwon <[email protected]>
1 parent 973f7c0 commit 79551f5

File tree

10 files changed

+109
-14
lines changed

10 files changed

+109
-14
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
177177
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
178178
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
179179
multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
180-
dropFieldIfAllNull=None, encoding=None):
180+
dropFieldIfAllNull=None, encoding=None, locale=None):
181181
"""
182182
Loads JSON files and returns the results as a :class:`DataFrame`.
183183
@@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
249249
:param dropFieldIfAllNull: whether to ignore column of all null values or empty
250250
array/struct during schema inference. If None is set, it
251251
uses the default value, ``false``.
252+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
253+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
254+
parsing dates and timestamps.
252255
253256
>>> df1 = spark.read.json('python/test_support/sql/people.json')
254257
>>> df1.dtypes
@@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
267270
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
268271
timestampFormat=timestampFormat, multiLine=multiLine,
269272
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
270-
samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding)
273+
samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding,
274+
locale=locale)
271275
if isinstance(path, basestring):
272276
path = [path]
273277
if type(path) == list:
@@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
349353
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
350354
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
351355
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
352-
samplingRatio=None, enforceSchema=None, emptyValue=None):
356+
samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None):
353357
r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
354358
355359
This function will go through the input once to determine the input schema if
@@ -446,6 +450,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
446450
If None is set, it uses the default value, ``1.0``.
447451
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
448452
the default value, empty string.
453+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
454+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
455+
parsing dates and timestamps.
449456
450457
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
451458
>>> df.dtypes
@@ -465,7 +472,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
465472
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
466473
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
467474
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
468-
enforceSchema=enforceSchema, emptyValue=emptyValue)
475+
enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale)
469476
if isinstance(path, basestring):
470477
path = [path]
471478
if type(path) == list:

python/pyspark/sql/streaming.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
404404
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
405405
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
406406
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
407-
multiLine=None, allowUnquotedControlChars=None, lineSep=None):
407+
multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None):
408408
"""
409409
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
410410
@@ -469,6 +469,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
469469
including tab and line feed characters) or not.
470470
:param lineSep: defines the line separator that should be used for parsing. If None is
471471
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
472+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
473+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
474+
parsing dates and timestamps.
472475
473476
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
474477
>>> json_sdf.isStreaming
@@ -483,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
483486
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
484487
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
485488
timestampFormat=timestampFormat, multiLine=multiLine,
486-
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
489+
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale)
487490
if isinstance(path, basestring):
488491
return self._df(self._jreader.json(path))
489492
else:
@@ -564,7 +567,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
564567
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
565568
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
566569
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
567-
enforceSchema=None, emptyValue=None):
570+
enforceSchema=None, emptyValue=None, locale=None):
568571
r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
569572
570573
This function will go through the input once to determine the input schema if
@@ -660,6 +663,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
660663
different, ``\0`` otherwise..
661664
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
662665
the default value, empty string.
666+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
667+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
668+
parsing dates and timestamps.
663669
664670
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
665671
>>> csv_sdf.isStreaming
@@ -677,7 +683,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
677683
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
678684
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
679685
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema,
680-
emptyValue=emptyValue)
686+
emptyValue=emptyValue, locale=locale)
681687
if isinstance(path, basestring):
682688
return self._df(self._jreader.csv(path))
683689
else:

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,16 @@ class CSVOptions(
131131
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
132132
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
133133

134+
// A language tag in IETF BCP 47 format
135+
val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)
136+
134137
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
135138
val dateFormat: FastDateFormat =
136-
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
139+
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)
137140

138141
val timestampFormat: FastDateFormat =
139142
FastDateFormat.getInstance(
140-
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
143+
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)
141144

142145
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
143146

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,19 @@ private[sql] class JSONOptions(
7676
// Whether to ignore column of all null values or empty array/struct during schema inference
7777
val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)
7878

79+
// A language tag in IETF BCP 47 format
80+
val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)
81+
7982
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
8083
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
8184

8285
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
8386
val dateFormat: FastDateFormat =
84-
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
87+
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)
8588

8689
val timestampFormat: FastDateFormat =
8790
FastDateFormat.getInstance(
88-
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
91+
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)
8992

9093
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
9194

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.util.Calendar
20+
import java.text.SimpleDateFormat
21+
import java.util.{Calendar, Locale}
2122

2223
import org.scalatest.exceptions.TestFailedException
2324

@@ -209,4 +210,20 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
209210
"2015-12-31T16:00:00"
210211
)
211212
}
213+
214+
test("parse date with locale") {
215+
Seq("en-US", "ru-RU").foreach { langTag =>
216+
val locale = Locale.forLanguageTag(langTag)
217+
val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
218+
val schema = new StructType().add("d", DateType)
219+
val dateFormat = "MMM yyyy"
220+
val sdf = new SimpleDateFormat(dateFormat, locale)
221+
val dateStr = sdf.format(date)
222+
val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)
223+
224+
checkEvaluation(
225+
CsvToStructs(schema, options, Literal.create(dateStr), gmtId),
226+
InternalRow(17836)) // number of days from 1970-01-01
227+
}
228+
}
212229
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.util.Calendar
20+
import java.text.SimpleDateFormat
21+
import java.util.{Calendar, Locale}
2122

2223
import org.scalatest.exceptions.TestFailedException
2324

@@ -737,4 +738,20 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
737738
CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))),
738739
"struct<col:bigint>")
739740
}
741+
742+
test("parse date with locale") {
743+
Seq("en-US", "ru-RU").foreach { langTag =>
744+
val locale = Locale.forLanguageTag(langTag)
745+
val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
746+
val schema = new StructType().add("d", DateType)
747+
val dateFormat = "MMM yyyy"
748+
val sdf = new SimpleDateFormat(dateFormat, locale)
749+
val dateStr = s"""{"d":"${sdf.format(date)}"}"""
750+
val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)
751+
752+
checkEvaluation(
753+
JsonToStructs(schema, options, Literal.create(dateStr), gmtId),
754+
InternalRow(17836)) // number of days from 1970-01-01
755+
}
756+
}
740757
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
384384
* for schema inferring.</li>
385385
* <li>`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
386386
* empty array/struct during schema inference.</li>
387+
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
388+
* For instance, this is used while parsing dates and timestamps.</li>
387389
* </ul>
388390
*
389391
* @since 2.0.0
@@ -604,6 +606,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
604606
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
605607
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
606608
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
609+
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
610+
* For instance, this is used while parsing dates and timestamps.</li>
607611
* </ul>
608612
*
609613
* @since 2.0.0

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
296296
* that should be used for parsing.</li>
297297
* <li>`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
298298
* empty array/struct during schema inference.</li>
299+
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
300+
* For instance, this is used while parsing dates and timestamps.</li>
299301
* </ul>
300302
*
301303
* @since 2.0.0
@@ -372,6 +374,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
372374
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
373375
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
374376
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
377+
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
378+
* For instance, this is used while parsing dates and timestamps.</li>
375379
* </ul>
376380
*
377381
* @since 2.0.0

sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.text.SimpleDateFormat
21+
import java.util.Locale
22+
2023
import scala.collection.JavaConverters._
2124

2225
import org.apache.spark.SparkException
@@ -164,4 +167,18 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext {
164167
val df1 = Seq(Tuple1(Tuple1(1))).toDF("a")
165168
checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil)
166169
}
170+
171+
test("parse timestamps with locale") {
172+
Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag =>
173+
val locale = Locale.forLanguageTag(langTag)
174+
val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00")
175+
val timestampFormat = "dd MMM yyyy HH:mm"
176+
val sdf = new SimpleDateFormat(timestampFormat, locale)
177+
val input = Seq(s"""${sdf.format(ts)}""").toDS()
178+
val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag)
179+
val df = input.select(from_csv($"value", lit("time timestamp"), options.asJava))
180+
181+
checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0"))))
182+
}
183+
}
167184
}

sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.text.SimpleDateFormat
21+
import java.util.Locale
22+
2023
import collection.JavaConverters._
2124

2225
import org.apache.spark.SparkException
@@ -591,4 +594,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
591594
df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))),
592595
Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil)
593596
}
597+
598+
test("parse timestamps with locale") {
599+
Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag =>
600+
val locale = Locale.forLanguageTag(langTag)
601+
val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00")
602+
val timestampFormat = "dd MMM yyyy HH:mm"
603+
val sdf = new SimpleDateFormat(timestampFormat, locale)
604+
val input = Seq(s"""{"time": "${sdf.format(ts)}"}""").toDS()
605+
val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag)
606+
val df = input.select(from_json($"value", "time timestamp", options))
607+
608+
checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0"))))
609+
}
610+
}
594611
}

0 commit comments

Comments
 (0)