Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import temporalio.converter
import temporalio.exceptions
import temporalio.nexus
import temporalio.nexus._operation_context
import temporalio.runtime
import temporalio.service
import temporalio.workflow
Expand Down Expand Up @@ -5877,6 +5878,12 @@ async def _build_start_workflow_execution_request(
)
# Links are duplicated on request for compatibility with older server versions.
req.links.extend(links)

if temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context():
req.on_conflict_options.attach_request_id = True
req.on_conflict_options.attach_completion_callbacks = True
req.on_conflict_options.attach_links = True

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Nexus Signal-with-Start Missing Conflict Options

The _build_signal_with_start_workflow_execution_request method is missing the on_conflict_options logic for Nexus operations. This logic, which sets options for attaching request IDs, callbacks, and links, is present in _build_start_workflow_execution_request. Without it, Nexus operations using signal-with-start won't properly attach to existing workflows, causing inconsistent behavior.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bad review comment, but we don't need to support signal-with-start in this context yet.

return req

async def _build_signal_with_start_workflow_execution_request(
Expand Down Expand Up @@ -5932,6 +5939,7 @@ async def _populate_start_workflow_execution_request(
"temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType",
int(input.id_conflict_policy),
)

if input.retry_policy is not None:
input.retry_policy.apply_to_proto(req.retry_policy)
req.cron_schedule = input.cron_schedule
Expand Down
107 changes: 60 additions & 47 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import dataclasses
import logging
from collections.abc import Awaitable, Mapping, MutableMapping, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Optional,
Union,
overload,
Expand Down Expand Up @@ -47,6 +49,14 @@
ContextVar("temporal-cancel-operation-context")
)

# A Nexus start handler might start zero or more workflows as usual using a Temporal client. In
# addition, it may start one "nexus-backing" workflow, using
# WorkflowRunOperationContext.start_workflow. This context is active while the latter is being done.
# It is thus a narrower context than _temporal_start_operation_context.
_temporal_nexus_backing_workflow_start_context: ContextVar[bool] = ContextVar(
"temporal-nexus-backing-workflow-start-context"
)


