diff --git a/temporalio/client.py b/temporalio/client.py index 6c26d41ef..f1a16d39e 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -1208,6 +1208,7 @@ def on_start_error( start_workflow_input=start_workflow_operation._start_workflow_input, update_workflow_input=update_input, _on_start=on_start, + headers={}, _on_start_error=on_start_error, ) @@ -5538,6 +5539,7 @@ class StartWorkflowUpdateWithStartInput: [temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None ] _on_start_error: Callable[[BaseException], None] + headers: Mapping[str, temporalio.api.common.v1.Payload] @dataclass @@ -6361,6 +6363,10 @@ def on_start( err: Optional[BaseException] = None + # fan headers out to both operations + input.start_workflow_input.headers = input.headers + input.update_workflow_input.headers = input.headers + try: return await self._start_workflow_update_with_start( input.start_workflow_input, input.update_workflow_input, on_start diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 04d40d544..34a81a6fa 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -292,6 +292,23 @@ async def start_workflow_update( ): return await super().start_workflow_update(input) + async def start_update_with_start_workflow( + self, input: temporalio.client.StartWorkflowUpdateWithStartInput + ) -> temporalio.client.WorkflowUpdateHandle[Any]: + attrs = { + "temporalWorkflowID": input.start_workflow_input.id, + } + if input.update_workflow_input.update_id is not None: + attrs["temporalUpdateID"] = input.update_workflow_input.update_id + + with self.root._start_as_current_span( + f"StartUpdateWithStartWorkflow:{input.start_workflow_input.workflow}", + attributes=attrs, + input=input, + kind=opentelemetry.trace.SpanKind.CLIENT, + ): + return await super().start_update_with_start_workflow(input) + class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor): def __init__( diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 0b797f606..6ae3686cc 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -14,8 +14,8 @@ from opentelemetry.trace import StatusCode, get_tracer from temporalio import activity, workflow -from temporalio.client import Client -from temporalio.common import RetryPolicy +from temporalio.client import Client, WithStartWorkflowOperation, WorkflowUpdateStage +from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy from temporalio.contrib.opentelemetry import TracingInterceptor from temporalio.contrib.opentelemetry import workflow as otel_workflow from temporalio.exceptions import ApplicationError, ApplicationErrorCategory @@ -55,6 +55,7 @@ class TracingWorkflowAction: continue_as_new: Optional[TracingWorkflowActionContinueAsNew] = None wait_until_signal_count: int = 0 wait_and_do_update: bool = False + wait_and_do_start_with_update: bool = False @dataclass @@ -79,6 +80,7 @@ class TracingWorkflowActionContinueAsNew: ready_for_update: asyncio.Semaphore +ready_for_update_with_start: asyncio.Semaphore @workflow.defn @@ -86,6 +88,7 @@ class TracingWorkflow: def __init__(self) -> None: self._signal_count = 0 self._did_update = False + self._did_update_with_start = False @workflow.run async def run(self, param: TracingWorkflowParam) -> None: @@ -140,6 +143,9 @@ async def run(self, param: TracingWorkflowParam) -> None: if action.wait_and_do_update: ready_for_update.release() await workflow.wait_condition(lambda: self._did_update) + if action.wait_and_do_start_with_update: + ready_for_update_with_start.release() + await workflow.wait_condition(lambda: self._did_update_with_start) async def _raise_on_non_replay(self) -> None: replaying = workflow.unsafe.is_replaying() @@ -161,6 +167,10 @@ def signal(self) -> None: def update(self) -> None: self._did_update = True + @workflow.update + def update_with_start(self) -> None: + self._did_update_with_start = True + @update.validator def update_validator(self) -> None: pass @@ -301,6 +311,99 @@ async def test_opentelemetry_tracing(client: Client, env: WorkflowEnvironment): ] +async def test_opentelemetry_tracing_update_with_start( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1424" + ) + global ready_for_update_with_start + ready_for_update_with_start = asyncio.Semaphore(0) + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + # Create new client with tracer interceptor + client_config = client.config() + client_config["interceptors"] = [TracingInterceptor(tracer)] + client = Client(**client_config) + + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client, + task_queue=task_queue, + workflows=[TracingWorkflow], + activities=[tracing_activity], + # Needed so we can wait to send update at the right time + workflow_runner=UnsandboxedWorkflowRunner(), + ): + # Run workflow with various actions + workflow_id = f"workflow_{uuid.uuid4()}" + workflow_params = TracingWorkflowParam( + actions=[ + # Wait for update + TracingWorkflowAction(wait_and_do_start_with_update=True), + ] + ) + handle = await client.start_workflow( + TracingWorkflow.run, + workflow_params, + id=workflow_id, + task_queue=task_queue, + ) + async with ready_for_update_with_start: + start_op = WithStartWorkflowOperation( + TracingWorkflow.run, + workflow_params, + id=handle.id, + task_queue=task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING, + ) + await client.start_update_with_start_workflow( + TracingWorkflow.update_with_start, + start_workflow_operation=start_op, + id=handle.id, + wait_for_stage=WorkflowUpdateStage.ACCEPTED, + ) + await handle.result() + + # issue update with start again to trigger a new workflow + workflow_id = f"workflow_{uuid.uuid4()}" + start_op = WithStartWorkflowOperation( + TracingWorkflow.run, + TracingWorkflowParam(actions=[]), + id=workflow_id, + task_queue=task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING, + ) + await client.execute_update_with_start_workflow( + update=TracingWorkflow.update_with_start, + start_workflow_operation=start_op, + id=workflow_id, + ) + + # Dump debug with attributes, but do string assertion test without + logging.debug( + "Spans:\n%s", + "\n".join(dump_spans(exporter.get_finished_spans(), with_attributes=False)), + ) + assert dump_spans(exporter.get_finished_spans(), with_attributes=False) == [ + "StartWorkflow:TracingWorkflow", + " RunWorkflow:TracingWorkflow", + " MyCustomSpan", + " HandleUpdate:update_with_start (links: StartUpdateWithStartWorkflow:TracingWorkflow)", + " CompleteWorkflow:TracingWorkflow", + "StartUpdateWithStartWorkflow:TracingWorkflow", + "StartUpdateWithStartWorkflow:TracingWorkflow", + " HandleUpdate:update_with_start (links: StartUpdateWithStartWorkflow:TracingWorkflow)", + " RunWorkflow:TracingWorkflow", + " MyCustomSpan", + " CompleteWorkflow:TracingWorkflow", + ] + + def dump_spans( spans: Iterable[ReadableSpan], *,