diff --git a/temporalio/workflow.py b/temporalio/workflow.py index a627cec1d..75cedd18c 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5280,13 +5280,16 @@ async def execute_operation( headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... + # TODO(nexus-preview): in practice, both these overloads match an async def sync + # operation (i.e. either can be deleted without causing a type error). + # Overload for sync_operation methods (async def) @overload @abstractmethod async def execute_operation( self, operation: Callable[ - [ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT], + [ServiceT, nexusrpc.handler.StartOperationContext, InputT], Awaitable[OutputT], ], input: InputT, @@ -5302,7 +5305,7 @@ async def execute_operation( async def execute_operation( self, operation: Callable[ - [ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT], + [ServiceT, nexusrpc.handler.StartOperationContext, InputT], OutputT, ], input: InputT, diff --git a/tests/nexus/test_type_checking.py b/tests/nexus/test_type_checking.py deleted file mode 100644 index e43b289c9..000000000 --- a/tests/nexus/test_type_checking.py +++ /dev/null @@ -1,31 +0,0 @@ -import nexusrpc - -import temporalio.nexus -from temporalio import workflow - - -def _(): - @nexusrpc.handler.service_handler - class MyService: - @nexusrpc.handler.sync_operation - async def my_sync_operation( - self, ctx: nexusrpc.handler.StartOperationContext, input: int - ) -> str: - raise NotImplementedError - - @temporalio.nexus.workflow_run_operation - async def my_workflow_run_operation( - self, ctx: temporalio.nexus.WorkflowRunOperationContext, input: int - ) -> temporalio.nexus.WorkflowHandle[str]: - raise NotImplementedError - - @workflow.defn(sandboxed=False) - class MyWorkflow: - @workflow.run - async def invoke_nexus_op_and_assert_error(self) -> None: - nexus_client = workflow.create_nexus_client( - service=MyService, - endpoint="fake-endpoint", - ) - await nexus_client.execute_operation(MyService.my_sync_operation, 1) - await nexus_client.execute_operation(MyService.my_workflow_run_operation, 1) diff --git a/tests/nexus/test_type_errors.py b/tests/nexus/test_type_errors.py new file mode 100644 index 000000000..1f5d3e2a7 --- /dev/null +++ b/tests/nexus/test_type_errors.py @@ -0,0 +1,207 @@ +""" +This file exists to test for type-checker false positives and false negatives. +It doesn't contain any test functions. +""" + +from dataclasses import dataclass + +import nexusrpc + +import temporalio.nexus +from temporalio import workflow + + +@dataclass +class MyInput: + pass + + +@dataclass +class MyOutput: + pass + + +@nexusrpc.service +class MyService: + my_sync_operation: nexusrpc.Operation[MyInput, MyOutput] + my_workflow_run_operation: nexusrpc.Operation[MyInput, MyOutput] + + +@nexusrpc.handler.service_handler(service=MyService) +class MyServiceHandler: + @nexusrpc.handler.sync_operation + async def my_sync_operation( + self, _ctx: nexusrpc.handler.StartOperationContext, _input: MyInput + ) -> MyOutput: + raise NotImplementedError + + @temporalio.nexus.workflow_run_operation + async def my_workflow_run_operation( + self, _ctx: temporalio.nexus.WorkflowRunOperationContext, _input: MyInput + ) -> temporalio.nexus.WorkflowHandle[MyOutput]: + raise NotImplementedError + + +@nexusrpc.handler.service_handler(service=MyService) +class MyServiceHandler2: + @nexusrpc.handler.sync_operation + async def my_sync_operation( + self, _ctx: nexusrpc.handler.StartOperationContext, _input: MyInput + ) -> MyOutput: + raise NotImplementedError + + @temporalio.nexus.workflow_run_operation + async def my_workflow_run_operation( + self, _ctx: temporalio.nexus.WorkflowRunOperationContext, _input: MyInput + ) -> temporalio.nexus.WorkflowHandle[MyOutput]: + raise NotImplementedError + + +@nexusrpc.handler.service_handler +class MyServiceHandlerWithoutServiceDefinition: + @nexusrpc.handler.sync_operation + async def my_sync_operation( + self, _ctx: nexusrpc.handler.StartOperationContext, _input: MyInput + ) -> MyOutput: + raise NotImplementedError + + @temporalio.nexus.workflow_run_operation + async def my_workflow_run_operation( + self, _ctx: temporalio.nexus.WorkflowRunOperationContext, _input: MyInput + ) -> temporalio.nexus.WorkflowHandle[MyOutput]: + raise NotImplementedError + + +@workflow.defn +class MyWorkflow1: + @workflow.run + async def test_invoke_by_operation_definition_happy_path(self) -> None: + """ + When a nexus client calls an operation by referencing an operation definition on + a service definition, the output type is inferred correctly. + """ + nexus_client = workflow.create_nexus_client( + service=MyService, + endpoint="fake-endpoint", + ) + input = MyInput() + + # sync operation + _output_1: MyOutput = await nexus_client.execute_operation( + MyService.my_sync_operation, input + ) + _handle_1: workflow.NexusOperationHandle[ + MyOutput + ] = await nexus_client.start_operation(MyService.my_sync_operation, input) + _output_1_1: MyOutput = await _handle_1 + + # workflow run operation + _output_2: MyOutput = await nexus_client.execute_operation( + MyService.my_workflow_run_operation, input + ) + _handle_2: workflow.NexusOperationHandle[ + MyOutput + ] = await nexus_client.start_operation( + MyService.my_workflow_run_operation, input + ) + _output_2_1: MyOutput = await _handle_2 + + +@workflow.defn +class MyWorkflow2: + @workflow.run + async def test_invoke_by_operation_handler_happy_path(self) -> None: + """ + When a nexus client calls an operation by referencing an operation handler on a + service handler, the output type is inferred correctly. + """ + nexus_client = workflow.create_nexus_client( + service=MyServiceHandler, # MyService would also work + endpoint="fake-endpoint", + ) + input = MyInput() + + # sync operation + _output_1: MyOutput = await nexus_client.execute_operation( + MyServiceHandler.my_sync_operation, input + ) + _handle_1: workflow.NexusOperationHandle[ + MyOutput + ] = await nexus_client.start_operation( + MyServiceHandler.my_sync_operation, input + ) + _output_1_1: MyOutput = await _handle_1 + + # workflow run operation + _output_2: MyOutput = await nexus_client.execute_operation( + MyServiceHandler.my_workflow_run_operation, input + ) + _handle_2: workflow.NexusOperationHandle[ + MyOutput + ] = await nexus_client.start_operation( + MyServiceHandler.my_workflow_run_operation, input + ) + _output_2_1: MyOutput = await _handle_2 + + +@workflow.defn +class MyWorkflow3: + @workflow.run + async def test_invoke_by_operation_definition_wrong_input_type(self) -> None: + """ + When a nexus client calls an operation by referencing an operation definition on + a service definition, there is a type error if the input type is wrong. + """ + nexus_client = workflow.create_nexus_client( + service=MyService, + endpoint="fake-endpoint", + ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + MyService.my_sync_operation, + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' + "wrong-input-type", # type: ignore + ) + + +@workflow.defn +class MyWorkflow4: + @workflow.run + async def test_invoke_by_operation_handler_wrong_input_type(self) -> None: + """ + When a nexus client calls an operation by referencing an operation handler on a + service handler, there is a type error if the input type is wrong. + """ + nexus_client = workflow.create_nexus_client( + service=MyServiceHandler, + endpoint="fake-endpoint", + ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + MyServiceHandler.my_sync_operation, # type: ignore[arg-type] + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' + "wrong-input-type", # type: ignore + ) + + +@workflow.defn +class MyWorkflow5: + @workflow.run + async def test_invoke_by_operation_handler_method_on_wrong_service(self) -> None: + """ + When a nexus client calls an operation by referencing an operation handler method + on a service handler, there is a type error if the method does not belong to the + service for which the client was created. + + (This form of type safety is not available when referencing an operation definition) + """ + nexus_client = workflow.create_nexus_client( + service=MyServiceHandler, + endpoint="fake-endpoint", + ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "operation"' + MyServiceHandler2.my_sync_operation, # type: ignore + MyInput(), + ) diff --git a/tests/test_client_type_errors.py b/tests/test_client_type_errors.py new file mode 100644 index 000000000..d0f3b6de5 --- /dev/null +++ b/tests/test_client_type_errors.py @@ -0,0 +1,245 @@ +""" +This file exists to test for type-checker false positives and false negatives. +It doesn't contain any test functions. +""" + +from dataclasses import dataclass +from unittest.mock import Mock + +from temporalio import workflow +from temporalio.client import ( + Client, + WithStartWorkflowOperation, + WorkflowHandle, + WorkflowUpdateHandle, + WorkflowUpdateStage, +) +from temporalio.common import WorkflowIDConflictPolicy +from temporalio.service import ServiceClient + + +@dataclass +class WorkflowInput: + pass + + +@dataclass +class SignalInput: + pass + + +@dataclass +class QueryInput: + pass + + +@dataclass +class UpdateInput: + pass + + +@dataclass +class WorkflowOutput: + pass + + +@dataclass +class QueryOutput: + pass + + +@dataclass +class UpdateOutput: + pass + + +@workflow.defn +class TestWorkflow: + @workflow.run + async def run(self, _: WorkflowInput) -> WorkflowOutput: + return WorkflowOutput() + + @workflow.signal + async def signal(self, _: SignalInput) -> None: + pass + + @workflow.query + async def query(self, _: QueryInput) -> QueryOutput: + return QueryOutput() + + @workflow.update + async def update(self, _: UpdateInput) -> UpdateOutput: + return UpdateOutput() + + +@workflow.defn +class TestWorkflow2(TestWorkflow): + @workflow.run + async def run(self, _: WorkflowInput) -> WorkflowOutput: + return WorkflowOutput() + + +async def _start_and_execute_workflow_code_for_type_checking_test(): + client = Client(service_client=Mock(spec=ServiceClient)) + + # Good + _handle: WorkflowHandle[TestWorkflow, WorkflowOutput] = await client.start_workflow( + TestWorkflow.run, WorkflowInput(), id="wid", task_queue="tq" + ) + + # id and task_queue are required + # TODO: this type error is misleading: it's resolving to an unexpected overload. + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "workflow" of type "str"' + await client.start_workflow(TestWorkflow.run, id="wid", task_queue="tq") # type: ignore + # assert-type-error-pyright: 'No overloads for "start_workflow" match' + await client.start_workflow( + TestWorkflow.run, + # assert-type-error-pyright: 'Argument of type "SignalInput" cannot be assigned to parameter' + SignalInput(), # type: ignore + id="wid", + task_queue="tq", + ) + + # Good + _output: WorkflowOutput = await client.execute_workflow( + TestWorkflow.run, WorkflowInput(), id="wid", task_queue="tq" + ) + # TODO: this type error is misleading: it's resolving to an unexpected overload. + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "workflow" of type "str"' + await client.execute_workflow(TestWorkflow.run, id="wid", task_queue="tq") # type: ignore + # assert-type-error-pyright: 'No overloads for "execute_workflow" match' + await client.execute_workflow( + TestWorkflow.run, + # assert-type-error-pyright: 'Argument of type "SignalInput" cannot be assigned to parameter' + SignalInput(), # type: ignore + id="wid", + task_queue="tq", + ) + + +async def _signal_workflow_code_for_type_checking_test(): + client = Client(service_client=Mock(spec=ServiceClient)) + handle: WorkflowHandle[TestWorkflow, WorkflowOutput] = await client.start_workflow( + TestWorkflow.run, WorkflowInput(), id="wid", task_queue="tq" + ) + + # Good + await handle.signal(TestWorkflow.signal, SignalInput()) + # TODO: this type error is misleading: it's resolving to an unexpected overload. + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "signal" of type "str"' + await handle.signal(TestWorkflow.signal) # type: ignore + + # TODO: this type error is misleading: it's resolving to an unexpected overload. + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "signal" of type "str"' + await handle.signal(TestWorkflow2.signal, SignalInput()) # type: ignore + + +async def _query_workflow_code_for_type_checking_test(): + client = Client(service_client=Mock(spec=ServiceClient)) + handle: WorkflowHandle[TestWorkflow, WorkflowOutput] = await client.start_workflow( + TestWorkflow.run, WorkflowInput(), id="wid", task_queue="tq" + ) + # Good + _: QueryOutput = await handle.query(TestWorkflow.query, QueryInput()) + # TODO: this type error is misleading: it's resolving to an unexpected overload. + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "query" of type "str"' + await handle.query(TestWorkflow.query) # type: ignore + # assert-type-error-pyright: 'Argument of type "SignalInput" cannot be assigned to parameter' + await handle.query(TestWorkflow.query, SignalInput()) # type: ignore + + # TODO: this type error is misleading: it's resolving to an unexpected overload. + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "query" of type "str"' + await handle.query(TestWorkflow2.query, QueryInput()) # type: ignore + + +async def _update_workflow_code_for_type_checking_test(): + client = Client(service_client=Mock(spec=ServiceClient)) + handle: WorkflowHandle[TestWorkflow, WorkflowOutput] = await client.start_workflow( + TestWorkflow.run, WorkflowInput(), id="wid", task_queue="tq" + ) + + # Good + _handle: WorkflowUpdateHandle[UpdateOutput] = await handle.start_update( + TestWorkflow.update, UpdateInput(), wait_for_stage=WorkflowUpdateStage.ACCEPTED + ) + # wait_for_stage is required + # assert-type-error-pyright: 'No overloads for "start_update" match' + await handle.start_update(TestWorkflow.update, UpdateInput()) # type: ignore + + # assert-type-error-pyright: 'No overloads for "start_update" match the provided arguments' + await handle.start_update(TestWorkflow2.update, UpdateInput()) # type: ignore + + # Good + _result: UpdateOutput = await handle.execute_update( + TestWorkflow.update, UpdateInput() + ) + # assert-type-error-pyright: 'No overloads for "execute_update" match' + await handle.execute_update( + TestWorkflow.update, + wait_for_stage=WorkflowUpdateStage.ACCEPTED, # type: ignore + ) + # assert-type-error-pyright: 'Argument of type "SignalInput" cannot be assigned to parameter' + await handle.execute_update(TestWorkflow.update, SignalInput()) # type: ignore + + # TODO: this type error is misleading: it's resolving to an unecpected overload. + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "update" of type "str"' + await handle.execute_update(TestWorkflow2.update, UpdateInput()) # type: ignore + + +async def _update_with_start_workflow_code_for_type_checking_test(): + client = Client(service_client=Mock(spec=ServiceClient)) + + # Good + with_start = WithStartWorkflowOperation( + TestWorkflow.run, + WorkflowInput(), + id="wid", + task_queue="tq", + id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING, + ) + _update_handle: WorkflowUpdateHandle[ + UpdateOutput + ] = await client.start_update_with_start_workflow( + TestWorkflow.update, + UpdateInput(), + wait_for_stage=WorkflowUpdateStage.ACCEPTED, + start_workflow_operation=with_start, + ) + _update_result: UpdateOutput = await _update_handle.result() + + _wf_handle: WorkflowHandle[ + TestWorkflow, WorkflowOutput + ] = await with_start.workflow_handle() + + _wf_result: WorkflowOutput = await _wf_handle.result() + + # id_conflict_policy is required + # assert-type-error-pyright: 'No overloads for "__init__" match' + with_start = WithStartWorkflowOperation( # type: ignore + TestWorkflow.run, + WorkflowInput(), + id="wid", + task_queue="tq", + ) + + # wait_for_stage is required + # assert-type-error-pyright: 'No overloads for "start_update_with_start_workflow" match' + await client.start_update_with_start_workflow( # type: ignore + TestWorkflow.update, UpdateInput(), start_workflow_operation=with_start + ) + + # Good + _update_result_2: UpdateOutput = await client.execute_update_with_start_workflow( + TestWorkflow.update, + UpdateInput(), + start_workflow_operation=with_start, + ) + + # cannot supply wait_for_stage + # assert-type-error-pyright: 'No overloads for "execute_update_with_start_workflow" match' + await client.execute_update_with_start_workflow( # type: ignore + TestWorkflow.update, + UpdateInput(), + start_workflow_operation=with_start, + wait_for_stage=WorkflowUpdateStage.ACCEPTED, + ) diff --git a/tests/test_type_errors.py b/tests/test_type_errors.py new file mode 100644 index 000000000..3c700ff37 --- /dev/null +++ b/tests/test_type_errors.py @@ -0,0 +1,178 @@ +""" +This file contains a test allowing assertions to be made that an expected type error is in +fact produced by the type-checker. I.e. that the type checker is not delivering a false +negative. + +To use the test, add a comment of the following form to your test code: + + # assert-type-error-pyright: 'No overloads for "execute_operation" match' await + nexus_client.execute_operation( # type: ignore + +The `type: ignore` is only necessary if your test code is being type-checked. + +This is a copy of https://github.com/nexus-rpc/sdk-python/blob/main/tests/test_type_errors.py + +Until a shared library is created, please keep the two in sync. +""" + +import itertools +import json +import os +import platform +import re +import subprocess +import tempfile +from pathlib import Path + +import pytest + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + """Dynamically generate test cases for files with type error assertions.""" + if metafunc.function.__name__ in [ + "test_type_errors_pyright", + "test_type_errors_mypy", + ]: + tests_dir = Path(__file__).parent + files_with_assertions = [] + + for test_file in tests_dir.rglob("test_*.py"): + if test_file.name == "test_type_errors.py": + continue + + if _has_type_error_assertions(test_file): + files_with_assertions.append(test_file) + + metafunc.parametrize( + "test_file", + files_with_assertions, + ids=lambda f: str(f.relative_to(tests_dir)), + ) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="TODO: broken on Windows") +def test_type_errors_pyright(test_file: Path): + """ + Validate type error assertions in a single test file using pyright. + + For each line with a comment of the form `# assert-type-error-pyright: "regex"`, + verify that pyright reports an error on the next non-comment line matching the regex. + Also verify that there are no unexpected type errors. + """ + _test_type_errors( + test_file, + _get_expected_errors(test_file, "pyright"), + _get_pyright_errors(test_file), + ) + + +def _test_type_errors( + test_file: Path, + expected_errors: dict[int, str], + actual_errors: dict[int, str], +) -> None: + for line_num, expected_pattern in sorted(expected_errors.items()): + if line_num not in actual_errors: + pytest.fail( + f"{test_file}:{line_num}: Expected type error matching '{expected_pattern}' but no error found" + ) + + actual_msg = actual_errors[line_num] + if not re.search(expected_pattern, actual_msg): + pytest.fail( + f"{test_file}:{line_num}: Expected error matching '{expected_pattern}' but got '{actual_msg}'" + ) + + for line_num, actual_msg in sorted(actual_errors.items()): + if line_num not in expected_errors: + pytest.fail(f"{test_file}:{line_num}: Unexpected type error: {actual_msg}") + + +def _has_type_error_assertions(test_file: Path) -> bool: + """Check if a file contains any type error assertions.""" + with open(test_file) as f: + return any(re.search(r"# assert-type-error-\w+:", line) for line in f) + + +def _get_expected_errors(test_file: Path, type_checker: str) -> dict[int, str]: + """Parse expected type errors from comments in a file for the specified type checker.""" + expected_errors = {} + + with open(test_file) as f: + lines = zip(itertools.count(1), f) + for line_num, line in lines: + if match := re.search( + rf'# assert-type-error-{re.escape(type_checker)}:\s*["\'](.+)["\']', + line, + ): + pattern = match.group(1) + for line_num, line in lines: + if line := line.strip(): + if not line.startswith("#"): + expected_errors[line_num] = pattern + break + + return expected_errors + + +def _get_pyright_errors(test_file: Path) -> dict[int, str]: + """Run pyright on a file and parse the actual type errors.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + # Create a temporary config file to disable type ignore comments + config_data = {"enableTypeIgnoreComments": False} + json.dump(config_data, f) + config_path = f.name + + try: + result = subprocess.run( + ["uv", "run", "pyright", "--project", config_path, str(test_file)], + capture_output=True, + text=True, + ) + + actual_errors = {} + abs_path = test_file.resolve() + + for line in result.stdout.splitlines(): + # pyright output format: /full/path/to/file.py:line:column - error: message (error_code) + if match := re.match( + rf"\s*{re.escape(str(abs_path))}:(\d+):\d+\s*-\s*error:\s*(.+)", line + ): + line_num = int(match.group(1)) + error_msg = match.group(2).strip() + # Remove error code in parentheses if present + error_msg = re.sub(r"\s*\([^)]+\)$", "", error_msg) + actual_errors[line_num] = error_msg + + return actual_errors + finally: + if os.path.exists(config_path): + os.unlink(config_path) + + +def _get_mypy_errors(test_file: Path) -> dict[int, str]: # pyright: ignore[reportUnusedFunction] + """Run mypy on a file and parse the actual type errors. + + Note: mypy does not have a direct equivalent to pyright's enableTypeIgnoreComments=false, + so type ignore comments will still be respected by mypy. Users should avoid placing + # type: ignore comments on lines they want to test, or manually remove them for testing. + """ + result = subprocess.run( + ["uv", "run", "mypy", str(test_file)], + capture_output=True, + text=True, + ) + + actual_errors = {} + abs_path = test_file.resolve() + + for line in result.stdout.splitlines(): + # mypy output format: file.py:line: error: message + if match := re.match( + rf"{re.escape(str(abs_path))}:(\d+):\s*error:\s*(.+)", line + ): + line_num = int(match.group(1)) + error_msg = match.group(2).strip() + actual_errors[line_num] = error_msg + + return actual_errors diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 21dd7680e..cfeeb91b7 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -225,7 +225,8 @@ def update2(self, arg1: str): pass # Intentionally missing decorator - def base_update(self): + # assert-type-error-pyright: "overrides symbol of same name" + def base_update(self): # type: ignore pass @@ -288,7 +289,8 @@ def run(self): def test_workflow_defn_non_async_run(): with pytest.raises(ValueError) as err: - workflow.run(NonAsyncRun.run) + # assert-type-error-pyright: "Argument .+ cannot be assigned to parameter" + workflow.run(NonAsyncRun.run) # type: ignore assert "must be an async function" in str(err.value) @@ -451,6 +453,7 @@ class BadUpdateValidator: def my_update(self, a: str): pass + # assert-type-error-pyright: "Argument of type .+ cannot be assigned to parameter" @my_update.validator # type: ignore def my_validator(self, a: int): pass