Skip to content

Commit d8b5dff

Browse files
committed
Require identical parameters
1 parent 1f81d84 commit d8b5dff

File tree

2 files changed

+24
-73
lines changed

2 files changed

+24
-73
lines changed

temporalio/workflow.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,7 @@ def _apply_to_class(
13351335
issues: List[str] = []
13361336

13371337
# Collect run fn and all signal/query/update fns
1338+
init_fn: Optional[Callable[..., None]] = None
13381339
run_fn: Optional[Callable[..., Awaitable[Any]]] = None
13391340
seen_run_attr = False
13401341
signals: Dict[Optional[str], _SignalDefinition] = {}
@@ -1381,6 +1382,8 @@ def _apply_to_class(
13811382
)
13821383
else:
13831384
queries[query_defn.name] = query_defn
1385+
elif name == "__init__" and hasattr(member, "__temporal_workflow_init"):
1386+
init_fn = member
13841387
elif isinstance(member, UpdateMethodMultiParam):
13851388
update_defn = member._defn
13861389
if update_defn.name in updates:
@@ -1433,6 +1436,11 @@ def _apply_to_class(
14331436

14341437
if not seen_run_attr:
14351438
issues.append("Missing @workflow.run method")
1439+
if init_fn and run_fn:
1440+
if not _parameters_identical_up_to_naming(init_fn, run_fn):
1441+
issues.append(
1442+
"@workflow.init and @workflow.run method parameters do not match"
1443+
)
14361444
if issues:
14371445
if len(issues) == 1:
14381446
raise ValueError(f"Invalid workflow class: {issues[0]}")
@@ -1471,6 +1479,19 @@ def __post_init__(self) -> None:
14711479
object.__setattr__(self, "ret_type", ret_type)
14721480

14731481

1482+
def _parameters_identical_up_to_naming(fn1: Callable, fn2: Callable) -> bool:
1483+
"""Return True if the functions have identical parameter lists, ignoring parameter names."""
1484+
1485+
def params(fn: Callable) -> List[inspect.Parameter]:
1486+
# Ignore name when comparing parameters (remaining fields are kind,
1487+
# default, and annotation).
1488+
return [p.replace(name="x") for p in inspect.signature(fn).parameters.values()]
1489+
1490+
# We require that any type annotations present match exactly; i.e. we do
1491+
# not support any notion of subtype compatibility.
1492+
return params(fn1) == params(fn2)
1493+
1494+
14741495
# Async safe version of partial
14751496
def _bind_method(obj: Any, fn: Callable[..., Any]) -> Callable[..., Any]:
14761497
# Curry instance on the definition function since that represents an

tests/worker/test_workflow.py

Lines changed: 3 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3902,7 +3902,7 @@ def matches_metric_line(
39023902
return False
39033903
# Must have labels (don't escape for this test)
39043904
for k, v in at_least_labels.items():
3905-
if not f'{k}="{v}"' in line:
3905+
if f'{k}="{v}"' not in line:
39063906
return False
39073907
return line.endswith(f" {value}")
39083908

@@ -4856,7 +4856,7 @@ async def assert_scenario(
48564856
update_scenario: Optional[FailureTypesScenario] = None,
48574857
) -> None:
48584858
logging.debug(
4859-
f"Asserting scenario %s",
4859+
"Asserting scenario %s",
48604860
{
48614861
"workflow": workflow,
48624862
"expect_task_fail": expect_task_fail,
@@ -6054,7 +6054,7 @@ class WorkflowWithWorkflowInit:
60546054
_expected_update_result = "workflow input value"
60556055

60566056
@workflow.init
6057-
def __init__(self, arg: str = "from parameter default") -> None:
6057+
def __init__(self, arg: str) -> None:
60586058
self.value = arg
60596059

60606060
@workflow.update
@@ -6084,82 +6084,12 @@ async def run(self, _: str) -> str:
60846084
return self.value
60856085

60866086

6087-
@workflow.defn(name="MyWorkflow")
6088-
class WorkflowWithWorkflowInitBaseDecorated:
6089-
use_workflow_init = True
6090-
6091-
@workflow.init
6092-
def __init__(
6093-
self, required_param_that_will_be_supplied_by_child_init_method
6094-
) -> None:
6095-
self.value = required_param_that_will_be_supplied_by_child_init_method
6096-
6097-
if use_workflow_init:
6098-
__init__ = workflow.init(__init__)
6099-
6100-
@workflow.run
6101-
async def run(self, _: str): ...
6102-
6103-
@workflow.update
6104-
async def my_update(self) -> str:
6105-
return self.value
6106-
6107-
6108-
class WorkflowWithWorkflowInitBaseUndecorated(WorkflowWithWorkflowInitBaseDecorated):
6109-
# The base class does not need the @workflow.init decorator
6110-
use_workflow_init = False
6111-
6112-
6113-
@workflow.defn(name="MyWorkflow")
6114-
class WorkflowWithWorkflowInitChild(WorkflowWithWorkflowInitBaseDecorated):
6115-
use_workflow_init = True
6116-
_expected_update_result = "workflow input value"
6117-
6118-
def __init__(self, arg: str = "from parameter default") -> None:
6119-
super().__init__("from child __init__")
6120-
self.value = arg
6121-
6122-
if use_workflow_init:
6123-
__init__ = workflow.init(__init__)
6124-
6125-
@workflow.run
6126-
async def run(self, _: str) -> str:
6127-
self.value = "set in run method"
6128-
return self.value
6129-
6130-
6131-
@workflow.defn(name="MyWorkflow")
6132-
class WorkflowWithWorkflowInitChildNoWorkflowInit(
6133-
WorkflowWithWorkflowInitBaseDecorated
6134-
):
6135-
use_workflow_init = False
6136-
_expected_update_result = "from parameter default"
6137-
6138-
def __init__(self, arg: str = "from parameter default") -> None:
6139-
super().__init__("from child __init__")
6140-
self.value = arg
6141-
6142-
if use_workflow_init:
6143-
__init__ = workflow.init(__init__)
6144-
6145-
@workflow.run
6146-
async def run(self, _: str) -> str:
6147-
self.value = "set in run method"
6148-
return self.value
6149-
6150-
61516087
@pytest.mark.parametrize(
61526088
["client_cls", "worker_cls"],
61536089
[
61546090
(WorkflowWithoutInit, WorkflowWithoutInit),
61556091
(WorkflowWithNonWorkflowInitInit, WorkflowWithNonWorkflowInitInit),
61566092
(WorkflowWithWorkflowInit, WorkflowWithWorkflowInit),
6157-
(WorkflowWithWorkflowInitBaseDecorated, WorkflowWithWorkflowInitChild),
6158-
(WorkflowWithWorkflowInitBaseUndecorated, WorkflowWithWorkflowInitChild),
6159-
(
6160-
WorkflowWithWorkflowInitBaseUndecorated,
6161-
WorkflowWithWorkflowInitChildNoWorkflowInit,
6162-
),
61636093
],
61646094
)
61656095
async def test_update_in_first_wft_sees_workflow_init(

0 commit comments

Comments
 (0)