Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -482,13 +482,13 @@ case class MapPartitions[T, U](
}

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
object AppendColumns {
def apply[T, U : Encoder](
func: T => U,
tEncoder: ExpressionEncoder[T],
child: LogicalPlan): AppendColumn[T, U] = {
child: LogicalPlan): AppendColumns[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}

Expand All @@ -497,7 +497,7 @@ object AppendColumn {
* resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to
* decode/encode from the JVM object representation expected by `func.`
*/
case class AppendColumn[T, U](
case class AppendColumns[T, U](
func: T => U,
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ class TypedColumn[-T, U](
private[sql] def withInputType(
inputEncoder: ExpressionEncoder[_],
schema: Seq[Attribute]): TypedColumn[T, U] = {
val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]]
new TypedColumn[T, U] (expr transform {
case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
ta.copy(
aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]),
aEncoder = Some(boundEncoder),
children = schema)
}, encoder)
}
Expand Down
10 changes: 3 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class Dataset[T] private[sql](
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
val inputPlan = queryExecution.analyzed
val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan)
val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)

new GroupedDataset(
Expand Down Expand Up @@ -364,13 +364,11 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
// We use an unbound encoder since the expression will make up its own schema.
// TODO: This probably doesn't work if we are relying on reordering of the input class fields.
new Dataset[U1](
sqlContext,
Project(
c1.withInputType(
resolvedTEncoder.bind(queryExecution.analyzed.output),
resolvedTEncoder,
queryExecution.analyzed.output).named :: Nil,
logicalPlan))
}
Expand All @@ -382,10 +380,8 @@ class Dataset[T] private[sql](
*/
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
// We use an unbound encoder since the expression will make up its own schema.
// TODO: This probably doesn't work if we are relying on reordering of the input class fields.
val namedColumns =
columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named)
columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))

new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class GroupedDataset[K, T] private[sql](
}

/**
* Applies the given function to each group of data. For each unique group, the function will
* Applies the given function to each group of data. For each unique group, the function will
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an iterator containing elements of an arbitrary type which will be returned
* as a new [[Dataset]].
Expand Down Expand Up @@ -162,7 +162,7 @@ class GroupedDataset[K, T] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(
_.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named)
_.withInputType(resolvedTEncoder, dataAttributes).named)
val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)
val execution = new QueryExecution(sqlContext, aggregate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

case logical.MapPartitions(f, tEnc, uEnc, output, child) =>
execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil
case logical.AppendColumn(f, tEnc, uEnc, newCol, child) =>
case logical.AppendColumns(f, tEnc, uEnc, newCol, child) =>
execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ public Integer call(Integer v1, Integer v2) throws Exception {
Assert.assertEquals(6, reduced);
}

@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.select(ClassInputAgg.toColumn),
1)

checkAnswer(
ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn),
(1.0, 1))

checkAnswer(
ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
("one", 1))
Expand Down