Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ def on_start_error(
start_workflow_input=start_workflow_operation._start_workflow_input,
update_workflow_input=update_input,
_on_start=on_start,
headers={},
_on_start_error=on_start_error,
)

Expand Down Expand Up @@ -5538,6 +5539,7 @@ class StartWorkflowUpdateWithStartInput:
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
]
_on_start_error: Callable[[BaseException], None]
headers: Mapping[str, temporalio.api.common.v1.Payload]


@dataclass
Expand Down Expand Up @@ -6361,6 +6363,10 @@ def on_start(

err: Optional[BaseException] = None

# fan headers out to both operations
input.start_workflow_input.headers = input.headers
input.update_workflow_input.headers = input.headers

try:
return await self._start_workflow_update_with_start(
input.start_workflow_input, input.update_workflow_input, on_start
Expand Down
17 changes: 17 additions & 0 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,23 @@ async def start_workflow_update(
):
return await super().start_workflow_update(input)

async def start_update_with_start_workflow(
self, input: temporalio.client.StartWorkflowUpdateWithStartInput
) -> temporalio.client.WorkflowUpdateHandle[Any]:
attrs = {
"temporalWorkflowID": input.start_workflow_input.id,
}
if input.update_workflow_input.update_id is not None:
attrs["temporalUpdateID"] = input.update_workflow_input.update_id

with self.root._start_as_current_span(
f"StartUpdateWithStartWorkflow:{input.start_workflow_input.workflow}",
attributes=attrs,
input=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().start_update_with_start_workflow(input)


class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor):
def __init__(
Expand Down
107 changes: 105 additions & 2 deletions tests/contrib/test_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from opentelemetry.trace import StatusCode, get_tracer

from temporalio import activity, workflow
from temporalio.client import Client
from temporalio.common import RetryPolicy
from temporalio.client import Client, WithStartWorkflowOperation, WorkflowUpdateStage
from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy
from temporalio.contrib.opentelemetry import TracingInterceptor
from temporalio.contrib.opentelemetry import workflow as otel_workflow
from temporalio.exceptions import ApplicationError, ApplicationErrorCategory
Expand Down Expand Up @@ -55,6 +55,7 @@ class TracingWorkflowAction:
continue_as_new: Optional[TracingWorkflowActionContinueAsNew] = None
wait_until_signal_count: int = 0
wait_and_do_update: bool = False
wait_and_do_start_with_update: bool = False


@dataclass
Expand All @@ -79,13 +80,15 @@ class TracingWorkflowActionContinueAsNew:


ready_for_update: asyncio.Semaphore
ready_for_update_with_start: asyncio.Semaphore


@workflow.defn
class TracingWorkflow:
def __init__(self) -> None:
self._signal_count = 0
self._did_update = False
self._did_update_with_start = False

@workflow.run
async def run(self, param: TracingWorkflowParam) -> None:
Expand Down Expand Up @@ -140,6 +143,9 @@ async def run(self, param: TracingWorkflowParam) -> None:
if action.wait_and_do_update:
ready_for_update.release()
await workflow.wait_condition(lambda: self._did_update)
if action.wait_and_do_start_with_update:
ready_for_update_with_start.release()
await workflow.wait_condition(lambda: self._did_update_with_start)

async def _raise_on_non_replay(self) -> None:
replaying = workflow.unsafe.is_replaying()
Expand All @@ -161,6 +167,10 @@ def signal(self) -> None:
def update(self) -> None:
self._did_update = True

@workflow.update
def update_with_start(self) -> None:
self._did_update_with_start = True

@update.validator
def update_validator(self) -> None:
pass
Expand Down Expand Up @@ -301,6 +311,99 @@ async def test_opentelemetry_tracing(client: Client, env: WorkflowEnvironment):
]


async def test_opentelemetry_tracing_update_with_start(
client: Client, env: WorkflowEnvironment
):
if env.supports_time_skipping:
pytest.skip(
"Java test server: https://github.com/temporalio/sdk-java/issues/1424"
)
global ready_for_update_with_start
ready_for_update_with_start = asyncio.Semaphore(0)
# Create a tracer that has an in-memory exporter
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
tracer = get_tracer(__name__, tracer_provider=provider)
# Create new client with tracer interceptor
client_config = client.config()
client_config["interceptors"] = [TracingInterceptor(tracer)]
client = Client(**client_config)

task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client,
task_queue=task_queue,
workflows=[TracingWorkflow],
activities=[tracing_activity],
# Needed so we can wait to send update at the right time
workflow_runner=UnsandboxedWorkflowRunner(),
):
# Run workflow with various actions
workflow_id = f"workflow_{uuid.uuid4()}"
workflow_params = TracingWorkflowParam(
actions=[
# Wait for update
TracingWorkflowAction(wait_and_do_start_with_update=True),
]
)
handle = await client.start_workflow(
TracingWorkflow.run,
workflow_params,
id=workflow_id,
task_queue=task_queue,
)
async with ready_for_update_with_start:
start_op = WithStartWorkflowOperation(
TracingWorkflow.run,
workflow_params,
id=handle.id,
task_queue=task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
)
await client.start_update_with_start_workflow(
TracingWorkflow.update_with_start,
start_workflow_operation=start_op,
id=handle.id,
wait_for_stage=WorkflowUpdateStage.ACCEPTED,
)
await handle.result()

# issue update with start again to trigger a new workflow
workflow_id = f"workflow_{uuid.uuid4()}"
start_op = WithStartWorkflowOperation(
TracingWorkflow.run,
TracingWorkflowParam(actions=[]),
id=workflow_id,
task_queue=task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
)
await client.execute_update_with_start_workflow(
update=TracingWorkflow.update_with_start,
start_workflow_operation=start_op,
id=workflow_id,
)

# Dump debug with attributes, but do string assertion test without
logging.debug(
"Spans:\n%s",
"\n".join(dump_spans(exporter.get_finished_spans(), with_attributes=False)),
)
assert dump_spans(exporter.get_finished_spans(), with_attributes=False) == [
"StartWorkflow:TracingWorkflow",
" RunWorkflow:TracingWorkflow",
" MyCustomSpan",
" HandleUpdate:update_with_start (links: StartUpdateWithStartWorkflow:TracingWorkflow)",
" CompleteWorkflow:TracingWorkflow",
"StartUpdateWithStartWorkflow:TracingWorkflow",
"StartUpdateWithStartWorkflow:TracingWorkflow",
" HandleUpdate:update_with_start (links: StartUpdateWithStartWorkflow:TracingWorkflow)",
" RunWorkflow:TracingWorkflow",
" MyCustomSpan",
" CompleteWorkflow:TracingWorkflow",
]


def dump_spans(
spans: Iterable[ReadableSpan],
*,
Expand Down
Loading