Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,17 @@ class Analyzer(
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)
private def buildNonSelectExprs(bitmask: Int, exprs: Seq[Expression])
: collection.mutable.ArrayBuffer[Expression] = {
val buffer = new collection.mutable.ArrayBuffer[Expression]()

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
if (((bitmask >> bit) & 1) == 0) buffer += exprs(bit)
bit -= 1
}

set
buffer
}

/*
Expand Down Expand Up @@ -198,10 +198,12 @@ class Analyzer(

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
val nonSelectedGroupExprs = buildNonSelectExprs(bitmask, g.groupByExprs)

val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
case x: Expression
if nonSelectedGroupExprs.find(
ExpressionEquality(_) == ExpressionEquality(x)).isDefined =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we add semanticEquals method to Expression and implement it with equals as default? Then we can override semanticEquals for AttributeReference and just write _.semanticEquals(x) here.
ExpressionEquality seems a little hack to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you need to add semanticEquals for all of the non-leaf expression classes.

// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
case e: Attribute
if groupingExprs.find(ExpressionEquality(_) == ExpressionEquality(e)).isEmpty =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if groupingExprs.find(
ExpressionEquality(_) == ExpressionEquality(e)).isDefined => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ case class GroupExpression(children: Seq[Expression]) extends Expression {
override def dataType: DataType = throw new UnsupportedOperationException
}

/**
* The output expression probably will be invalid, and this is ONLY
* for expression equality checking purpose.
*/
object ExpressionEquality {
def apply(expr: Expression): Expression = expr.transformUp {
case n: AttributeReference =>
// This is a hack way to simplify the expression, as we don't care about
// the `name` for AttributeReference in semantic equality for an expression
new AttributeReference(null, n.dataType, n.nullable, n.metadata)(n.exprId, n.qualifiers)
}
}

/**
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ object PartialAggregation {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
namedGroupingExpressions
.get(e.transform { case Alias(g: ExtractValue, _) => g })
.map(_.toAttribute)
val ee = ExpressionEquality(e.transform { case Alias(g: ExtractValue, _) => g })
namedGroupingExpressions.find { case (k, v) => ExpressionEquality(k) == ee }
.map(_._2.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,24 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT `key` FROM src").collect().toSeq)
}

test("SPARK-7269 Check analysis failed in case in-sensitive") {
Seq(1, 2, 3).map { i =>
(i.toString, i.toString)
}.toDF("key", "value").registerTempTable("df_analysis")
sql("SELECT kEy from df_analysis group by key").collect()
sql("SELECT kEy+3 from df_analysis group by key+3").collect()
sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
sql("SELECT 2 from df_analysis A group by key+1").collect()
intercept[AnalysisException] {
sql("SELECT kEy+1 from df_analysis group by key+3")
}
intercept[AnalysisException] {
sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
}
}

test("SPARK-3814 Support Bitwise & operator") {
checkAnswer(
sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"),
Expand Down