diff --git a/temporalio/client.py b/temporalio/client.py index 4eb1dc868..8edae6fb1 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -447,7 +447,7 @@ async def start_workflow( args: Sequence[Any] = [], id: str, task_queue: str, - result_type: Optional[Type] = None, + result_type: Optional[Type[ReturnType]] = None, execution_timeout: Optional[timedelta] = None, run_timeout: Optional[timedelta] = None, task_timeout: Optional[timedelta] = None, @@ -472,7 +472,7 @@ async def start_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, - ) -> WorkflowHandle[Any, Any]: ... + ) -> WorkflowHandle[Any, ReturnType]: ... async def start_workflow( self, @@ -728,7 +728,7 @@ async def execute_workflow( args: Sequence[Any] = [], id: str, task_queue: str, - result_type: Optional[Type] = None, + result_type: Optional[Type[ReturnType]] = None, execution_timeout: Optional[timedelta] = None, run_timeout: Optional[timedelta] = None, task_timeout: Optional[timedelta] = None, @@ -753,7 +753,7 @@ async def execute_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, - ) -> Any: ... + ) -> ReturnType: ... async def execute_workflow( self, @@ -941,10 +941,10 @@ async def execute_update_with_start_workflow( start_workflow_operation: WithStartWorkflowOperation[Any, Any], args: Sequence[Any] = [], id: Optional[str] = None, - result_type: Optional[Type] = None, + result_type: Optional[Type[LocalReturnType]] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, - ) -> Any: ... + ) -> LocalReturnType: ... async def execute_update_with_start_workflow( self, @@ -1062,10 +1062,10 @@ async def start_update_with_start_workflow( wait_for_stage: WorkflowUpdateStage, args: Sequence[Any] = [], id: Optional[str] = None, - result_type: Optional[Type] = None, + result_type: Optional[Type[LocalReturnType]] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, - ) -> WorkflowUpdateHandle[Any]: ... + ) -> WorkflowUpdateHandle[LocalReturnType]: ... async def start_update_with_start_workflow( self, diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 862f9a456..1818e6dc4 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -204,7 +204,7 @@ async def run(self, param1: int, param2: str) -> str: async def test_workflow_multi_param(client: Client): # This test is mostly just here to confirm MyPy type checks the multi-param - # overload approach properly + # overload approach properly, and infers result type from result_type. async with new_worker( client, MultiParamWorkflow, activities=[multi_param_activity] ) as worker: @@ -216,6 +216,15 @@ async def test_workflow_multi_param(client: Client): ) assert result == "param1: 123, param2: val1" + result_via_name_overload = await client.execute_workflow( + "MultiParamWorkflow", + args=[123, "val1"], + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + result_type=str, + ) + assert result_via_name_overload == "param1: 123, param2: val1" + @workflow.defn class InfoWorkflow: