Skip to content

Commit dc23266

Browse files
committed
Finally! Type system defeated!
1 parent 668532d commit dc23266

File tree

5 files changed

+179
-65
lines changed

5 files changed

+179
-65
lines changed

temporalio/client.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,62 @@ async def terminate(
16151615
)
16161616
)
16171617

1618+
# Overload for no-param update
1619+
@overload
1620+
async def update(
1621+
self,
1622+
update: temporalio.workflow.UpdateMethodMultiArg[[SelfType], LocalReturnType],
1623+
*,
1624+
id: Optional[str] = None,
1625+
rpc_metadata: Mapping[str, str] = {},
1626+
rpc_timeout: Optional[timedelta] = None,
1627+
) -> LocalReturnType:
1628+
...
1629+
1630+
# Overload for single-param update
1631+
@overload
1632+
async def update(
1633+
self,
1634+
update: temporalio.workflow.UpdateMethodMultiArg[
1635+
[SelfType, ParamType], LocalReturnType
1636+
],
1637+
arg: ParamType,
1638+
*,
1639+
id: Optional[str] = None,
1640+
rpc_metadata: Mapping[str, str] = {},
1641+
rpc_timeout: Optional[timedelta] = None,
1642+
) -> LocalReturnType:
1643+
...
1644+
1645+
@overload
1646+
async def update(
1647+
self,
1648+
update: temporalio.workflow.UpdateMethodMultiArg[
1649+
MultiParamSpec, LocalReturnType
1650+
],
1651+
*,
1652+
args: MultiParamSpec.args,
1653+
id: Optional[str] = None,
1654+
rpc_metadata: Mapping[str, str] = {},
1655+
rpc_timeout: Optional[timedelta] = None,
1656+
) -> LocalReturnType:
1657+
...
1658+
1659+
# Overload for string-name update
1660+
@overload
1661+
async def update(
1662+
self,
1663+
update: str,
1664+
arg: Any = temporalio.common._arg_unset,
1665+
*,
1666+
args: Sequence[Any] = [],
1667+
id: Optional[str] = None,
1668+
result_type: Optional[Type] = None,
1669+
rpc_metadata: Mapping[str, str] = {},
1670+
rpc_timeout: Optional[timedelta] = None,
1671+
) -> Any:
1672+
...
1673+
16181674
async def update(
16191675
self,
16201676
update: Union[str, Callable],
@@ -1701,15 +1757,16 @@ async def start_update(
17011757
"""
17021758
update_name: str
17031759
ret_type = result_type
1704-
if callable(update):
1705-
if not isinstance(update, temporalio.workflow.update):
1760+
if isinstance(update, temporalio.workflow.UpdateMethodMultiArg):
1761+
defn = update._defn
1762+
if not defn:
17061763
raise RuntimeError(
17071764
f"Update definition not found on {update.__qualname__}, "
17081765
"is it decorated with @workflow.update?"
17091766
)
1710-
defn = update._defn
1711-
if not defn.name:
1767+
elif not defn.name:
17121768
raise RuntimeError("Cannot invoke dynamic update definition")
1769+
# TODO(cretz): Check count/type of args at runtime?
17131770
update_name = defn.name
17141771
ret_type = defn.ret_type
17151772
else:

temporalio/worker/_workflow_instance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ async def run_update() -> None:
439439
job.input,
440440
defn.name,
441441
defn.arg_types,
442-
False,
442+
defn.dynamic_vararg,
443443
)
444444
handler_input = HandleUpdateInput(
445445
# TODO: update id vs proto instance id
@@ -1013,6 +1013,10 @@ def workflow_set_update_handler(
10131013
if validator is not None:
10141014
defn.set_validator(validator)
10151015
self._updates[name] = defn
1016+
if defn.dynamic_vararg:
1017+
raise RuntimeError(
1018+
"Dynamic updates do not support a vararg third param, use Sequence[RawValue]",
1019+
)
10161020
else:
10171021
self._updates.pop(name, None)
10181022

temporalio/workflow.py

Lines changed: 103 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
overload,
3737
)
3838

39-
from typing_extensions import Concatenate, Literal, TypedDict
39+
from typing_extensions import (
40+
Concatenate,
41+
Literal,
42+
Protocol,
43+
TypedDict,
44+
runtime_checkable,
45+
)
4046

4147
import temporalio.api.common.v1
4248
import temporalio.bridge.proto.child_workflow
@@ -64,6 +70,7 @@
6470
MethodSyncSingleParam,
6571
MultiParamSpec,
6672
ParamType,
73+
ProtocolReturnType,
6774
ReturnType,
6875
SelfType,
6976
)
@@ -764,8 +771,64 @@ def time_ns() -> int:
764771
return _Runtime.current().workflow_time_ns()
765772

766773

767-
# noinspection PyPep8Naming
768-
class update(object):
774+
# Needs to be defined here to avoid a circular import
775+
@runtime_checkable
776+
class UpdateMethodMultiArg(Protocol[MultiParamSpec, ProtocolReturnType]):
777+
"""Decorated workflow update functions implement this."""
778+
779+
_defn: temporalio.workflow._UpdateDefinition
780+
781+
def __call__(
782+
self, *args: MultiParamSpec.args, **kwargs: MultiParamSpec.kwargs
783+
) -> Union[ProtocolReturnType, Awaitable[ProtocolReturnType]]:
784+
"""Generic callable type callback."""
785+
...
786+
787+
def validator(self, vfunc: Callable[MultiParamSpec, None]) -> None:
788+
"""Use to decorate a function to validate the arguments passed to the update handler."""
789+
...
790+
791+
792+
@overload
793+
def update(
794+
fn: Callable[MultiParamSpec, Awaitable[ReturnType]]
795+
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
796+
...
797+
798+
799+
@overload
800+
def update(
801+
fn: Callable[MultiParamSpec, ReturnType]
802+
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
803+
...
804+
805+
806+
@overload
807+
def update(
808+
*, name: str
809+
) -> Callable[
810+
[Callable[MultiParamSpec, ReturnType]],
811+
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
812+
]:
813+
...
814+
815+
816+
@overload
817+
def update(
818+
*, dynamic: Literal[True]
819+
) -> Callable[
820+
[Callable[MultiParamSpec, ReturnType]],
821+
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
822+
]:
823+
...
824+
825+
826+
def update(
827+
fn: Optional[CallableSyncOrAsyncType] = None,
828+
*,
829+
name: Optional[str] = None,
830+
dynamic: Optional[bool] = False,
831+
):
769832
"""Decorator for a workflow update handler method.
770833
771834
This is set on any async or non-async method that you wish to be called upon
@@ -791,44 +854,33 @@ class update(object):
791854
present.
792855
"""
793856

794-
def __init__(
795-
self,
796-
fn: Optional[CallableSyncOrAsyncType] = None,
797-
*,
798-
name: Optional[str] = None,
799-
dynamic: Optional[bool] = False,
800-
):
801-
"""See :py:class:`update`."""
802-
if name is not None or dynamic:
803-
if name is not None and dynamic:
804-
raise RuntimeError("Cannot provide name and dynamic boolean")
805-
self._fn = fn
806-
self._name = (
807-
name if name is not None else self._fn.__name__ if self._fn else None
808-
)
809-
self._dynamic = dynamic
810-
if self._fn is not None:
811-
# Only bother to assign the definition if we are given a function. The function is not provided when
812-
# extra arguments are specified - in that case, the __call__ method is invoked instead.
813-
self._assign_defn()
814-
815-
def __call__(self, fn: CallableSyncOrAsyncType):
816-
"""Call the update decorator (as when passing optional arguments)."""
817-
self._fn = fn
818-
self._assign_defn()
819-
return self
820-
821-
def _assign_defn(self) -> None:
822-
assert self._fn is not None
823-
self._defn = _UpdateDefinition(name=self._name, fn=self._fn, is_method=True)
824-
825-
def validator(self, fn: Callable[..., None]):
826-
"""Decorator for a workflow update validator method. Apply this decorator to a function to have it run before
827-
the update handler. If it throws an error, the update will be rejected. The validator must not mutate workflow
828-
state at all, and cannot call workflow functions which would schedule new commands (ex: starting an
829-
activity).
830-
"""
831-
self._defn.set_validator(fn)
857+
def with_name(
858+
name: Optional[str], fn: CallableSyncOrAsyncType
859+
) -> CallableSyncOrAsyncType:
860+
defn = _UpdateDefinition(name=name, fn=fn, is_method=True)
861+
if defn.dynamic_vararg:
862+
raise RuntimeError(
863+
"Dynamic updates do not support a vararg third param, use Sequence[RawValue]",
864+
)
865+
setattr(fn, "_defn", defn)
866+
setattr(fn, "validator", partial(_update_validator, defn))
867+
return fn
868+
869+
if name is not None or dynamic:
870+
if name is not None and dynamic:
871+
raise RuntimeError("Cannot provide name and dynamic boolean")
872+
return partial(with_name, name)
873+
if fn is None:
874+
raise RuntimeError("Cannot create update without function or name or dynamic")
875+
return with_name(fn.__name__, fn)
876+
877+
878+
def _update_validator(
879+
update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None
880+
):
881+
"""Decorator for a workflow update validator method."""
882+
if fn is not None:
883+
update_def.set_validator(fn)
832884

833885

834886
def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None:
@@ -1132,7 +1184,7 @@ def _apply_to_class(
11321184
)
11331185
else:
11341186
queries[query_defn.name] = query_defn
1135-
elif isinstance(member, update):
1187+
elif isinstance(member, UpdateMethodMultiArg):
11361188
update_defn = member._defn
11371189
if update_defn.name in updates:
11381190
defn_name = update_defn.name or "<dynamic>"
@@ -1350,16 +1402,19 @@ class _UpdateDefinition:
13501402
arg_types: Optional[List[Type]] = None
13511403
ret_type: Optional[Type] = None
13521404
validator: Optional[Callable[..., None]] = None
1405+
dynamic_vararg: bool = False
13531406

13541407
def __post_init__(self) -> None:
13551408
if self.arg_types is None:
13561409
arg_types, ret_type = temporalio.common._type_hints_from_func(self.fn)
1357-
# Disallow dynamic varargs
1358-
if not self.name and not _assert_dynamic_handler_args(
1359-
self.fn, arg_types, self.is_method
1360-
):
1361-
raise RuntimeError(
1362-
"Dynamic updates do not support a vararg third param, use Sequence[RawValue]",
1410+
# If dynamic, assert it
1411+
if not self.name:
1412+
object.__setattr__(
1413+
self,
1414+
"dynamic_vararg",
1415+
not _assert_dynamic_handler_args(
1416+
self.fn, arg_types, self.is_method
1417+
),
13631418
)
13641419
object.__setattr__(self, "arg_types", arg_types)
13651420
object.__setattr__(self, "ret_type", ret_type)

tests/test_workflow.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,18 @@ def test_workflow_defn_good():
103103
name="base_query", fn=GoodDefnBase.base_query, is_method=True
104104
),
105105
},
106-
# Since updates use class-based decorators we need to pass the inner _fn for the fn param
107106
updates={
108107
"update1": workflow._UpdateDefinition(
109-
name="update1", fn=GoodDefn.update1._fn, is_method=True
108+
name="update1", fn=GoodDefn.update1, is_method=True
110109
),
111110
"update-custom": workflow._UpdateDefinition(
112-
name="update-custom", fn=GoodDefn.update2._fn, is_method=True
111+
name="update-custom", fn=GoodDefn.update2, is_method=True
113112
),
114113
None: workflow._UpdateDefinition(
115-
name=None, fn=GoodDefn.update3._fn, is_method=True
114+
name=None, fn=GoodDefn.update3, is_method=True
116115
),
117116
"base_update": workflow._UpdateDefinition(
118-
name="base_update", fn=GoodDefnBase.base_update._fn, is_method=True
117+
name="base_update", fn=GoodDefnBase.base_update, is_method=True
119118
),
120119
},
121120
sandboxed=True,

tests/worker/test_workflow.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,6 +3553,10 @@ async def runs_activity(self, name: str) -> str:
35533553
await act
35543554
return "done"
35553555

3556+
@workflow.update(name="renamed")
3557+
async def async_named(self) -> str:
3558+
return "named"
3559+
35563560
@workflow.update
35573561
async def bad_validator(self) -> str:
35583562
return "done"
@@ -3579,10 +3583,6 @@ def dynavalidator(name: str, _args: Sequence[RawValue]) -> None:
35793583
workflow.set_dynamic_update_handler(dynahandler, validator=dynavalidator)
35803584
return "set"
35813585

3582-
@workflow.update(name="name_override")
3583-
async def not_the_name(self) -> str:
3584-
return "name_overridden"
3585-
35863586

35873587
async def test_workflow_update_handlers_happy(client: Client):
35883588
async with new_worker(
@@ -3611,9 +3611,8 @@ async def test_workflow_update_handlers_happy(client: Client):
36113611
await handle.update(UpdateHandlersWorkflow.set_dynamic)
36123612
assert "dynahandler - made_up" == await handle.update("made_up")
36133613

3614-
assert "name_overridden" == await handle.update(
3615-
UpdateHandlersWorkflow.not_the_name
3616-
)
3614+
# Name overload
3615+
assert "named" == await handle.update(UpdateHandlersWorkflow.async_named)
36173616

36183617

36193618
async def test_workflow_update_handlers_unhappy(client: Client):

0 commit comments

Comments
 (0)