Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class SessionCatalog(
functionRegistry,
conf,
new Configuration(),
CatalystSqlParser,
new CatalystSqlParser(conf),
DummyFunctionResourceLoader)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
* TableIdentifier.
*/
class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
import ParserUtils._

def this() = this(new SQLConf())

protected def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T]
}
Expand Down Expand Up @@ -1457,7 +1459,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Special characters can be escaped by using Hive/C-style escaping.
*/
private def createString(ctx: StringLiteralContext): String = {
if (SQLConf.get.escapedStringLiterals) {
if (conf.escapedStringLiterals) {
ctx.STRING().asScala.map(stringWithoutUnescape).mkString
} else {
ctx.STRING().asScala.map(string).mkString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}

/**
Expand Down Expand Up @@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
/**
* Concrete SQL parser for Catalyst-only SQL statements.
*/
class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser {
val astBuilder = new AstBuilder(conf)
}

/** For test-only. */
object CatalystSqlParser extends AbstractSqlParser {
val astBuilder = new AstBuilder
val astBuilder = new AstBuilder(new SQLConf())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ class ExpressionParserSuite extends PlanTest {
}

test("like expressions with ESCAPED_STRING_LITERALS = true") {
val parser = CatalystSqlParser
withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") {
assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
}
val conf = new SQLConf()
conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true")
val parser = new CatalystSqlParser(conf)
assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
}

test("is null expressions") {
Expand Down Expand Up @@ -435,85 +435,86 @@ class ExpressionParserSuite extends PlanTest {
}

test("strings") {
val parser = CatalystSqlParser
Seq(true, false).foreach { escape =>
withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> escape.toString) {
// tests that have same result whatever the conf is
// Single Strings.
assertEqual("\"hello\"", "hello", parser)
assertEqual("'hello'", "hello", parser)

// Multi-Strings.
assertEqual("\"hello\" 'world'", "helloworld", parser)
assertEqual("'hello' \" \" 'world'", "hello world", parser)

// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
// regular '%'; to get the correct result you need to add another escaped '\'.
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
assertEqual("'pattern%'", "pattern%", parser)
assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)

// tests that have different result regarding the conf
if (escape) {
// When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to
// Spark 1.6 behavior.

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)

// Escaped characters.
// Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work
// when ESCAPED_STRING_LITERALS is enabled.
// It is parsed literally.
assertEqual("'\\0'", "\\0", parser)

// Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is
// enabled.
val e = intercept[ParseException](parser.parseExpression("'\''"))
assert(e.message.contains("extraneous input '''"))

// The unescape special characters (e.g., "\\t") for 2.0+ don't work
// when ESCAPED_STRING_LITERALS is enabled. They are parsed literally.
assertEqual("'\\\"'", "\\\"", parser) // Double quote
assertEqual("'\\b'", "\\b", parser) // Backspace
assertEqual("'\\n'", "\\n", parser) // Newline
assertEqual("'\\r'", "\\r", parser) // Carriage return
assertEqual("'\\t'", "\\t", parser) // Tab character

// The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser)
// The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'",
"\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser)
} else {
// Default behavior

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser)

// Escaped characters.
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00')
assertEqual("'\\''", "\'", parser) // Single quote
assertEqual("'\\\"'", "\"", parser) // Double quote
assertEqual("'\\b'", "\b", parser) // Backspace
assertEqual("'\\n'", "\n", parser) // Newline
assertEqual("'\\r'", "\r", parser) // Carriage return
assertEqual("'\\t'", "\t", parser) // Tab character
assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows)

// Octals
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser)

// Unicode
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)",
parser)
}
val conf = new SQLConf()
conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString)
val parser = new CatalystSqlParser(conf)

// tests that have same result whatever the conf is
// Single Strings.
assertEqual("\"hello\"", "hello", parser)
assertEqual("'hello'", "hello", parser)

// Multi-Strings.
assertEqual("\"hello\" 'world'", "helloworld", parser)
assertEqual("'hello' \" \" 'world'", "hello world", parser)

// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
// regular '%'; to get the correct result you need to add another escaped '\'.
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
assertEqual("'pattern%'", "pattern%", parser)
assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)

// tests that have different result regarding the conf
if (escape) {
// When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to
// Spark 1.6 behavior.

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)

