Skip to content

Commit 4eceb4a

Browse files
committed
Add Nexus type error false-negative tests
1 parent 33b4a43 commit 4eceb4a

File tree

5 files changed

+419
-35
lines changed

5 files changed

+419
-35
lines changed

temporalio/workflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5280,13 +5280,16 @@ async def execute_operation(
52805280
headers: Optional[Mapping[str, str]] = None,
52815281
) -> OutputT: ...
52825282

5283+
# TODO(nexus-preview): in practice, both these overloads match an async def sync
5284+
# operation (i.e. either can be deleted without causing a type error).
5285+
52835286
# Overload for sync_operation methods (async def)
52845287
@overload
52855288
@abstractmethod
52865289
async def execute_operation(
52875290
self,
52885291
operation: Callable[
5289-
[ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT],
5292+
[ServiceT, nexusrpc.handler.StartOperationContext, InputT],
52905293
Awaitable[OutputT],
52915294
],
52925295
input: InputT,
@@ -5302,7 +5305,7 @@ async def execute_operation(
53025305
async def execute_operation(
53035306
self,
53045307
operation: Callable[
5305-
[ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT],
5308+
[ServiceT, nexusrpc.handler.StartOperationContext, InputT],
53065309
OutputT,
53075310
],
53085311
input: InputT,

tests/nexus/test_type_checking.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tests/nexus/test_type_errors.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""
2+
This file exists to test for type-checker false positives and false negatives.
3+
It doesn't contain any test functions.
4+
"""
5+
6+
from dataclasses import dataclass
7+
8+
import nexusrpc
9+
10+
import temporalio.nexus
11+
from temporalio import workflow
12+
13+
14+
@dataclass
15+
class MyInput:
16+
pass
17+
18+
19+
@dataclass
20+
class MyOutput:
21+
pass
22+
23+
24+
@nexusrpc.service
25+
class MyService:
26+
my_sync_operation: nexusrpc.Operation[MyInput, MyOutput]
27+
my_workflow_run_operation: nexusrpc.Operation[MyInput, MyOutput]
28+
29+
30+
@nexusrpc.handler.service_handler(service=MyService)
31+
class MyServiceHandler:
32+
@nexusrpc.handler.sync_operation
33+
async def my_sync_operation(
34+
self, _ctx: nexusrpc.handler.StartOperationContext, _input: MyInput
35+
) -> MyOutput:
36+
raise NotImplementedError
37+
38+
@temporalio.nexus.workflow_run_operation
39+
async def my_workflow_run_operation(
40+
self, _ctx: temporalio.nexus.WorkflowRunOperationContext, _input: MyInput
41+
) -> temporalio.nexus.WorkflowHandle[MyOutput]:
42+
raise NotImplementedError
43+
44+
45+
@nexusrpc.handler.service_handler(service=MyService)
46+
class MyServiceHandler2:
47+
@nexusrpc.handler.sync_operation
48+
async def my_sync_operation(
49+
self, _ctx: nexusrpc.handler.StartOperationContext, _input: MyInput
50+
) -> MyOutput:
51+
raise NotImplementedError
52+
53+
@temporalio.nexus.workflow_run_operation
54+
async def my_workflow_run_operation(
55+
self, _ctx: temporalio.nexus.WorkflowRunOperationContext, _input: MyInput
56+
) -> temporalio.nexus.WorkflowHandle[MyOutput]:
57+
raise NotImplementedError
58+
59+
60+
@nexusrpc.handler.service_handler
61+
class MyServiceHandlerWithoutServiceDefinition:
62+
@nexusrpc.handler.sync_operation
63+
async def my_sync_operation(
64+
self, _ctx: nexusrpc.handler.StartOperationContext, _input: MyInput
65+
) -> MyOutput:
66+
raise NotImplementedError
67+
68+
@temporalio.nexus.workflow_run_operation
69+
async def my_workflow_run_operation(
70+
self, _ctx: temporalio.nexus.WorkflowRunOperationContext, _input: MyInput
71+
) -> temporalio.nexus.WorkflowHandle[MyOutput]:
72+
raise NotImplementedError
73+
74+
75+
@workflow.defn
76+
class MyWorkflow1:
77+
@workflow.run
78+
async def test_invoke_by_operation_definition_happy_path(self) -> None:
79+
"""
80+
When a nexus client calls an operation by referencing an operation definition on
81+
a service definition, the output type is inferred correctly.
82+
"""
83+
nexus_client = workflow.create_nexus_client(
84+
service=MyService,
85+
endpoint="fake-endpoint",
86+
)
87+
input = MyInput()
88+
89+
# sync operation
90+
_output_1: MyOutput = await nexus_client.execute_operation(
91+
MyService.my_sync_operation, input
92+
)
93+
_handle_1: workflow.NexusOperationHandle[
94+
MyOutput
95+
] = await nexus_client.start_operation(MyService.my_sync_operation, input)
96+
_output_1_1: MyOutput = await _handle_1
97+
98+
# workflow run operation
99+
_output_2: MyOutput = await nexus_client.execute_operation(
100+
MyService.my_workflow_run_operation, input
101+
)
102+
_handle_2: workflow.NexusOperationHandle[
103+
MyOutput
104+
] = await nexus_client.start_operation(
105+
MyService.my_workflow_run_operation, input
106+
)
107+
_output_2_1: MyOutput = await _handle_2
108+
109+
110+
@workflow.defn
111+
class MyWorkflow2:
112+
@workflow.run
113+
async def test_invoke_by_operation_handler_happy_path(self) -> None:
114+
"""
115+
When a nexus client calls an operation by referencing an operation handler on a
116+
service handler, the output type is inferred correctly.
117+
"""
118+
nexus_client = workflow.create_nexus_client(
119+
service=MyServiceHandler, # MyService would also work
120+
endpoint="fake-endpoint",
121+
)
122+
input = MyInput()
123+
124+
# sync operation
125+
_output_1: MyOutput = await nexus_client.execute_operation(
126+
MyServiceHandler.my_sync_operation, input
127+
)
128+
_handle_1: workflow.NexusOperationHandle[
129+
MyOutput
130+
] = await nexus_client.start_operation(
131+
MyServiceHandler.my_sync_operation, input
132+
)
133+
_output_1_1: MyOutput = await _handle_1
134+
135+
# workflow run operation
136+
_output_2: MyOutput = await nexus_client.execute_operation(
137+
MyServiceHandler.my_workflow_run_operation, input
138+
)
139+
_handle_2: workflow.NexusOperationHandle[
140+
MyOutput
141+
] = await nexus_client.start_operation(
142+
MyServiceHandler.my_workflow_run_operation, input
143+
)
144+
_output_2_1: MyOutput = await _handle_2
145+
146+
147+
@workflow.defn
148+
class MyWorkflow3:
149+
@workflow.run
150+
async def test_invoke_by_operation_name_happy_path(self) -> None:
151+
"""
152+
When a nexus client calls an operation by referencing an operation name, the
153+
output type is inferred as Unknown.
154+
"""
155+
nexus_client = workflow.create_nexus_client(
156+
service=MyServiceHandler,
157+
endpoint="fake-endpoint",
158+
)
159+
input = MyInput()
160+
# TODO: mypy fails these since no type is inferred, so we're forced to add a
161+
# `type: ignore`. As a result this function doesn't currently prove anything, but
162+
# one can confirm the inferred type is Unknown in an IDE.
163+
_output_1 = await nexus_client.execute_operation("my_sync_operation", input) # type: ignore[var-annotated]
164+
_output_2 = await nexus_client.execute_operation( # type: ignore[var-annotated]
165+
"my_workflow_run_operation", input
166+
)
167+
168+
169+
@workflow.defn
170+
class MyWorkflow4:
171+
@workflow.run
172+
async def test_invoke_by_operation_definition_wrong_input_type(self) -> None:
173+
"""
174+
When a nexus client calls an operation by referencing an operation definition on
175+
a service definition, there is a type error if the input type is wrong.
176+
"""
177+
nexus_client = workflow.create_nexus_client(
178+
service=MyService,
179+
endpoint="fake-endpoint",
180+
)
181+
# assert-type-error-pyright: 'No overloads for "execute_operation" match'
182+
await nexus_client.execute_operation( # type: ignore
183+
MyService.my_sync_operation,
184+
# assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"'
185+
"wrong-input-type", # type: ignore
186+
)
187+
188+
189+
@workflow.defn
190+
class MyWorkflow5:
191+
@workflow.run
192+
async def test_invoke_by_operation_handler_wrong_input_type(self) -> None:
193+
"""
194+
When a nexus client calls an operation by referencing an operation handler on a
195+
service handler, there is a type error if the input type is wrong.
196+
"""
197+
nexus_client = workflow.create_nexus_client(
198+
service=MyServiceHandler,
199+
endpoint="fake-endpoint",
200+
)
201+
# assert-type-error-pyright: 'No overloads for "execute_operation" match'
202+
await nexus_client.execute_operation( # type: ignore
203+
MyServiceHandler.my_sync_operation,
204+
# assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"'
205+
"wrong-input-type", # type: ignore
206+
)
207+
208+
209+
@workflow.defn
210+
class MyWorkflow6:
211+
@workflow.run
212+
async def test_invoke_by_operation_handler_method_on_wrong_service(self) -> None:
213+
"""
214+
When a nexus client calls an operation by referencing an operation handler method
215+
on a service handler, there is a type error if the method does not belong to the
216+
service for which the client was created.
217+
218+
(This form of type safety is not available when referencing an operation definition)
219+
"""
220+
nexus_client = workflow.create_nexus_client(
221+
service=MyServiceHandler,
222+
endpoint="fake-endpoint",
223+
)
224+
# assert-type-error-pyright: 'No overloads for "execute_operation" match'
225+
await nexus_client.execute_operation( # type: ignore
226+
# assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "operation"'
227+
MyServiceHandler2.my_sync_operation, # type: ignore
228+
MyInput(),
229+
)

0 commit comments

Comments
 (0)