Skip to content

Commit 90e1c47

Browse files
committed
Preserve contextvars during comm offload
Helps with setting the current client in worker while deserializing. Implementation referenced from python/cpython#9688
1 parent 1898673 commit 90e1c47

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

distributed/tests/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import array
22
import asyncio
3+
import contextvars
34
import functools
45
import io
56
import os
@@ -554,6 +555,22 @@ async def test_offload():
554555
assert (await offload(lambda x, y: x + y, 1, y=2)) == 3
555556

556557

558+
@pytest.mark.asyncio
559+
async def test_offload_preserves_contextvars():
560+
var = contextvars.ContextVar("var", default="foo")
561+
562+
def change_var():
563+
var.set("bar")
564+
return var.get()
565+
566+
o1 = offload(var.get)
567+
o2 = offload(change_var)
568+
569+
r1, r2 = await asyncio.gather(o1, o2)
570+
assert (r1, r2) == ("foo", "bar")
571+
assert var.get() == "foo"
572+
573+
557574
def test_serialize_for_cli_deprecated():
558575
with pytest.warns(FutureWarning, match="serialize_for_cli is deprecated"):
559576
from distributed.utils import serialize_for_cli

distributed/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextvars
45
import functools
56
import importlib
67
import inspect
@@ -1322,7 +1323,11 @@ def import_term(name: str):
13221323

13231324
async def offload(fn, *args, **kwargs):
13241325
loop = asyncio.get_event_loop()
1325-
return await loop.run_in_executor(_offload_executor, lambda: fn(*args, **kwargs))
1326+
# Retain context vars while deserializing; see https://bugs.python.org/issue34014
1327+
context = contextvars.copy_context()
1328+
return await loop.run_in_executor(
1329+
_offload_executor, lambda: context.run(fn, *args, **kwargs)
1330+
)
13261331

13271332

13281333
class EmptyContext:

0 commit comments

Comments
 (0)