Skip to content

Commit 1097275

Browse files
committed
Infer result type from result_type arg under string name overload
1 parent 9054679 commit 1097275

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

temporalio/client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ async def start_workflow(
446446
args: Sequence[Any] = [],
447447
id: str,
448448
task_queue: str,
449-
result_type: Optional[Type] = None,
449+
result_type: Optional[Type[ReturnType]] = None,
450450
execution_timeout: Optional[timedelta] = None,
451451
run_timeout: Optional[timedelta] = None,
452452
task_timeout: Optional[timedelta] = None,
@@ -471,7 +471,7 @@ async def start_workflow(
471471
request_eager_start: bool = False,
472472
priority: temporalio.common.Priority = temporalio.common.Priority.default,
473473
versioning_override: Optional[temporalio.common.VersioningOverride] = None,
474-
) -> WorkflowHandle[Any, Any]: ...
474+
) -> WorkflowHandle[Any, ReturnType]: ...
475475

476476
async def start_workflow(
477477
self,
@@ -727,7 +727,7 @@ async def execute_workflow(
727727
args: Sequence[Any] = [],
728728
id: str,
729729
task_queue: str,
730-
result_type: Optional[Type] = None,
730+
result_type: Optional[Type[ReturnType]] = None,
731731
execution_timeout: Optional[timedelta] = None,
732732
run_timeout: Optional[timedelta] = None,
733733
task_timeout: Optional[timedelta] = None,
@@ -752,7 +752,7 @@ async def execute_workflow(
752752
request_eager_start: bool = False,
753753
priority: temporalio.common.Priority = temporalio.common.Priority.default,
754754
versioning_override: Optional[temporalio.common.VersioningOverride] = None,
755-
) -> Any: ...
755+
) -> ReturnType: ...
756756

757757
async def execute_workflow(
758758
self,
@@ -940,10 +940,10 @@ async def execute_update_with_start_workflow(
940940
start_workflow_operation: WithStartWorkflowOperation[Any, Any],
941941
args: Sequence[Any] = [],
942942
id: Optional[str] = None,
943-
result_type: Optional[Type] = None,
943+
result_type: Optional[Type[LocalReturnType]] = None,
944944
rpc_metadata: Mapping[str, str] = {},
945945
rpc_timeout: Optional[timedelta] = None,
946-
) -> Any: ...
946+
) -> LocalReturnType: ...
947947

948948
async def execute_update_with_start_workflow(
949949
self,
@@ -1061,10 +1061,10 @@ async def start_update_with_start_workflow(
10611061
wait_for_stage: WorkflowUpdateStage,
10621062
args: Sequence[Any] = [],
10631063
id: Optional[str] = None,
1064-
result_type: Optional[Type] = None,
1064+
result_type: Optional[Type[LocalReturnType]] = None,
10651065
rpc_metadata: Mapping[str, str] = {},
10661066
rpc_timeout: Optional[timedelta] = None,
1067-
) -> WorkflowUpdateHandle[Any]: ...
1067+
) -> WorkflowUpdateHandle[LocalReturnType]: ...
10681068

10691069
async def start_update_with_start_workflow(
10701070
self,

tests/worker/test_workflow.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ async def run(self, param1: int, param2: str) -> str:
204204

205205
async def test_workflow_multi_param(client: Client):
206206
# This test is mostly just here to confirm MyPy type checks the multi-param
207-
# overload approach properly
207+
# overload approach properly, and infers result type from result_type.
208208
async with new_worker(
209209
client, MultiParamWorkflow, activities=[multi_param_activity]
210210
) as worker:
@@ -216,6 +216,15 @@ async def test_workflow_multi_param(client: Client):
216216
)
217217
assert result == "param1: 123, param2: val1"
218218

219+
result_via_name_overload = await client.execute_workflow(
220+
"MultiParamWorkflow",
221+
args=[123, "val1"],
222+
id=f"workflow-{uuid.uuid4()}",
223+
task_queue=worker.task_queue,
224+
result_type=str,
225+
)
226+
assert result_via_name_overload == "param1: 123, param2: val1"
227+
219228

220229
@workflow.defn
221230
class InfoWorkflow:

0 commit comments

Comments
 (0)