// Escaped characters.
// Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work
// when ESCAPED_STRING_LITERALS is enabled.
// It is parsed literally.
assertEqual("'\\0'", "\\0", parser)

// Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled.
val e = intercept[ParseException](parser.parseExpression("'\''"))
assert(e.message.contains("extraneous input '''"))

// The unescape special characters (e.g., "\\t") for 2.0+ don't work
// when ESCAPED_STRING_LITERALS is enabled. They are parsed literally.
assertEqual("'\\\"'", "\\\"", parser) // Double quote
assertEqual("'\\b'", "\\b", parser) // Backspace
assertEqual("'\\n'", "\\n", parser) // Newline
assertEqual("'\\r'", "\\r", parser) // Carriage return
assertEqual("'\\t'", "\\t", parser) // Tab character

// The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser)
// The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'",
"\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser)
} else {
// Default behavior

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser)

// Escaped characters.
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00')
assertEqual("'\\''", "\'", parser) // Single quote
assertEqual("'\\\"'", "\"", parser) // Double quote
assertEqual("'\\b'", "\b", parser) // Backspace
assertEqual("'\\n'", "\n", parser) // Newline
assertEqual("'\\r'", "\r", parser) // Carriage return
assertEqual("'\\t'", "\t", parser) // Tab character
assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows)

// Octals
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser)

// Unicode
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)",
parser)
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ import org.apache.spark.sql.types.StructType
/**
* Concrete parser for Spark SQL statements.
*/
class SparkSqlParser extends AbstractSqlParser {
class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
val astBuilder = new SparkSqlAstBuilder(conf)

val astBuilder = new SparkSqlAstBuilder

private val substitutor = new VariableSubstitution
private val substitutor = new VariableSubstitution(conf)

protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
super.parse(substitutor.substitute(command))(toResult)
Expand All @@ -53,11 +52,9 @@ class SparkSqlParser extends AbstractSqlParser {
/**
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
*/
class SparkSqlAstBuilder extends AstBuilder {
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
import org.apache.spark.sql.catalyst.parser.ParserUtils._

private def conf: SQLConf = SQLConf.get

/**
* Create a [[SetCommand]] logical plan.
*
Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -1275,7 +1276,7 @@ object functions {
*/
def expr(expr: String): Column = {
val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
new SparkSqlParser
new SparkSqlParser(new SQLConf)
}
Column(parser.parseExpression(expr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ abstract class BaseSessionStateBuilder(
* Note: this depends on the `conf` field.
*/
protected lazy val sqlParser: ParserInterface = {
extensions.buildParser(session, new SparkSqlParser)
extensions.buildParser(session, new SparkSqlParser(conf))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ import org.apache.spark.internal.config._
*
* Variable substitution is controlled by `SQLConf.variableSubstituteEnabled`.
*/
class VariableSubstitution {

private def conf = SQLConf.get
class VariableSubstitution(conf: SQLConf) {

private val provider = new ConfigProvider {
override def get(key: String): Option[String] = Option(conf.getConfString(key, ""))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType
*/
class SparkSqlParserSuite extends AnalysisTest {

private lazy val parser = new SparkSqlParser
val newConf = new SQLConf
private lazy val parser = new SparkSqlParser(newConf)

/**
* Normalizes plans:
Expand Down Expand Up @@ -284,7 +285,6 @@ class SparkSqlParserSuite extends AnalysisTest {
}

test("query organization") {
val conf = SQLConf.get
// Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
val baseSql = "select * from t"
val basePlan =
Expand All @@ -293,20 +293,20 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual(s"$baseSql distribute by a, b",
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = conf.numShufflePartitions))
numPartitions = newConf.numShufflePartitions))
assertEqual(s"$baseSql distribute by a sort by b",
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
basePlan,
numPartitions = conf.numShufflePartitions)))
numPartitions = newConf.numShufflePartitions)))
assertEqual(s"$baseSql cluster by a, b",
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = conf.numShufflePartitions)))
numPartitions = newConf.numShufflePartitions)))
}

test("pipeline concatenation") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


// TODO: merge this with DDLSuite (SPARK-14441)
class DDLCommandSuite extends PlanTest {
private lazy val parser = new SparkSqlParser
private lazy val parser = new SparkSqlParser(new SQLConf)

private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = {
val e = intercept[ParseException] {
Expand Down
Loading