Skip to content

Commit eb63949

Browse files
TeodorDjeliccloud-fan
authored andcommitted
[SPARK-52908][CORE] Prevent for iterator variable name clashing with names of labels in the path to the root of AST
### What changes were proposed in this pull request? Proposed change is to explicitly prohibit the interaction of iterator variable hiding the scoped variable if the label of scope and the iterator variable names are the same. ### Why are the changes needed? For iterator variable hides scoped variables if the label of the scope and iterator variable name are the same. This interaction leads to undesirable behavior: - Column of the iterator variable and a variable in scope having the same name will result in the column of the iterator variable hiding the variable in scope; - Trying to access the variable in scope that does not clash with the column of the iterator variable will result in the compiler not being able to resolve the variable in scope. ### Does this PR introduce _any_ user-facing change? Yes, it does. Changes are: - Error LABEL_ALREADY_DEFINED was renamed to LABEL_OR_FOR_VARIABLE_ALREADY_DEFINED; - Error LABEL_NAME_FORBIDDEN was renamed to LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN. Old behavior: ![2C77E459-18FA-4B0E-8BDF-AA5D3E9B412F](https://github.com/user-attachments/assets/007ec190-e27f-4d70-b5a8-fb0272e05057) <img width="1198" height="214" alt="image" src="https://github.com/user-attachments/assets/b2586697-49e9-4cbc-b57e-53b6c91700bc" /> New behavior: <img width="1618" height="162" alt="image" src="https://github.com/user-attachments/assets/d023715c-08a1-47a2-9db1-3a19758140d6" /> <img width="1393" height="110" alt="image" src="https://github.com/user-attachments/assets/855d80c1-3fbd-42a6-ab1e-f664c4b4b47e" /> ### How was this patch tested? New tests in SqlScriptingExecutionSuite and existing tests. Instead of printing a variable resolution exception, exception printed is stating the prohibition of such interactions. Old behavior: <img width="1335" height="380" alt="467960247-895da398-3ace-4334-b597-1be4a400acf4" src="https://github.com/user-attachments/assets/2d9412e1-5896-4286-b230-b728315a0fc6" /> New behavior: <img width="1651" height="263" alt="467961070-92a0cbb7-bb93-410a-8266-3ef2591350f8" src="https://github.com/user-attachments/assets/5a44b505-6be2-4c1a-9f05-7fa563eab441" /> ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51595 from TeodorDjelic/prevent-for-iterator-variable-name-clashing-with-names-of-labels-in-the-path-to-the-root-of-ast. Authored-by: Teodor Djelic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3713725 commit eb63949

File tree

6 files changed

+206
-29
lines changed

6 files changed

+206
-29
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4051,15 +4051,15 @@
40514051
],
40524052
"sqlState" : "42K0L"
40534053
},
4054-
"LABEL_ALREADY_EXISTS" : {
4054+
"LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS" : {
40554055
"message" : [
4056-
"The label <label> already exists. Choose another name or rename the existing label."
4056+
"The label or FOR variable <label> already exists. Choose another name or rename the existing one."
40574057
],
40584058
"sqlState" : "42K0L"
40594059
},
4060-
"LABEL_NAME_FORBIDDEN" : {
4060+
"LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN" : {
40614061
"message" : [
4062-
"The label name <label> is forbidden."
4062+
"The label or FOR variable name <label> is forbidden."
40634063
],
40644064
"sqlState" : "42K0L"
40654065
},

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.AnalysisException
2626
import org.apache.spark.sql.catalyst.SqlScriptingContextManager
2727
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.expressions.SubExprUtils.wrapOuterReference
29-
import org.apache.spark.sql.catalyst.parser.SqlScriptingLabelContext.isForbiddenLabelName
29+
import org.apache.spark.sql.catalyst.parser.SqlScriptingLabelContext.isForbiddenLabelOrForVariableName
3030
import org.apache.spark.sql.catalyst.plans.logical._
3131
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
3232
import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -269,7 +269,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
269269
.filterNot(_ => AnalysisContext.get.isExecuteImmediate)
270270
// If variable name is qualified with session.<varName> treat it as a session variable.
271271
.filterNot(_ =>
272-
nameParts.length > 2 || (nameParts.length == 2 && isForbiddenLabelName(nameParts.head)))
272+
nameParts.length > 2
273+
|| (nameParts.length == 2 && isForbiddenLabelOrForVariableName(nameParts.head)))
273274
.flatMap(_.get(namePartsCaseAdjusted))
274275
.map { varDef =>
275276
VariableReference(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,13 +557,15 @@ class AstBuilder extends DataTypeAstBuilder
557557
val query = withOrigin(queryCtx) {
558558
SingleStatement(visitQuery(queryCtx))
559559
}
560+
parsingCtx.labelContext.enterForScope(Option(ctx.multipartIdentifier()))
560561
val varName = Option(ctx.multipartIdentifier()).map(_.getText)
561562
val body = visitCompoundBodyImpl(
562563
ctx.compoundBody(),
563564
None,
564565
parsingCtx,
565566
isScope = false
566567
)
568+
parsingCtx.labelContext.exitForScope(Option(ctx.multipartIdentifier()))
567569
parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
568570

569571
ForStatement(query, varName, body, Some(labelText))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}
2828

2929
import org.apache.spark.SparkException
3030
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
31-
import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext}
31+
import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext, MultipartIdentifierContext}
3232
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, ErrorCondition}
3333
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
3434
import org.apache.spark.sql.catalyst.util.SparkParserUtils
@@ -316,6 +316,23 @@ class SqlScriptingLabelContext {
316316
beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
317317
}
318318

