Skip to content

Commit 28f3d81

Browse files
author
Davies Liu
committed
Merge branch 'master' of github.com:apache/spark into string
Conflicts: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
2 parents 28d6f32 + 5db8912 commit 28f3d81

File tree

18 files changed

+161
-82
lines changed

18 files changed

+161
-82
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
433433
// Thread Local variable that can be used by users to pass information down the stack
434434
private val localProperties = new InheritableThreadLocal[Properties] {
435435
override protected def childValue(parent: Properties): Properties = new Properties(parent)
436+
override protected def initialValue(): Properties = new Properties()
436437
}
437438

438439
/**
@@ -474,9 +475,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
474475
* Spark fair scheduler pool.
475476
*/
476477
def setLocalProperty(key: String, value: String) {
477-
if (localProperties.get() == null) {
478-
localProperties.set(new Properties())
479-
}
480478
if (value == null) {
481479
localProperties.get.remove(key)
482480
} else {

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,6 @@ private[spark] object PythonRDD extends Logging {
605605
*/
606606
private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
607607
val serverSocket = new ServerSocket(0, 1)
608-
serverSocket.setReuseAddress(true)
609608
// Close the socket if no connection in 3 seconds
610609
serverSocket.setSoTimeout(3000)
611610

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ class DAGScheduler(
493493
callSite: CallSite,
494494
allowLocal: Boolean,
495495
resultHandler: (Int, U) => Unit,
496-
properties: Properties = null): JobWaiter[U] = {
496+
properties: Properties): JobWaiter[U] = {
497497
// Check to make sure we are not launching a task on a partition that does not exist.
498498
val maxPartitions = rdd.partitions.length
499499
partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
@@ -522,7 +522,7 @@ class DAGScheduler(
522522
callSite: CallSite,
523523
allowLocal: Boolean,
524524
resultHandler: (Int, U) => Unit,
525-
properties: Properties = null): Unit = {
525+
properties: Properties): Unit = {
526526
val start = System.nanoTime
527527
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
528528
waiter.awaitResult() match {
@@ -542,7 +542,7 @@ class DAGScheduler(
542542
evaluator: ApproximateEvaluator[U, R],
543543
callSite: CallSite,
544544
timeout: Long,
545-
properties: Properties = null): PartialResult[R] = {
545+
properties: Properties): PartialResult[R] = {
546546
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
547547
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
548548
val partitions = (0 until rdd.partitions.size).toArray
@@ -689,7 +689,7 @@ class DAGScheduler(
689689
// Cancel all jobs belonging to this job group.
690690
// First finds all active jobs with this group id, and then kill stages for them.
691691
val activeInGroup = activeJobs.filter(activeJob =>
692-
groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
692+
Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId))
693693
val jobIds = activeInGroup.map(_.jobId)
694694
jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId)))
695695
submitWaitingStages()
@@ -736,7 +736,7 @@ class DAGScheduler(
736736
allowLocal: Boolean,
737737
callSite: CallSite,
738738
listener: JobListener,
739-
properties: Properties = null) {
739+
properties: Properties) {
740740
var finalStage: ResultStage = null
741741
try {
742742
// New stage creation may throw an exception if, for example, jobs are run on a

core/src/test/scala/org/apache/spark/SparkContextSuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,19 @@
1818
package org.apache.spark
1919

2020
import java.io.File
21+
import java.util.concurrent.TimeUnit
2122

2223
import com.google.common.base.Charsets._
2324
import com.google.common.io.Files
2425

2526
import org.scalatest.FunSuite
2627

2728
import org.apache.hadoop.io.BytesWritable
28-
2929
import org.apache.spark.util.Utils
3030

31+
import scala.concurrent.Await
32+
import scala.concurrent.duration.Duration
33+
3134
class SparkContextSuite extends FunSuite with LocalSparkContext {
3235

3336
test("Only one SparkContext may be active at a time") {
@@ -173,4 +176,19 @@ class SparkContextSuite extends FunSuite with LocalSparkContext {
173176
sc.stop()
174177
}
175178
}
179+
180+
test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") {
181+
try {
182+
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
183+
val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)})
184+
sc.cancelJobGroup("nonExistGroupId")
185+
Await.ready(future, Duration(2, TimeUnit.SECONDS))
186+
187+
// In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause
188+
// SparkContext to shutdown, so the following assertion will fail.
189+
assert(sc.parallelize(1 to 10).count() == 10L)
190+
} finally {
191+
sc.stop()
192+
}
193+
}
176194
}

python/pyspark/rdd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def _parse_memory(s):
113113

114114
def _load_from_socket(port, serializer):
115115
sock = socket.socket()
116+
sock.settimeout(3)
116117
try:
117118
sock.connect(("localhost", port))
118119
rf = sock.makefile("rb", 65536)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ trait ScalaReflection {
7070
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
7171
convertToCatalyst(elem, field.dataType)
7272
}.toArray)
73+
case (d: BigDecimal, _) => Decimal(d)
74+
case (d: java.math.BigDecimal, _) => Decimal(d)
75+
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
76+
case (s: String, _) => UTF8String(s)
7377
case (r: Row, structType: StructType) =>
7478
new GenericRow(
7579
r.toSeq.zip(structType.fields).map { case (elem, field) =>
7680
convertToCatalyst(elem, field.dataType)
7781
}.toArray)
78-
case (d: BigDecimal, _) => Decimal(d)
79-
case (d: java.math.BigDecimal, _) => Decimal(d)
80-
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
81-
case (s: String, _) => UTF8String(s)
8282
case (other, _) => other
8383
}
8484

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,8 +903,9 @@ class DataFrame private[sql](
903903
* @group rdd
904904
*/
905905
override def repartition(numPartitions: Int): DataFrame = {
906-
val repartitioned = queryExecution.toRdd.map(_.copy()).repartition(numPartitions)
907-
DataFrame(sqlContext, LogicalRDD(schema.toAttributes, repartitioned)(sqlContext))
906+
sqlContext.createDataFrame(
907+
queryExecution.toRdd.map(_.copy()).repartition(numPartitions),
908+
schema, needsConversion = false)
908909
}
909910

910911
/**

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -392,33 +392,24 @@ class SQLContext(@transient val sparkContext: SparkContext)
392392
*/
393393
@DeveloperApi
394394
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
395-
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
396-
// schema differs from the existing schema on any field data type.
397-
def needsConversion(dt: DataType): Boolean = dt match {
398-
case StringType => true
399-
case DateType => true
400-
case DecimalType() => true
401-
case dt: ArrayType => needsConversion(dt.elementType)
402-
case dt: MapType => needsConversion(dt.keyType) || needsConversion(dt.valueType)
403-
case dt: StructType => !dt.fields.forall(f => !needsConversion(f.dataType))
404-
case other => false
405-
}
406-
val convertedRdd = if (needsConversion(schema)) {
407-
RDDConversions.rowToRowRdd(rowRDD, schema)
408-
} else {
409-
rowRDD
410-
}
411-
DataFrame(this, LogicalRDD(schema.toAttributes, convertedRdd)(self))
395+
createDataFrame(rowRDD, schema, needsConversion = true)
412396
}
413397

