Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4794,6 +4794,9 @@ async def get_worker_task_reachability(
namespace=self._client.namespace,
build_ids=input.build_ids,
task_queues=input.task_queues,
reachability=input.reachability._to_proto()
if input.reachability
else temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_UNSPECIFIED,
)
resp = await self._client.workflow_service.get_worker_task_reachability(
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
Expand Down Expand Up @@ -5170,3 +5173,25 @@ def _from_proto(
return TaskReachabilityType.CLOSED_WORKFLOWS
else:
raise ValueError(f"Cannot convert reachability type: {reachability}")

def _to_proto(self) -> temporalio.api.enums.v1.TaskReachability.ValueType:
if self == TaskReachabilityType.NEW_WORKFLOWS:
return (
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_NEW_WORKFLOWS
)
elif self == TaskReachabilityType.EXISTING_WORKFLOWS:
return (
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_EXISTING_WORKFLOWS
)
elif self == TaskReachabilityType.OPEN_WORKFLOWS:
return (
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_OPEN_WORKFLOWS
)
elif self == TaskReachabilityType.CLOSED_WORKFLOWS:
return (
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_CLOSED_WORKFLOWS
)
else:
return (
temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_UNSPECIFIED
)
12 changes: 11 additions & 1 deletion tests/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import temporalio.worker._worker
from temporalio import activity, workflow
from temporalio.client import BuildIdOpAddNewDefault, Client
from temporalio.client import BuildIdOpAddNewDefault, Client, TaskReachabilityType
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from temporalio.workflow import VersioningIntent
Expand Down Expand Up @@ -198,6 +198,16 @@ async def test_worker_versioning(client: Client, env: WorkflowEnvironment):
build_id="2.0",
use_worker_versioning=True,
):
# Confirm reachability type parameter is respected. If it wasn't, list would have
# `OPEN_WORKFLOWS` in it.
reachability = await client.get_worker_task_reachability(
build_ids=["2.0"],
reachability_type=TaskReachabilityType.CLOSED_WORKFLOWS,
)
assert reachability.build_id_reachability["2.0"].task_queue_reachability[
task_queue
] == [TaskReachabilityType.NEW_WORKFLOWS]

await wf1.signal(WaitOnSignalWorkflow.my_signal, "finish")
await wf2.signal(WaitOnSignalWorkflow.my_signal, "finish")
await wf1.result()
Expand Down