319+
/**
320+
* Assert the identifier is not contained within seenLabels.
321+
* If the identifier is contained within seenLabels, raise an exception.
322+
*/
323+
private def assertIdentifierNotInSeenLabels(
324+
identifierCtx: Option[MultipartIdentifierContext]): Unit = {
325+
identifierCtx.foreach { ctx =>
326+
val identifierName = ctx.getText
327+
if (seenLabels.contains(identifierName.toLowerCase(Locale.ROOT))) {
328+
withOrigin(ctx) {
329+
throw SqlScriptingErrors
330+
.duplicateLabels(CurrentOrigin.get, identifierName.toLowerCase(Locale.ROOT))
331+
}
332+
}
333+
}
334+
}
335+
319336
/**
320337
* Enter a labeled scope and return the label text.
321338
* If the label is defined, it will be returned and added to seenLabels.
@@ -342,9 +359,9 @@ class SqlScriptingLabelContext {
342359
// Do not add the label to the seenLabels set if it is not defined.
343360
java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)
344361
}
345-
if (SqlScriptingLabelContext.isForbiddenLabelName(labelText)) {
362+
if (SqlScriptingLabelContext.isForbiddenLabelOrForVariableName(labelText)) {
346363
withOrigin(beginLabelCtx.get) {
347-
throw SqlScriptingErrors.labelNameForbidden(CurrentOrigin.get, labelText)
364+
throw SqlScriptingErrors.labelOrForVariableNameForbidden(CurrentOrigin.get, labelText)
348365
}
349366
}
350367
labelText
@@ -359,13 +376,46 @@ class SqlScriptingLabelContext {
359376
seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
360377
}
361378
}
379+
380+
/**
381+
* Enter a for loop scope.
382+
* If the for loop variable is defined, it will be asserted to not be inside seenLabels;
383+
* Then, if the for loop variable is defined, it will be added to seenLabels.
384+
*/
385+
def enterForScope(identifierCtx: Option[MultipartIdentifierContext]): Unit = {
386+
identifierCtx.foreach { ctx =>
387+
val identifierName = ctx.getText
388+
assertIdentifierNotInSeenLabels(identifierCtx)
389+
seenLabels.add(identifierName.toLowerCase(Locale.ROOT))
390+
391+
if (SqlScriptingLabelContext.isForbiddenLabelOrForVariableName(identifierName)) {
392+
withOrigin(ctx) {
393+
throw SqlScriptingErrors.labelOrForVariableNameForbidden(
394+
CurrentOrigin.get,
395+
identifierName.toLowerCase(Locale.ROOT))
396+
}
397+
}
398+
}
399+
}
400+
401+
/**
402+
* Exit a for loop scope.
403+
* If the for loop variable is defined, it will be removed from seenLabels.
404+
*/
405+
def exitForScope(identifierCtx: Option[MultipartIdentifierContext]): Unit = {
406+
identifierCtx.foreach { ctx =>
407+
val identifierName = ctx.getText
408+
seenLabels.remove(identifierName.toLowerCase(Locale.ROOT))
409+
}
410+
}
411+
362412
}
363413

364414
object SqlScriptingLabelContext {
365415
private val forbiddenLabelNames: immutable.Set[Regex] =
366416
immutable.Set("builtin".r, "session".r, "sys.*".r)
367417

368-
def isForbiddenLabelName(labelName: String): Boolean = {
418+
def isForbiddenLabelOrForVariableName(labelName: String): Boolean = {
369419
forbiddenLabelNames.exists(_.matches(labelName.toLowerCase(Locale.ROOT)))
370420
}
371421
}

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ private[sql] object SqlScriptingErrors {
3333
def duplicateLabels(origin: Origin, label: String): Throwable = {
3434
new SqlScriptingException(
3535
origin = origin,
36-
errorClass = "LABEL_ALREADY_EXISTS",
36+
errorClass = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS",
3737
cause = null,
3838
messageParameters = Map("label" -> toSQLId(label)))
3939
}
@@ -54,10 +54,10 @@ private[sql] object SqlScriptingErrors {
5454
messageParameters = Map("endLabel" -> toSQLId(endLabel)))
5555
}
5656

57-
def labelNameForbidden(origin: Origin, label: String): Throwable = {
57+
def labelOrForVariableNameForbidden(origin: Origin, label: String): Throwable = {
5858
new SqlScriptingException(
5959
origin = origin,
60-
errorClass = "LABEL_NAME_FORBIDDEN",
60+
errorClass = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN",
6161
cause = null,
6262
messageParameters = Map("label" -> toSQLId(label))
6363
)

0 commit comments

Comments
 (0)