@dataclass(frozen=True)
class Info:
Expand Down Expand Up @@ -96,6 +106,19 @@ def _try_temporal_context() -> (
return start_ctx or cancel_ctx


@contextmanager
def _nexus_backing_workflow_start_context() -> Generator[None, None, None]:
token = _temporal_nexus_backing_workflow_start_context.set(True)
try:
yield
finally:
_temporal_nexus_backing_workflow_start_context.reset(token)


def _in_nexus_backing_workflow_start_context() -> bool:
return _temporal_nexus_backing_workflow_start_context.get(False)


@dataclass
class _TemporalStartOperationContext:
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
Expand Down Expand Up @@ -396,56 +419,46 @@ async def start_workflow(
Nexus caller is itself a workflow, this means that the workflow in the caller
namespace web UI will contain links to the started workflow, and vice versa.
"""
# TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this:
# if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') {
# internalOptions.onConflictOptions = {
# attachLinks: true,
# attachCompletionCallbacks: true,
# attachRequestId: true,
# };
# }
if (
id_conflict_policy
== temporalio.common.WorkflowIDConflictPolicy.USE_EXISTING
):
raise RuntimeError(
"WorkflowIDConflictPolicy.USE_EXISTING is not yet supported when starting a workflow "
"that backs a Nexus operation (Python SDK Nexus support is at Pre-release stage)."
)

# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
# but these are deliberately not exposed in overloads, hence the type-check
# violation.
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
workflow=workflow,
arg=arg,
args=args,
id=id,
task_queue=task_queue or self._temporal_context.info().task_queue,
result_type=result_type,
execution_timeout=execution_timeout,
run_timeout=run_timeout,
task_timeout=task_timeout,
id_reuse_policy=id_reuse_policy,
id_conflict_policy=id_conflict_policy,
retry_policy=retry_policy,
cron_schedule=cron_schedule,
memo=memo,
search_attributes=search_attributes,
static_summary=static_summary,
static_details=static_details,
start_delay=start_delay,
start_signal=start_signal,
start_signal_args=start_signal_args,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
callbacks=self._temporal_context._get_callbacks(),
workflow_event_links=self._temporal_context._get_workflow_event_links(),
request_id=self._temporal_context.nexus_context.request_id,
)

# Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request
# contains nexus-specific data such as a completion callback (used by the handler server
# namespace to deliver the result to the caller namespace when the workflow reaches a
# terminal state) and inbound links to the caller workflow (attached to history events of
# the workflow started in the handler namespace, and displayed in the UI).
with _nexus_backing_workflow_start_context():
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
workflow=workflow,
arg=arg,
args=args,
id=id,
task_queue=task_queue or self._temporal_context.info().task_queue,
result_type=result_type,
execution_timeout=execution_timeout,
run_timeout=run_timeout,
task_timeout=task_timeout,
id_reuse_policy=id_reuse_policy,
id_conflict_policy=id_conflict_policy,
retry_policy=retry_policy,
cron_schedule=cron_schedule,
memo=memo,
search_attributes=search_attributes,
static_summary=static_summary,
static_details=static_details,
start_delay=start_delay,
start_signal=start_signal,
start_signal_args=start_signal_args,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
callbacks=self._temporal_context._get_callbacks(),
workflow_event_links=self._temporal_context._get_workflow_event_links(),
request_id=self._temporal_context.nexus_context.request_id,
)

self._temporal_context._add_outbound_links(wf_handle)

Expand Down
7 changes: 0 additions & 7 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3029,13 +3029,6 @@ def operation_token(self) -> Optional[str]:
def __await__(self) -> Generator[Any, Any, OutputT]:
return self._task.__await__()

def __repr__(self) -> str:
return (
f"{self._start_fut} "
f"{self._result_fut} "
f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" # type: ignore
)

def cancel(self) -> bool:
return self._task.cancel()

Expand Down
124 changes: 124 additions & 0 deletions tests/nexus/test_use_existing_conflict_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

import asyncio
import uuid
from dataclasses import dataclass
from typing import Optional

import pytest
from nexusrpc.handler import service_handler

from temporalio import nexus, workflow
from temporalio.client import Client
from temporalio.common import WorkflowIDConflictPolicy
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name


@dataclass
class OpInput:
workflow_id: str
conflict_policy: WorkflowIDConflictPolicy


@workflow.defn
class HandlerWorkflow:
def __init__(self) -> None:
self.result: Optional[str] = None

@workflow.run
async def run(self) -> str:
await workflow.wait_condition(lambda: self.result is not None)
assert self.result
return self.result

@workflow.signal
def complete(self, result: str) -> None:
self.result = result


@service_handler
class NexusService:
@nexus.workflow_run_operation
async def workflow_backed_operation(
self, ctx: nexus.WorkflowRunOperationContext, input: OpInput
) -> nexus.WorkflowHandle[str]:
return await ctx.start_workflow(
HandlerWorkflow.run,
id=input.workflow_id,
id_conflict_policy=input.conflict_policy,
)


@dataclass
class CallerWorkflowInput:
workflow_id: str
task_queue: str
num_operations: int


@workflow.defn
class CallerWorkflow:
def __init__(self) -> None:
self._nexus_operations_have_started = asyncio.Event()

@workflow.run
async def run(self, input: CallerWorkflowInput) -> list[str]:
nexus_client = workflow.create_nexus_client(
service=NexusService, endpoint=make_nexus_endpoint_name(input.task_queue)
)

op_input = OpInput(
workflow_id=input.workflow_id,
conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
)

handles = []
for _ in range(input.num_operations):
handles.append(
await nexus_client.start_operation(
NexusService.workflow_backed_operation, op_input
)
)
self._nexus_operations_have_started.set()
return await asyncio.gather(*handles)

@workflow.update
async def nexus_operations_have_started(self) -> None:
await self._nexus_operations_have_started.wait()


async def test_multiple_operation_invocations_can_connect_to_same_handler_workflow(
client: Client, env: WorkflowEnvironment
):
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
workflow_id = str(uuid.uuid4())

async with Worker(
client,
nexus_service_handlers=[NexusService()],
workflows=[CallerWorkflow, HandlerWorkflow],
task_queue=task_queue,
):
await create_nexus_endpoint(task_queue, client)
caller_handle = await client.start_workflow(
CallerWorkflow.run,
args=[
CallerWorkflowInput(
workflow_id=workflow_id,
task_queue=task_queue,
num_operations=5,
)
],
id=str(uuid.uuid4()),
task_queue=task_queue,
)
await caller_handle.execute_update(CallerWorkflow.nexus_operations_have_started)
await client.get_workflow_handle(workflow_id).signal(
HandlerWorkflow.complete, "test-result"
)
assert await caller_handle.result() == ["test-result"] * 5
Loading