Skip to content

Commit 369c40c

Browse files
panbingkunMaxGekk
authored andcommitted
[SPARK-50067][SQL] Codegen Support for SchemaOfCsv(by Invoke & RuntimeReplaceable)
### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `schema_of_csv`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: CsvFunctionsSuite#`*schema_of_csv*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48595 from panbingkun/SPARK-50067. Authored-by: panbingkun <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 2cb7a16 commit 369c40c

File tree

4 files changed

+69
-26
lines changed

4 files changed

+69
-26
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.expressions.csv
18+
19+
import com.univocity.parsers.csv.CsvParser
20+
21+
import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions}
22+
import org.apache.spark.sql.types.{DataType, NullType, StructType}
23+
import org.apache.spark.unsafe.types.UTF8String
24+
25+
case class SchemaOfCsvEvaluator(options: Map[String, String]) {
26+
27+
@transient
28+
private lazy val csvOptions: CSVOptions = {
29+
// 'lineSep' is a plan-wise option so we set a noncharacter, according to
30+
// the unicode specification, which should not appear in Java's strings.
31+
// See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
32+
// scalastyle:off nonascii
33+
val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
34+
// scalastyle:on nonascii
35+
new CSVOptions(exprOptions, true, "UTC")
36+
}
37+
38+
@transient
39+
private lazy val csvParser: CsvParser = new CsvParser(csvOptions.asParserSettings)
40+
41+
@transient
42+
private lazy val csvInferSchema = new CSVInferSchema(csvOptions)
43+
44+
final def evaluate(csv: UTF8String): Any = {
45+
val row = csvParser.parseLine(csv.toString)
46+
assert(row != null, "Parsed CSV record should not be null.")
47+
val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
48+
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
49+
val fieldTypes = csvInferSchema.inferRowType(startType, row)
50+
val st = StructType(csvInferSchema.toStructFields(fieldTypes, header))
51+
UTF8String.fromString(st.sql)
52+
}
53+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.io.CharArrayWriter
2121

22-
import com.univocity.parsers.csv.CsvParser
23-
2422
import org.apache.spark.SparkException
2523
import org.apache.spark.sql.catalyst.InternalRow
2624
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2725
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
2826
import org.apache.spark.sql.catalyst.csv._
2927
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
28+
import org.apache.spark.sql.catalyst.expressions.csv.SchemaOfCsvEvaluator
29+
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
3030
import org.apache.spark.sql.catalyst.util._
3131
import org.apache.spark.sql.catalyst.util.TypeUtils._
3232
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
@@ -170,7 +170,7 @@ case class CsvToStructs(
170170
case class SchemaOfCsv(
171171
child: Expression,
172172
options: Map[String, String])
173-
extends UnaryExpression with CodegenFallback with QueryErrorsBase {
173+
extends UnaryExpression with RuntimeReplaceable with QueryErrorsBase {
174174

175175
def this(child: Expression) = this(child, Map.empty[String, String])
176176

@@ -202,30 +202,20 @@ case class SchemaOfCsv(
202202
}
203203
}
204204

205-
override def eval(v: InternalRow): Any = {
206-
// 'lineSep' is a plan-wise option so we set a noncharacter, according to
207-
// the unicode specification, which should not appear in Java's strings.
208-
// See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
209-
// scalastyle:off nonascii
210-
val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
211-
// scalastyle:on nonascii
212-
val parsedOptions = new CSVOptions(exprOptions, true, "UTC")
213-
val parser = new CsvParser(parsedOptions.asParserSettings)
214-
val row = parser.parseLine(csv.toString)
215-
assert(row != null, "Parsed CSV record should not be null.")
216-
217-
val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
218-
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
219-
val inferSchema = new CSVInferSchema(parsedOptions)
220-
val fieldTypes = inferSchema.inferRowType(startType, row)
221-
val st = StructType(inferSchema.toStructFields(fieldTypes, header))
222-
UTF8String.fromString(st.sql)
223-
}
224-
225205
override def prettyName: String = "schema_of_csv"
226206

227207
override protected def withNewChildInternal(newChild: Expression): SchemaOfCsv =
228208
copy(child = newChild)
209+
210+
@transient
211+
private lazy val evaluator: SchemaOfCsvEvaluator = SchemaOfCsvEvaluator(options)
212+
213+
override def replacement: Expression = Invoke(
214+
Literal.create(evaluator, ObjectType(classOf[SchemaOfCsvEvaluator])),
215+
"evaluate",
216+
dataType,
217+
Seq(child),
218+
Seq(child.dataType))
229219
}
230220

231221
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ case class JsonToStructsEvaluator(
6464
nullableSchema: DataType,
6565
nameOfCorruptRecord: String,
6666
timeZoneId: Option[String],
67-
variantAllowDuplicateKeys: Boolean) extends Serializable {
67+
variantAllowDuplicateKeys: Boolean) {
6868

6969
// This converts parsed rows to the desired output by the given schema.
7070
@transient
@@ -117,7 +117,7 @@ case class JsonToStructsEvaluator(
117117
case class StructsToJsonEvaluator(
118118
options: Map[String, String],
119119
inputSchema: DataType,
120-
timeZoneId: Option[String]) extends Serializable {
120+
timeZoneId: Option[String]) {
121121

122122
@transient
123123
private lazy val writer = new CharArrayWriter()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [schema_of_csv(1|abc, (sep,|)) AS schema_of_csv(1|abc)#0]
1+
Project [invoke(SchemaOfCsvEvaluator(Map(sep -> |)).evaluate(1|abc)) AS schema_of_csv(1|abc)#0]
22
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 commit comments

Comments
 (0)