Skip to content

Commit 3fdce81

Browse files
aokolnychyigatorsmile
authored andcommitted
[SPARK-16046][DOCS] Aggregations in the Spark SQL programming guide
## What changes were proposed in this pull request? - A separate subsection for Aggregations under “Getting Started” in the Spark SQL programming guide. It mentions which aggregate functions are predefined and how users can create their own. - Examples of using the `UserDefinedAggregateFunction` abstract class for untyped aggregations in Java and Scala. - Examples of using the `Aggregator` abstract class for type-safe aggregations in Java and Scala. - Python is not covered. - The PR might not resolve the ticket since I do not know what exactly was planned by the author. In total, there are four new standalone examples that can be executed via `spark-submit` or `run-example`. The updated Spark SQL programming guide references to these examples and does not contain hard-coded snippets. ## How was this patch tested? The patch was tested locally by building the docs. The examples were run as well. ![image](https://cloud.githubusercontent.com/assets/6235869/21292915/04d9d084-c515-11e6-811a-999d598dffba.png) Author: aokolnychyi <[email protected]> Closes #16329 from aokolnychyi/SPARK-16046.
1 parent 40a4cfc commit 3fdce81

File tree

6 files changed

+533
-0
lines changed

6 files changed

+533
-0
lines changed

docs/sql-programming-guide.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,52 @@ For example:
382382

383383
</div>
384384

