Skip to content

Commit 223df5d

Browse files
committed
[SPARK-24397][PYSPARK] Added TaskContext.getLocalProperty(key) in Python
## What changes were proposed in this pull request? This adds a new API `TaskContext.getLocalProperty(key)` to the Python TaskContext. It mirrors the Java TaskContext API of returning a string value if the key exists, or None if the key does not exist. ## How was this patch tested? New test added. Author: Tathagata Das <[email protected]> Closes #21437 from tdas/SPARK-24397.
1 parent 7a82e93 commit 223df5d

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
183183
dataOut.writeInt(context.partitionId())
184184
dataOut.writeInt(context.attemptNumber())
185185
dataOut.writeLong(context.taskAttemptId())
186+
val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala
187+
dataOut.writeInt(localProps.size)
188+
localProps.foreach { case (k, v) =>
189+
PythonRDD.writeUTF(k, dataOut)
190+
PythonRDD.writeUTF(v, dataOut)
191+
}
192+
186193
// sparkFilesDir
187194
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
188195
// Python includes (*.zip and *.egg files)

python/pyspark/taskcontext.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class TaskContext(object):
3434
_partitionId = None
3535
_stageId = None
3636
_taskAttemptId = None
37+
_localProperties = None
3738

3839
def __new__(cls):
3940
"""Even if users construct TaskContext instead of using get, give them the singleton."""
@@ -88,3 +89,9 @@ def taskAttemptId(self):
8889
TaskAttemptID.
8990
"""
9091
return self._taskAttemptId
92+
93+
def getLocalProperty(self, key):
94+
"""
95+
Get a local property set upstream in the driver, or None if it is missing.
96+
"""
97+
return self._localProperties.get(key, None)

python/pyspark/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,20 @@ def test_tc_on_driver(self):
574574
tc = TaskContext.get()
575575
self.assertTrue(tc is None)
576576

577+
def test_get_local_property(self):
578+
"""Verify that local properties set on the driver are available in TaskContext."""
579+
key = "testkey"
580+
value = "testvalue"
581+
self.sc.setLocalProperty(key, value)
582+
try:
583+
rdd = self.sc.parallelize(range(1), 1)
584+
prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0]
585+
self.assertEqual(prop1, value)
586+
prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
587+
self.assertTrue(prop2 is None)
588+
finally:
589+
self.sc.setLocalProperty(key, None)
590+
577591

578592
class RDDTests(ReusedPySparkTestCase):
579593

python/pyspark/worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ def main(infile, outfile):
222222
taskContext._partitionId = read_int(infile)
223223
taskContext._attemptNumber = read_int(infile)
224224
taskContext._taskAttemptId = read_long(infile)
225+
taskContext._localProperties = dict()
226+
for i in range(read_int(infile)):
227+
k = utf8_deserializer.loads(infile)
228+
v = utf8_deserializer.loads(infile)
229+
taskContext._localProperties[k] = v
230+
225231
shuffle.MemoryBytesSpilled = 0
226232
shuffle.DiskBytesSpilled = 0
227233
_accumulatorRegistry.clear()

0 commit comments

Comments
 (0)