Skip to content

Commit 715c589

Browse files
committed
address comments
1 parent 7ea5b31 commit 715c589

File tree

13 files changed

+96
-92
lines changed

13 files changed

+96
-92
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
375375
protected lazy val primary: PackratParser[Expression] =
376376
( literal
377377
| expression ~ ("[" ~> expression <~ "]") ^^
378-
{ case base ~ ordinal => UnresolvedGetField(base, ordinal) }
378+
{ case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
379379
| (expression <~ ".") ~ ident ^^
380-
{ case base ~ fieldName => UnresolvedGetField(base, Literal(fieldName)) }
380+
{ case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
381381
| cast
382382
| "(" ~> expression <~ ")"
383383
| function

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ class Analyzer(
311311
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
312312
logDebug(s"Resolving $u to $result")
313313
result
314-
case UnresolvedGetField(child, fieldExpr) if child.resolved =>
315-
GetField(child, fieldExpr, resolver)
314+
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
315+
ExtractValue(child, fieldExpr, resolver)
316316
}
317317
}
318318

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,16 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
185185
}
186186

187187
/**
188-
* Get field of an expression
188+
* Extracts a value or values from an Expression
189189
*
190-
* @param child The expression to get field of, can be Map, Array, Struct or array of Struct.
191-
* @param fieldExpr The expression to describe the field,
192-
* can be key of Map, index of Array, field name of Struct.
190+
* @param child The expression to extract value from,
191+
* can be Map, Array, Struct or array of Structs.
192+
* @param extraction The expression to describe the extraction,
193+
* can be key of Map, index of Array, field name of Struct.
193194
*/
194-
case class UnresolvedGetField(child: Expression, fieldExpr: Expression) extends UnaryExpression {
195+
case class UnresolvedExtractValue(child: Expression, extraction: Expression)
196+
extends UnaryExpression {
197+
195198
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
196199
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
197200
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -200,5 +203,5 @@ case class UnresolvedGetField(child: Expression, fieldExpr: Expression) extends
200203
override def eval(input: Row = null): EvaluatedType =
201204
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
202205

203-
override def toString: String = s"$child[$fieldExpr]"
206+
override def toString: String = s"$child[$extraction]"
204207
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
2222
import scala.language.implicitConversions
2323
import scala.reflect.runtime.universe.{TypeTag, typeTag}
2424

25-
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
25+
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute}
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -100,9 +100,9 @@ package object dsl {
100100
def isNull: Predicate = IsNull(expr)
101101
def isNotNull: Predicate = IsNotNull(expr)
102102

103-
def getItem(ordinal: Expression): UnresolvedGetField = UnresolvedGetField(expr, ordinal)
104-
def getField(fieldName: String): UnresolvedGetField =
105-
UnresolvedGetField(expr, Literal(fieldName))
103+
def getItem(ordinal: Expression): UnresolvedExtractValue = UnresolvedExtractValue(expr, ordinal)
104+
def getField(fieldName: String): UnresolvedExtractValue =
105+
UnresolvedExtractValue(expr, Literal(fieldName))
106106

107107
def cast(to: DataType): Expression = Cast(expr, to)
108108

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GetField.scala renamed to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,50 +23,50 @@ import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.analysis._
2424
import org.apache.spark.sql.types._
2525

26-
object GetField {
26+
object ExtractValue {
2727
/**
28-
* Returns the resolved `GetField`. It will return one kind of concrete `GetField`,
29-
* depend on the type of `child` and `fieldExpr`.
28+
* Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`,
29+
* depend on the type of `child` and `extraction`.
3030
*
31-
* `child` | `fieldExpr` | concrete `GetField`
32-
* -------------------------------------------------------------
33-
* Struct | Literal String | SimpleStructGetField
34-
* Array[Struct] | Literal String | ArrayStructGetField
35-
* Array | Integral type | ArrayOrdinalGetField
36-
* Map | Any type | MapOrdinalGetField
31+
* `child` | `extraction` | concrete `ExtractValue`
32+
* ----------------------------------------------------------------
33+
* Struct | Literal String | GetStructField
34+
* Array[Struct] | Literal String | GetArrayStructFields
35+
* Array | Integral type | GetArrayItem
36+
* Map | Any type | GetMapValue
3737
*/
3838
def apply(
3939
child: Expression,
40-
fieldExpr: Expression,
41-
resolver: Resolver): GetField = {
40+
extraction: Expression,
41+
resolver: Resolver): ExtractValue = {
4242

43-
(child.dataType, fieldExpr) match {
43+
(child.dataType, extraction) match {
4444
case (StructType(fields), Literal(fieldName, StringType)) =>
4545
val ordinal = findField(fields, fieldName.toString, resolver)
46-
SimpleStructGetField(child, fields(ordinal), ordinal)
46+
GetStructField(child, fields(ordinal), ordinal)
4747
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
4848
val ordinal = findField(fields, fieldName.toString, resolver)
49-
ArrayStructGetField(child, fields(ordinal), ordinal, containsNull)
50-
case (_: ArrayType, _) if fieldExpr.dataType.isInstanceOf[IntegralType] =>
51-
ArrayOrdinalGetField(child, fieldExpr)
49+
GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
50+
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
51+
GetArrayItem(child, extraction)
5252
case (_: MapType, _) =>
53-
MapOrdinalGetField(child, fieldExpr)
53+
GetMapValue(child, extraction)
5454
case (otherType, _) =>
5555
val errorMsg = otherType match {
5656
case StructType(_) | ArrayType(StructType(_), _) =>
57-
s"Field name should be String Literal, but it's $fieldExpr"
57+
s"Field name should be String Literal, but it's $extraction"
5858
case _: ArrayType =>
59-
s"Array index should be integral type, but it's ${fieldExpr.dataType}"
59+
s"Array index should be integral type, but it's ${extraction.dataType}"
6060
case other =>
61-
s"Can't get field on $child"
61+
s"Can't extract value from $child"
6262
}
6363
throw new AnalysisException(errorMsg)
6464
}
6565
}
6666

67-
def unapply(g: GetField): Option[(Expression, Expression)] = {
67+
def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
6868
g match {
69-
case o: OrdinalGetField => Some((o.child, o.ordinal))
69+
case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
7070
case _ => Some((g.child, null))
7171
}
7272
}
@@ -90,7 +90,7 @@ object GetField {
9090
}
9191
}
9292

93-
trait GetField extends UnaryExpression {
93+
trait ExtractValue extends UnaryExpression {
9494
self: Product =>
9595

9696
type EvaluatedType = Any
@@ -99,8 +99,8 @@ trait GetField extends UnaryExpression {
9999
/**
100100
* Returns the value of fields in the Struct `child`.
101101
*/
102-
case class SimpleStructGetField(child: Expression, field: StructField, ordinal: Int)
103-
extends GetField {
102+
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
103+
extends ExtractValue {
104104

105105
override def dataType: DataType = field.dataType
106106
override def nullable: Boolean = child.nullable || field.nullable
@@ -116,11 +116,11 @@ case class SimpleStructGetField(child: Expression, field: StructField, ordinal:
116116
/**
117117
* Returns the array of value of fields in the Array of Struct `child`.
118118
*/
119-
case class ArrayStructGetField(
119+
case class GetArrayStructFields(
120120
child: Expression,
121121
field: StructField,
122122
ordinal: Int,
123-
containsNull: Boolean) extends GetField {
123+
containsNull: Boolean) extends ExtractValue {
124124

125125
override def dataType: DataType = ArrayType(field.dataType, containsNull)
126126
override def nullable: Boolean = child.nullable
@@ -137,7 +137,7 @@ case class ArrayStructGetField(
137137
}
138138
}
139139

140-
abstract class OrdinalGetField extends GetField {
140+
abstract class ExtractValueWithOrdinal extends ExtractValue {
141141
self: Product =>
142142

143143
def ordinal: Expression
@@ -168,8 +168,8 @@ abstract class OrdinalGetField extends GetField {
168168
/**
169169
* Returns the field at `ordinal` in the Array `child`
170170
*/
171-
case class ArrayOrdinalGetField(child: Expression, ordinal: Expression)
172-
extends OrdinalGetField {
171+
case class GetArrayItem(child: Expression, ordinal: Expression)
172+
extends ExtractValueWithOrdinal {
173173

174174
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
175175

@@ -192,8 +192,8 @@ case class ArrayOrdinalGetField(child: Expression, ordinal: Expression)
192192
/**
193193
* Returns the value of key `ordinal` in Map `child`
194194
*/
195-
case class MapOrdinalGetField(child: Expression, ordinal: Expression)
196-
extends OrdinalGetField {
195+
case class GetMapValue(child: Expression, ordinal: Expression)
196+
extends ExtractValueWithOrdinal {
197197

198198
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
199199

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ object NullPropagation extends Rule[LogicalPlan] {
227227
case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
228228
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
229229
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
230-
case e @ GetField(Literal(null, _), _) => Literal.create(null, e.dataType)
231-
case e @ GetField(_, Literal(null, _)) => Literal.create(null, e.dataType)
230+
case e @ ExtractValue(Literal(null, _), _) => Literal.create(null, e.dataType)
231+
case e @ ExtractValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
232232
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
233233
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
234234
case e @ Count(expr) if !expr.nullable => Count(Literal(1))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ object PartialAggregation {
160160
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
161161
// (Should we just turn `GetField` into a `NamedExpression`?)
162162
namedGroupingExpressions
163-
.get(e.transform { case Alias(g: GetField, _) => g })
163+
.get(e.transform { case Alias(g: ExtractValue, _) => g })
164164
.map(_.toAttribute)
165165
.getOrElse(e)
166166
}).asInstanceOf[Seq[NamedExpression]]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
210210
// Then this will add GetField("c", GetField("b", a)), and alias
211211
// the final expression as "c".
212212
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
213-
GetField(expr, Literal(fieldName), resolver))
213+
ExtractValue(expr, Literal(fieldName), resolver))
214214
val aliasName = nestedFields.last
215215
Some(Alias(fieldExprs, aliasName)())
216216
} catch {

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ abstract class DataType {
4343
/**
4444
* Enables matching against DataType for expressions:
4545
* {{{
46-
* case Cast(child @ DataType(), StringType) =>
46+
* case Cast(child @ BinaryType(), StringType) =>
4747
* ...
4848
* }}}
4949
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.scalatest.FunSuite
2626
import org.scalatest.Matchers._
2727

2828
import org.apache.spark.sql.catalyst.CatalystTypeConverters
29-
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
29+
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
3030
import org.apache.spark.sql.catalyst.dsl.expressions._
3131
import org.apache.spark.sql.catalyst.expressions.mathfuncs._
3232
import org.apache.spark.sql.types._
@@ -891,57 +891,55 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
891891
val typeMap = MapType(StringType, StringType)
892892
val typeArray = ArrayType(StringType)
893893

894-
checkEvaluation(MapOrdinalGetField(BoundReference(3, typeMap, true),
894+
checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
895895
Literal("aa")), "bb", row)
896-
checkEvaluation(MapOrdinalGetField(Literal.create(null, typeMap), Literal("aa")), null, row)
896+
checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row)
897897
checkEvaluation(
898-
MapOrdinalGetField(Literal.create(null, typeMap),
899-
Literal.create(null, StringType)), null, row)
900-
checkEvaluation(MapOrdinalGetField(BoundReference(3, typeMap, true),
898+
GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
899+
checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
901900
Literal.create(null, StringType)), null, row)
902901

903-
checkEvaluation(ArrayOrdinalGetField(BoundReference(4, typeArray, true),
902+
checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
904903
Literal(1)), "bb", row)
905-
checkEvaluation(ArrayOrdinalGetField(Literal.create(null, typeArray), Literal(1)), null, row)
904+
checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row)
906905
checkEvaluation(
907-
ArrayOrdinalGetField(Literal.create(null, typeArray),
908-
Literal.create(null, IntegerType)), null, row)
909-
checkEvaluation(ArrayOrdinalGetField(BoundReference(4, typeArray, true),
906+
GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row)
907+
checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
910908
Literal.create(null, IntegerType)), null, row)
911909

912-
def quickBuildGetField(expr: Expression, fieldName: String): GetField = {
910+
def getStructField(expr: Expression, fieldName: String): ExtractValue = {
913911
expr.dataType match {
914912
case StructType(fields) =>
915913
val field = fields.find(_.name == fieldName).get
916-
SimpleStructGetField(expr, field, fields.indexOf(field))
914+
GetStructField(expr, field, fields.indexOf(field))
917915
}
918916
}
919917

920-
def resolveGetField(u: UnresolvedGetField): GetField = {
921-
GetField(u.child, u.fieldExpr, _ == _)
918+
def quickResolve(u: UnresolvedExtractValue): ExtractValue = {
919+
ExtractValue(u.child, u.extraction, _ == _)
922920
}
923921

924-
checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
925-
checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row)
922+
checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
923+
checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row)
926924

927925
val typeS_notNullable = StructType(
928926
StructField("a", StringType, nullable = false)
929927
:: StructField("b", StringType, nullable = false) :: Nil
930928
)
931929

932-
assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
933-
assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
930+
assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
931+
assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
934932
=== false)
935933

936-
assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true)
937-
assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true)
934+
assert(getStructField(Literal.create(null, typeS), "a").nullable === true)
935+
assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true)
938936

939-
checkEvaluation(resolveGetField('c.map(typeMap).at(3).getItem("aa")), "bb", row)
940-
checkEvaluation(resolveGetField('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row)
941-
checkEvaluation(resolveGetField('c.struct(typeS).at(2).getField("a")), "aa", row)
937+
checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row)
938+
checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row)
939+
checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
942940
}
943941

944-
test("error message of GetField") {
942+
test("error message of ExtractValue") {
945943
val structType = StructType(StructField("a", StringType, true) :: Nil)
946944
val arrayStructType = ArrayType(structType)
947945
val arrayType = ArrayType(StringType)
@@ -952,7 +950,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
952950
fieldDataType: DataType,
953951
errorMesage: String): Unit = {
954952
val e = intercept[org.apache.spark.sql.AnalysisException] {
955-
GetField(
953+
ExtractValue(
956954
Literal.create(null, childDataType),
957955
Literal.create(null, fieldDataType),
958956
_ == _)
@@ -963,7 +961,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
963961
checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
964962
checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
965963
checkErrorMessage(arrayType, StringType, "Array index should be integral type")
966-
checkErrorMessage(otherType, StringType, "Can't get field on")
964+
checkErrorMessage(otherType, StringType, "Can't extract value from")
967965
}
968966

969967
test("arithmetic") {

0 commit comments

Comments
 (0)