414398
/**
415-
* An internal API to apply a new schema on existing DataFrame without do the
416-
* conversion for Rows.
399+
* Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
400+
* converted to Catalyst rows.
417401
*/
418-
private[sql] def createDataFrame(df: DataFrame, schema: StructType): DataFrame = {
402+
private[sql]
403+
def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = {
419404
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
420405
// schema differs from the existing schema on any field data type.
421-
DataFrame(this, LogicalRDD(schema.toAttributes, df.queryExecution.toRdd)(self))
406+
val catalystRows = if (needsConversion) {
407+
rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])
408+
} else {
409+
rowRDD
410+
}
411+
val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
412+
DataFrame(this, logicalPlan)
422413
}
423414

424415
/**
@@ -627,7 +618,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
627618
JsonRDD.nullTypeToStringType(
628619
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
629620
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
630-
createDataFrame(rowRDD, appliedSchema)
621+
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
631622
}
632623

633624
/**
@@ -656,7 +647,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
656647
JsonRDD.nullTypeToStringType(
657648
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
658649
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
659-
createDataFrame(rowRDD, appliedSchema)
650+
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
660651
}
661652

662653
/**

sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ private[sql] class DefaultSource
119119
val relation = if (doInsertion) {
120120
// This is a hack. We always set nullable/containsNull/valueContainsNull to true
121121
// for the schema of a parquet data.
122-
val df = sqlContext.createDataFrame(data, data.schema.asNullable)
122+
val df =
123+
sqlContext.createDataFrame(
124+
data.queryExecution.toRdd,
125+
data.schema.asNullable,
126+
needsConversion = false)
123127
val createdRelation =
124128
createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2]
125129
createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite)
@@ -433,17 +437,18 @@ private[sql] case class ParquetRelation2(
433437
FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*)
434438
}
435439

436-
// Push down filters when possible. Notice that not all filters can be converted to Parquet
437-
// filter predicate. Here we try to convert each individual predicate and only collect those
438-
// convertible ones.
440+
// Try to push down filters when filter push-down is enabled.
439441
if (sqlContext.conf.parquetFilterPushDown) {
442+
val partitionColNames = partitionColumns.map(_.name).toSet
440443
predicates
441444
// Don't push down predicates which reference partition columns
442445
.filter { pred =>
443-
val partitionColNames = partitionColumns.map(_.name).toSet
444446
val referencedColNames = pred.references.map(_.name).toSet
445447
referencedColNames.intersect(partitionColNames).isEmpty
446448
}
449+
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
450+
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
451+
// is used here.
447452
.flatMap(ParquetFilters.createFilter)
448453
.reduceOption(FilterApi.and)
449454
.foreach(ParquetInputFormat.setFilterPredicate(jobConf, _))

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource(
3131
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
3232
val data = DataFrame(sqlContext, query)
3333
// Apply the schema of the existing table to the new data.
34-
val df = sqlContext.createDataFrame(data, logicalRelation.schema)
34+
val df = sqlContext.createDataFrame(
35+
data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false)
3536
relation.insert(df, overwrite)
3637

3738
// Invalidate the cache.

0 commit comments

Comments
 (0)