Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 75 additions & 57 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,6 @@ class DataFrame private[sql](
sb.toString()
}

private[sql] def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
Sort(sortOrder, global = global, logicalPlan)
}

override def toString: String = {
try {
schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]")
Expand Down Expand Up @@ -619,6 +607,32 @@ class DataFrame private[sql](
plan.copy(condition = cond)
}

/**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = {
sortWithinPartitions(sortCol, sortCols : _*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin This causes an infinite loop, which isn't caught by the unit tests since DataFrameSuite only tests the Column* overload.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ankurdave can you create a jira?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

/**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortExprs: Column*): DataFrame = {
sortInternal(global = false, sortExprs)
}

/**
* Returns a new [[DataFrame]] sorted by the specified column, all in ascending order.
* {{{
Expand All @@ -645,7 +659,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def sort(sortExprs: Column*): DataFrame = {
sortInternal(true, sortExprs)
sortInternal(global = true, sortExprs)
}

/**
Expand All @@ -666,44 +680,6 @@ class DataFrame private[sql](
@scala.annotation.varargs
def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*)

/**
* Returns a new [[DataFrame]] partitioned by the given partitioning expressions into
* `numPartitions`. The resulting DataFrame is hash partitioned.
* @group dfops
* @since 1.6.0
*/
def distributeBy(partitionExprs: Seq[Column], numPartitions: Int): DataFrame = {
RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, Some(numPartitions))
}

/**
* Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving
* the existing number of partitions. The resulting DataFrame is hash partitioned.
* @group dfops
* @since 1.6.0
*/
def distributeBy(partitionExprs: Seq[Column]): DataFrame = {
RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, None)
}

/**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def localSort(sortCol: String, sortCols: String*): DataFrame = localSort(sortCol, sortCols : _*)

/**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def localSort(sortExprs: Column*): DataFrame = {
sortInternal(false, sortExprs)
}

/**
* Selects column based on the column name and return it as a [[Column]].
* Note that the column name can also reference to a nested column like `a.b`.
Expand Down Expand Up @@ -798,7 +774,9 @@ class DataFrame private[sql](
* SQL expressions.
*
* {{{
* // The following are equivalent:
* df.selectExpr("colA", "colB as newName", "abs(colC)")
* df.select(expr("colA"), expr("colB as newName"), expr("abs(colC)"))
* }}}
* @group dfops
* @since 1.3.0
Expand Down Expand Up @@ -1524,13 +1502,41 @@ class DataFrame private[sql](

/**
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
* @group rdd
* @group dfops
* @since 1.3.0
*/
def repartition(numPartitions: Int): DataFrame = {
Repartition(numPartitions, shuffle = true, logicalPlan)
}

/**
* Returns a new [[DataFrame]] partitioned by the given partitioning expressions into
* `numPartitions`. The resulting DataFrame is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions))
}

/**
* Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving
* the existing number of partitions. The resulting DataFrame is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): DataFrame = {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None)
}

/**
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
Expand Down Expand Up @@ -2016,6 +2022,12 @@ class DataFrame private[sql](
write.mode(SaveMode.Append).insertInto(tableName)
}

////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
// End of deprecated methods
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////

/**
* Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with
* an execution.
Expand Down Expand Up @@ -2045,10 +2057,16 @@ class DataFrame private[sql](
}
}

////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
// End of deprecated methods
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
Sort(sortOrder, global = global, logicalPlan)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
// Set up two tables distributed in the same way. Try this with the data distributed into
// different number of partitions.
for (numPartitions <- 1 until 10 by 4) {
testData.distributeBy(Column("key") :: Nil, numPartitions).registerTempTable("t1")
testData2.distributeBy(Column("a") :: Nil, numPartitions).registerTempTable("t2")
testData.repartition(numPartitions, $"key").registerTempTable("t1")
testData2.repartition(numPartitions, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

Expand All @@ -401,8 +401,20 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
}

// Distribute the tables into non-matching number of partitions. Need to shuffle.
testData.distributeBy(Column("key") :: Nil, 6).registerTempTable("t1")
testData2.distributeBy(Column("a") :: Nil, 3).registerTempTable("t2")
testData.repartition(6, $"key").registerTempTable("t1")
testData2.repartition(3, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2)
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
sqlContext.dropTempTable("t1")
sqlContext.dropTempTable("t2")

// One side of join is not partitioned in the desired way. Need to shuffle.
testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(6, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

Expand Down
44 changes: 22 additions & 22 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1044,79 +1044,79 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("distributeBy and localSort") {
val original = testData.repartition(1)
assert(original.rdd.partitions.length == 1)
val df = original.distributeBy(Column("key") :: Nil, 5)
assert(df.rdd.partitions.length == 5)
val df = original.repartition(5, $"key")
assert(df.rdd.partitions.length == 5)
checkAnswer(original.select(), df.select())

val df2 = original.distributeBy(Column("key") :: Nil, 10)
assert(df2.rdd.partitions.length == 10)
val df2 = original.repartition(10, $"key")
assert(df2.rdd.partitions.length == 10)
checkAnswer(original.select(), df2.select())

// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.distributeBy(Column("key") :: Nil).groupBy("key").count()
val df3 = testData.repartition($"key").groupBy("key").count()
verifyNonExchangingAgg(df3)
verifyNonExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil)
verifyNonExchangingAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())

// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil)
verifyExchangingAgg(testData.repartition($"key", $"value")
.groupBy("key").count())

val data = sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData2(i % 10, i))).toDF()

// Distribute and order by.
val df4 = data.distributeBy(Column("a") :: Nil).localSort($"b".desc)
val df4 = data.repartition($"a").sortWithinPartitions($"b".desc)
// Walk each partition and verify that it is sorted descending and does not contain all
// the values.
df4.rdd.foreachPartition(p => {
df4.rdd.foreachPartition { p =>
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach(r => {
p.foreach { r =>
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue < v) throw new SparkException("Partition is not ordered.")
if (v + 1 != previousValue) allSequential = false
}
previousValue = v
})
}
if (allSequential) throw new SparkException("Partition should not be globally ordered")
})
}

// Distribute and order by with multiple order bys
val df5 = data.distributeBy(Column("a") :: Nil, 2).localSort($"b".asc, $"a".asc)
val df5 = data.repartition(2, $"a").sortWithinPartitions($"b".asc, $"a".asc)
// Walk each partition and verify that it is sorted ascending
df5.rdd.foreachPartition(p => {
df5.rdd.foreachPartition { p =>
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach(r => {
p.foreach { r =>
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue > v) throw new SparkException("Partition is not ordered.")
if (v - 1 != previousValue) allSequential = false
}
previousValue = v
})
}
if (allSequential) throw new SparkException("Partition should not be all sequential")
})
}

// Distribute into one partition and order by. This partition should contain all the values.
val df6 = data.distributeBy(Column("a") :: Nil, 1).localSort($"b".asc)
val df6 = data.repartition(1, $"a").sortWithinPartitions($"b".asc)
// Walk each partition and verify that it is sorted descending and not globally sorted.
df6.rdd.foreachPartition(p => {
df6.rdd.foreachPartition { p =>
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach(r => {
p.foreach { r =>
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue > v) throw new SparkException("Partition is not ordered.")
if (v - 1 != previousValue) allSequential = false
}
previousValue = v
})
}
if (!allSequential) throw new SparkException("Partition should contain all sequential values")
})
}
}

test("fix case sensitivity of partition by") {
Expand Down