385+
## Aggregations
386+
387+
The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common
388+
aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc.
389+
While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in
390+
[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and
391+
[Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets.
392+
Moreover, users are not limited to the predefined aggregate functions and can create their own.
393+
394+
### Untyped User-Defined Aggregate Functions
395+
396+
<div class="codetabs">
397+
398+
<div data-lang="scala" markdown="1">
399+
400+
Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction)
401+
abstract class to implement a custom untyped aggregate function. For example, a user-defined average
402+
can look like:
403+
404+
{% include_example untyped_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala%}
405+
</div>
406+
407+
<div data-lang="java" markdown="1">
408+
409+
{% include_example untyped_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java%}
410+
</div>
411+
412+
</div>
413+
414+
### Type-Safe User-Defined Aggregate Functions
415+
416+
User-defined aggregations for strongly typed Datasets revolve around the [Aggregator](api/scala/index.html#org.apache.spark.sql.expressions.Aggregator) abstract class.
417+
For example, a type-safe user-defined average can look like:
418+
<div class="codetabs">
419+
420+
<div data-lang="scala" markdown="1">
421+
422+
{% include_example typed_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala%}
423+
</div>
424+
425+
<div data-lang="java" markdown="1">
426+
427+
{% include_example typed_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java%}
428+
</div>
429+
430+
</div>
385431

386432
# Data Sources
387433

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.examples.sql;
18+
19+
// $example on:typed_custom_aggregation$
20+
import java.io.Serializable;
21+
22+
import org.apache.spark.sql.Dataset;
23+
import org.apache.spark.sql.Encoder;
24+
import org.apache.spark.sql.Encoders;
25+
import org.apache.spark.sql.SparkSession;
26+
import org.apache.spark.sql.TypedColumn;
27+
import org.apache.spark.sql.expressions.Aggregator;
28+
// $example off:typed_custom_aggregation$
29+
30+
public class JavaUserDefinedTypedAggregation {
31+
32+
// $example on:typed_custom_aggregation$
33+
public static class Employee implements Serializable {
34+
private String name;
35+
private long salary;
36+
37+
// Constructors, getters, setters...
38+
// $example off:typed_custom_aggregation$
39+
public String getName() {
40+
return name;
41+
}
42+
43+
public void setName(String name) {
44+
this.name = name;
45+
}
46+
47+
public long getSalary() {
48+
return salary;
49+
}
50+
51+
public void setSalary(long salary) {
52+
this.salary = salary;
53+
}
54+
// $example on:typed_custom_aggregation$
55+
}
56+
57+
public static class Average implements Serializable {
58+
private long sum;
59+
private long count;
60+
61+
// Constructors, getters, setters...
62+
// $example off:typed_custom_aggregation$
63+
public Average() {
64+
}
65+
66+
public Average(long sum, long count) {
67+
this.sum = sum;
68+
this.count = count;
69+
}
70+
71+
public long getSum() {
72+
return sum;
73+
}
74+
75+
public void setSum(long sum) {
76+
this.sum = sum;
77+
}
78+
79+
public long getCount() {
80+
return count;
81+
}
82+
83+
public void setCount(long count) {
84+
this.count = count;
85+
}
86+
// $example on:typed_custom_aggregation$
87+
}
88+
89+
public static class MyAverage extends Aggregator<Employee, Average, Double> {
90+
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
91+
public Average zero() {
92+
return new Average(0L, 0L);
93+
}
94+
// Combine two values to produce a new value. For performance, the function may modify `buffer`
95+
// and return it instead of constructing a new object
96+
public Average reduce(Average buffer, Employee employee) {
97+
long newSum = buffer.getSum() + employee.getSalary();
98+
long newCount = buffer.getCount() + 1;
99+
buffer.setSum(newSum);
100+
buffer.setCount(newCount);
101+
return buffer;
102+
}
103+
// Merge two intermediate values
104+
public Average merge(Average b1, Average b2) {
105+
long mergedSum = b1.getSum() + b2.getSum();
106+
long mergedCount = b1.getCount() + b2.getCount();
107+
b1.setSum(mergedSum);
108+
b1.setCount(mergedCount);
109+
return b1;
110+
}
111+
// Transform the output of the reduction
112+
public Double finish(Average reduction) {
113+
return ((double) reduction.getSum()) / reduction.getCount();
114+
}
115+
// Specifies the Encoder for the intermediate value type
116+
public Encoder<Average> bufferEncoder() {
117+
return Encoders.bean(Average.class);
118+
}
119+
// Specifies the Encoder for the final output value type
120+
public Encoder<Double> outputEncoder() {
121+
return Encoders.DOUBLE();
122+
}
123+
}
124+
// $example off:typed_custom_aggregation$
125+
126+
public static void main(String[] args) {
127+
SparkSession spark = SparkSession
128+
.builder()
129+
.appName("Java Spark SQL user-defined Datasets aggregation example")
130+
.getOrCreate();
131+
132+
// $example on:typed_custom_aggregation$
133+
Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
134+
String path = "examples/src/main/resources/employees.json";
135+
Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
136+
ds.show();
137+
// +-------+------+
138+
// | name|salary|
139+
// +-------+------+
140+
// |Michael| 3000|
141+
// | Andy| 4500|
142+
// | Justin| 3500|
143+
// | Berta| 4000|
144+
// +-------+------+
145+
146+
MyAverage myAverage = new MyAverage();
147+
// Convert the function to a `TypedColumn` and give it a name
148+
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
149+
Dataset<Double> result = ds.select(averageSalary);
150+
result.show();
151+
// +--------------+
152+
// |average_salary|
153+
// +--------------+
154+
// | 3750.0|
155+
// +--------------+
156+
// $example off:typed_custom_aggregation$
157+
spark.stop();
158+
}
159+
160+
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.examples.sql;
18+
19+
// $example on:untyped_custom_aggregation$
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
23+
import org.apache.spark.sql.Dataset;
24+
import org.apache.spark.sql.Row;
25+
import org.apache.spark.sql.SparkSession;
26+
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
27+
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
28+
import org.apache.spark.sql.types.DataType;
29+
import org.apache.spark.sql.types.DataTypes;
30+
import org.apache.spark.sql.types.StructField;
31+
import org.apache.spark.sql.types.StructType;
32+
// $example off:untyped_custom_aggregation$
33+
34+
public class JavaUserDefinedUntypedAggregation {
35+
36+
// $example on:untyped_custom_aggregation$
37+
public static class MyAverage extends UserDefinedAggregateFunction {
38+
39+
private StructType inputSchema;
40+
private StructType bufferSchema;
41+
42+
public MyAverage() {
43+
List<StructField> inputFields = new ArrayList<>();
44+
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
45+
inputSchema = DataTypes.createStructType(inputFields);
46+
47+
List<StructField> bufferFields = new ArrayList<>();
48+
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
49+
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
50+
bufferSchema = DataTypes.createStructType(bufferFields);
51+
}
52+
// Data types of input arguments of this aggregate function
53+
public StructType inputSchema() {
54+
return inputSchema;
55+
}
56+
// Data types of values in the aggregation buffer
57+
public StructType bufferSchema() {
58+
return bufferSchema;
59+
}
60+
// The data type of the returned value
61+
public DataType dataType() {
62+
return DataTypes.DoubleType;
63+
}
64+
// Whether this function always returns the same output on the identical input
65+
public boolean deterministic() {
66+
return true;
67+
}
68+
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
69+
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
70+
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
71+
// immutable.
72+
public void initialize(MutableAggregationBuffer buffer) {
73+
buffer.update(0, 0L);
74+
buffer.update(1, 0L);
75+
}
76+
// Updates the given aggregation buffer `buffer` with new input data from `input`
77+
public void update(MutableAggregationBuffer buffer, Row input) {
78+
if (!input.isNullAt(0)) {
79+
long updatedSum = buffer.getLong(0) + input.getLong(0);
80+
long updatedCount = buffer.getLong(1) + 1;
81+
buffer.update(0, updatedSum);
82+
buffer.update(1, updatedCount);
83+
}
84+
}
85+
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
86+
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
87+
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
88+
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
89+
buffer1.update(0, mergedSum);
90+
buffer1.update(1, mergedCount);
91+
}
92+
// Calculates the final result
93+
public Double evaluate(Row buffer) {
94+
return ((double) buffer.getLong(0)) / buffer.getLong(1);
95+
}
96+
}
97+
// $example off:untyped_custom_aggregation$
98+
99+
public static void main(String[] args) {
100+
SparkSession spark = SparkSession
101+
.builder()
102+
.appName("Java Spark SQL user-defined DataFrames aggregation example")
103+
.getOrCreate();
104+
105+
// $example on:untyped_custom_aggregation$
106+
// Register the function to access it
107+
spark.udf().register("myAverage", new MyAverage());
108+
109+
Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
110+
df.createOrReplaceTempView("employees");
111+
df.show();
112+
// +-------+------+
113+
// | name|salary|
114+
// +-------+------+
115+
// |Michael| 3000|
116+
// | Andy| 4500|
117+
// | Justin| 3500|
118+
// | Berta| 4000|
119+
// +-------+------+
120+
121+
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
122+
result.show();
123+
// +--------------+
124+
// |average_salary|
125+
// +--------------+
126+
// | 3750.0|
127+
// +--------------+
128+
// $example off:untyped_custom_aggregation$
129+
130+
spark.stop();
131+
}
132+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{"name":"Michael", "salary":3000}
2+
{"name":"Andy", "salary":4500}
3+
{"name":"Justin", "salary":3500}
4+
{"name":"Berta", "salary":4000}

0 commit comments

Comments
 (0)