Skip to content

Commit 2443c50

Browse files
authored
Type-checking (#962)
1 parent 28f43f1 commit 2443c50

18 files changed

+95
-97
lines changed

pyproject.toml

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ dev = [
4545
"psutil>=5.9.3,<6",
4646
"pydocstyle>=6.3.0,<7",
4747
"pydoctor>=24.11.1,<25",
48-
"pyright==1.1.402",
48+
"pyright==1.1.403",
4949
"pytest~=7.4",
5050
"pytest-asyncio>=0.21,<0.22",
5151
"pytest-timeout~=2.2",
@@ -69,14 +69,16 @@ lint = [
6969
{cmd = "uv run ruff check --select I"},
7070
{cmd = "uv run ruff format --check"},
7171
{ref = "lint-types"},
72-
{cmd = "uv run pyright"},
7372
{ref = "lint-docs"},
7473
]
7574
bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" }
7675
# TODO(cretz): Why does pydocstyle complain about @overload missing docs after
7776
# https://github.com/PyCQA/pydocstyle/pull/511?
7877
lint-docs = "uv run pydocstyle --ignore-decorators=overload"
79-
lint-types = "uv run mypy --namespace-packages --check-untyped-defs ."
78+
lint-types = [
79+
{ cmd = "uv run pyright"},
80+
{ cmd = "uv run mypy --namespace-packages --check-untyped-defs ."},
81+
]
8082
run-bench = "uv run python scripts/run_bench.py"
8183
test = "uv run pytest"
8284

@@ -100,7 +102,7 @@ filterwarnings = [
100102
[tool.cibuildwheel]
101103
before-all = "pip install protoc-wheel-0"
102104
build = "cp39-win_amd64 cp39-manylinux_x86_64 cp39-manylinux_aarch64 cp39-macosx_x86_64 cp39-macosx_arm64"
103-
build-verbosity = "1"
105+
build-verbosity = 1
104106

105107
[tool.cibuildwheel.macos]
106108
environment = { MACOSX_DEPLOYMENT_TARGET = "10.12" }
@@ -158,16 +160,24 @@ project-name = "Temporal Python"
158160
sidebar-expand-depth = 2
159161

160162
[tool.pyright]
161-
reportUnknownVariableType = "none"
162-
reportUnknownParameterType = "none"
163-
reportUnusedCallResult = "none"
164-
reportImplicitStringConcatenation = "none"
165-
reportPrivateUsage = "none"
163+
enableTypeIgnoreComments = true
164+
reportAny = "none"
165+
reportCallInDefaultInitializer = "none"
166166
reportExplicitAny = "none"
167+
reportIgnoreCommentWithoutRule = "none"
168+
reportImplicitOverride = "none"
169+
reportImplicitStringConcatenation = "none"
170+
reportImportCycles = "none"
167171
reportMissingTypeArgument = "none"
168-
reportAny = "none"
169-
enableTypeIgnoreComments = true
170-
172+
reportPrivateUsage = "none"
173+
reportUnannotatedClassAttribute = "none"
174+
reportUnknownArgumentType = "none"
175+
reportUnknownMemberType = "none"
176+
reportUnknownParameterType = "none"
177+
reportUnknownVariableType = "none"
178+
reportUnnecessaryIsInstance = "none"
179+
reportUnnecessaryTypeIgnoreComment = "none"
180+
reportUnusedCallResult = "none"
171181
include = ["temporalio", "tests"]
172182
exclude = [
173183
"temporalio/api",

temporalio/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5772,7 +5772,7 @@ async def get_worker_task_reachability(
57725772

57735773

57745774
class _ClientImpl(OutboundInterceptor):
5775-
def __init__(self, client: Client) -> None:
5775+
def __init__(self, client: Client) -> None: # type: ignore
57765776
# We are intentionally not calling the base class's __init__ here
57775777
self._client = client
57785778

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from agents.models.multi_provider import MultiProvider
2727
from typing_extensions import Required, TypedDict
2828

29-
from temporalio import activity, workflow
29+
from temporalio import activity
3030
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
3131

3232

@@ -106,7 +106,7 @@ class ActivityModelInput(TypedDict, total=False):
106106

107107
model_name: Optional[str]
108108
system_instructions: Optional[str]
109-
input: Required[Union[str, list[TResponseInputItem]]] # type: ignore
109+
input: Required[Union[str, list[TResponseInputItem]]]
110110
model_settings: Required[ModelSettings]
111111
tools: list[ToolInput]
112112
output_schema: Optional[AgentOutputSchemaInput]

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
from __future__ import annotations
22

33
import logging
4-
from datetime import timedelta
54
from typing import Optional
65

76
from temporalio import workflow
8-
from temporalio.common import Priority, RetryPolicy
97
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
10-
from temporalio.workflow import ActivityCancellationType, VersioningIntent
118

129
logger = logging.getLogger(__name__)
1310

14-
from typing import Any, AsyncIterator, Optional, Sequence, Union, cast
11+
from typing import Any, AsyncIterator, Sequence, Union, cast
1512

1613
from agents import (
1714
AgentOutputSchema,
@@ -57,7 +54,7 @@ def __init__(
5754
async def get_response(
5855
self,
5956
system_instructions: Optional[str],
60-
input: Union[str, list[TResponseInputItem]],
57+
input: Union[str, list[TResponseInputItem], dict[str, str]],
6158
model_settings: ModelSettings,
6259
tools: list[Tool],
6360
output_schema: Optional[AgentOutputSchemaBase],
@@ -67,7 +64,9 @@ async def get_response(
6764
previous_response_id: Optional[str],
6865
prompt: Optional[ResponsePromptParam],
6966
) -> ModelResponse:
70-
def get_summary(input: Union[str, list[TResponseInputItem]]) -> str:
67+
def get_summary(
68+
input: Union[str, list[TResponseInputItem], dict[str, str]],
69+
) -> str:
7170
### Activity summary shown in the UI
7271
try:
7372
max_size = 100
@@ -88,21 +87,18 @@ def get_summary(input: Union[str, list[TResponseInputItem]]) -> str:
8887
return ""
8988

9089
def make_tool_info(tool: Tool) -> ToolInput:
91-
if isinstance(tool, FileSearchTool):
92-
return cast(FileSearchTool, tool)
93-
elif isinstance(tool, WebSearchTool):
94-
return cast(WebSearchTool, tool)
90+
if isinstance(tool, (FileSearchTool, WebSearchTool)):
91+
return tool
9592
elif isinstance(tool, ComputerTool):
9693
raise NotImplementedError(
9794
"Computer search preview is not supported in Temporal model"
9895
)
9996
elif isinstance(tool, FunctionTool):
100-
t = cast(FunctionToolInput, tool)
10197
return FunctionToolInput(
102-
name=t.name,
103-
description=t.description,
104-
params_json_schema=t.params_json_schema,
105-
strict_json_schema=t.strict_json_schema,
98+
name=tool.name,
99+
description=tool.description,
100+
params_json_schema=tool.params_json_schema,
101+
strict_json_schema=tool.strict_json_schema,
106102
)
107103
else:
108104
raise ValueError(f"Unknown tool type: {tool.name}")
@@ -141,7 +137,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
141137
activity_input = ActivityModelInput(
142138
model_name=self.model_name,
143139
system_instructions=system_instructions,
144-
input=input,
140+
input=cast(Union[str, list[TResponseInputItem]], input),
145141
model_settings=model_settings,
146142
tools=tool_infos,
147143
output_schema=output_schema_input,
@@ -169,7 +165,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
169165
def stream_response(
170166
self,
171167
system_instructions: Optional[str],
172-
input: Union[str, list][TResponseInputItem], # type: ignore
168+
input: Union[str, list[TResponseInputItem]],
173169
model_settings: ModelSettings,
174170
tools: list[Tool],
175171
output_schema: Optional[AgentOutputSchemaBase],

temporalio/contrib/openai_agents/_trace_interceptor.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
from __future__ import annotations
44

55
from contextlib import contextmanager
6-
from typing import Any, Mapping, Protocol, Type, cast
6+
from typing import Any, Mapping, Protocol, Type
77

8-
from agents import CustomSpanData, custom_span, get_current_span, trace
8+
from agents import custom_span, get_current_span, trace
99
from agents.tracing import (
1010
get_trace_provider,
1111
)
12-
from agents.tracing.provider import DefaultTraceProvider
13-
from agents.tracing.spans import NoOpSpan, SpanImpl
12+
from agents.tracing.spans import NoOpSpan
1413

1514
import temporalio.activity
1615
import temporalio.api.common.v1
@@ -116,7 +115,7 @@ class OpenAIAgentsTracingInterceptor(
116115
worker = Worker(client, task_queue="my-task-queue", interceptors=[interceptor])
117116
"""
118117

119-
def __init__(
118+
def __init__( # type: ignore[reportMissingSuperCall]
120119
self,
121120
payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter,
122121
) -> None:
@@ -189,7 +188,7 @@ async def start_workflow(
189188
**({"temporal:workflowId": input.id} if input.id else {}),
190189
}
191190
data = {"workflowId": input.id} if input.id else None
192-
span_name = f"temporal:startWorkflow"
191+
span_name = "temporal:startWorkflow"
193192
if get_trace_provider().get_current_trace() is None:
194193
with trace(
195194
span_name + ":" + input.workflow, metadata=metadata, group_id=input.id
@@ -208,7 +207,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A
208207
**({"temporal:workflowId": input.id} if input.id else {}),
209208
}
210209
data = {"workflowId": input.id, "query": input.query}
211-
span_name = f"temporal:queryWorkflow"
210+
span_name = "temporal:queryWorkflow"
212211
if get_trace_provider().get_current_trace() is None:
213212
with trace(span_name, metadata=metadata, group_id=input.id):
214213
with custom_span(name=span_name, data=data):
@@ -227,7 +226,7 @@ async def signal_workflow(
227226
**({"temporal:workflowId": input.id} if input.id else {}),
228227
}
229228
data = {"workflowId": input.id, "signal": input.signal}
230-
span_name = f"temporal:signalWorkflow"
229+
span_name = "temporal:signalWorkflow"
231230
if get_trace_provider().get_current_trace() is None:
232231
with trace(span_name, metadata=metadata, group_id=input.id):
233232
with custom_span(name=span_name, data=data):

temporalio/nexus/_operation_context.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,14 @@
22

33
import dataclasses
44
import logging
5+
from collections.abc import Awaitable, Mapping, MutableMapping, Sequence
56
from contextvars import ContextVar
67
from dataclasses import dataclass
78
from datetime import timedelta
89
from typing import (
910
Any,
10-
Awaitable,
1111
Callable,
12-
Mapping,
13-
MutableMapping,
1412
Optional,
15-
Sequence,
16-
Type,
1713
Union,
1814
overload,
1915
)
@@ -305,7 +301,7 @@ async def start_workflow(
305301
args: Sequence[Any] = [],
306302
id: str,
307303
task_queue: Optional[str] = None,
308-
result_type: Optional[Type[ReturnType]] = None,
304+
result_type: Optional[type[ReturnType]] = None,
309305
execution_timeout: Optional[timedelta] = None,
310306
run_timeout: Optional[timedelta] = None,
311307
task_timeout: Optional[timedelta] = None,
@@ -340,7 +336,7 @@ async def start_workflow(
340336
args: Sequence[Any] = [],
341337
id: str,
342338
task_queue: Optional[str] = None,
343-
result_type: Optional[Type] = None,
339+
result_type: Optional[type] = None,
344340
execution_timeout: Optional[timedelta] = None,
345341
run_timeout: Optional[timedelta] = None,
346342
task_timeout: Optional[timedelta] = None,

temporalio/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _on_logs(
227227
# We can't access logging module's start time and it's not worth
228228
# doing difference math to get relative time right here, so
229229
# we'll make time relative to _our_ module's start time
230-
self.relativeCreated = (record.created - _module_start_time) * 1000
230+
self.relativeCreated = (record.created - _module_start_time) * 1000 # type: ignore[reportUninitializedInstanceVariable]
231231
# Log the record
232232
self.logger.handle(record)
233233

temporalio/worker/_worker.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,9 @@
2424

2525
from typing_extensions import TypeAlias, TypedDict
2626

27-
import temporalio.activity
28-
import temporalio.api.common.v1
29-
import temporalio.bridge.client
30-
import temporalio.bridge.proto
31-
import temporalio.bridge.proto.activity_result
32-
import temporalio.bridge.proto.activity_task
33-
import temporalio.bridge.proto.common
3427
import temporalio.bridge.worker
3528
import temporalio.client
36-
import temporalio.converter
37-
import temporalio.exceptions
29+
import temporalio.common
3830
import temporalio.runtime
3931
import temporalio.service
4032
from temporalio.common import (
@@ -578,8 +570,8 @@ def config(self) -> WorkerConfig:
578570
Configuration, shallow-copied.
579571
"""
580572
config = self._config.copy()
581-
config["activities"] = list(config["activities"])
582-
config["workflows"] = list(config["workflows"])
573+
config["activities"] = list(config.get("activities", []))
574+
config["workflows"] = list(config.get("workflows", []))
583575
return config
584576

585577
@property

temporalio/worker/_workflow_instance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,15 +2514,15 @@ def get_debug(self) -> bool:
25142514

25152515

25162516
class _WorkflowInboundImpl(WorkflowInboundInterceptor):
2517-
def __init__(
2517+
def __init__( # type: ignore
25182518
self,
25192519
instance: _WorkflowInstanceImpl,
25202520
) -> None:
25212521
# We are intentionally not calling the base class's __init__ here
25222522
self._instance = instance
25232523

25242524
def init(self, outbound: WorkflowOutboundInterceptor) -> None:
2525-
self._outbound = outbound
2525+
self._outbound = outbound # type: ignore
25262526

25272527
async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
25282528
args = [self._instance._object] + list(input.args)
@@ -2572,7 +2572,7 @@ async def handle_update_handler(self, input: HandleUpdateInput) -> Any:
25722572

25732573

25742574
class _WorkflowOutboundImpl(WorkflowOutboundInterceptor):
2575-
def __init__(self, instance: _WorkflowInstanceImpl) -> None:
2575+
def __init__(self, instance: _WorkflowInstanceImpl) -> None: # type: ignore
25762576
# We are intentionally not calling the base class's __init__ here
25772577
self._instance = instance
25782578

0 commit comments

Comments
 (0)