@@ -23,9 +23,40 @@ import scala.language.implicitConversions
2323import org .apache .spark .annotation .Experimental
2424import org .apache .spark .sql .catalyst .analysis .Star
2525import org .apache .spark .sql .catalyst .expressions ._
26- import org .apache .spark .sql .catalyst .plans .logical .Aggregate
26+ import org .apache .spark .sql .catalyst .plans .logical .{ Rollup , Cube , Aggregate }
2727import org .apache .spark .sql .types .NumericType
2828
29+ /**
30+ * Companion object for GroupedData
31+ */
32+ private [sql] object GroupedData {
33+ def apply (
34+ df : DataFrame ,
35+ groupingExprs : Seq [Expression ],
36+ groupType : GroupType ): GroupedData = {
37+ new GroupedData (df, groupingExprs, groupType : GroupType )
38+ }
39+
40+ /**
41+ * The Grouping Type
42+ */
43+ trait GroupType
44+
45+ /**
46+ * To indicate it's the GroupBy
47+ */
48+ object GroupByType extends GroupType
49+
50+ /**
51+ * To indicate it's the CUBE
52+ */
53+ object CubeType extends GroupType
54+
55+ /**
56+ * To indicate it's the ROLLUP
57+ */
58+ object RollupType extends GroupType
59+ }
2960
3061/**
3162 * :: Experimental ::
@@ -34,19 +65,37 @@ import org.apache.spark.sql.types.NumericType
3465 * @since 1.3.0
3566 */
3667@ Experimental
37- class GroupedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ]) {
68+ class GroupedData protected [sql](
69+ df : DataFrame ,
70+ groupingExprs : Seq [Expression ],
71+ private val groupType : GroupedData .GroupType ) {
3872
39- private [sql] implicit def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
40- val namedGroupingExprs = groupingExprs.map {
41- case expr : NamedExpression => expr
42- case expr : Expression => Alias (expr, expr.prettyString)()
73+ private [this ] def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
74+ val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
75+ val retainedExprs = groupingExprs.map {
76+ case expr : NamedExpression => expr
77+ case expr : Expression => Alias (expr, expr.prettyString)()
78+ }
79+ retainedExprs ++ aggExprs
80+ } else {
81+ aggExprs
82+ }
83+
84+ groupType match {
85+ case GroupedData .GroupByType =>
86+ DataFrame (
87+ df.sqlContext, Aggregate (groupingExprs, aggregates, df.logicalPlan))
88+ case GroupedData .RollupType =>
89+ DataFrame (
90+ df.sqlContext, Rollup (groupingExprs, df.logicalPlan, aggregates))
91+ case GroupedData .CubeType =>
92+ DataFrame (
93+ df.sqlContext, Cube (groupingExprs, df.logicalPlan, aggregates))
4394 }
44- DataFrame (
45- df.sqlContext, Aggregate (groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
4695 }
4796
4897 private [this ] def aggregateNumericColumns (colNames : String * )(f : Expression => Expression )
49- : Seq [ NamedExpression ] = {
98+ : DataFrame = {
5099
51100 val columnExprs = if (colNames.isEmpty) {
52101 // No columns specified. Use all numeric columns.
@@ -63,10 +112,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
63112 namedExpr
64113 }
65114 }
66- columnExprs.map { c =>
115+ toDF( columnExprs.map { c =>
67116 val a = f(c)
68117 Alias (a, a.prettyString)()
69- }
118+ })
70119 }
71120
72121 private [this ] def strToExpr (expr : String ): (Expression => Expression ) = {
@@ -119,10 +168,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
119168 * @since 1.3.0
120169 */
121170 def agg (exprs : Map [String , String ]): DataFrame = {
122- exprs.map { case (colName, expr) =>
171+ toDF( exprs.map { case (colName, expr) =>
123172 val a = strToExpr(expr)(df(colName).expr)
124173 Alias (a, a.prettyString)()
125- }.toSeq
174+ }.toSeq)
126175 }
127176
128177 /**
@@ -175,19 +224,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
175224 */
176225 @ scala.annotation.varargs
177226 def agg (expr : Column , exprs : Column * ): DataFrame = {
178- val aggExprs = (expr +: exprs).map(_.expr).map {
227+ toDF( (expr +: exprs).map(_.expr).map {
179228 case expr : NamedExpression => expr
180229 case expr : Expression => Alias (expr, expr.prettyString)()
181- }
182- if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
183- val retainedExprs = groupingExprs.map {
184- case expr : NamedExpression => expr
185- case expr : Expression => Alias (expr, expr.prettyString)()
186- }
187- DataFrame (df.sqlContext, Aggregate (groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
188- } else {
189- DataFrame (df.sqlContext, Aggregate (groupingExprs, aggExprs, df.logicalPlan))
190- }
230+ })
191231 }
192232
193233 /**
@@ -196,7 +236,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
196236 *
197237 * @since 1.3.0
198238 */
199- def count (): DataFrame = Seq (Alias (Count (Literal (1 )), " count" )())
239+ def count (): DataFrame = toDF( Seq (Alias (Count (Literal (1 )), " count" )() ))
200240
201241 /**
202242 * Compute the average value for each numeric columns for each group. This is an alias for `avg`.
@@ -256,5 +296,5 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
256296 @ scala.annotation.varargs
257297 def sum (colNames : String * ): DataFrame = {
258298 aggregateNumericColumns(colNames:_* )(Sum )
259- }
299+ }
260300}
0 commit comments