Skip to content

Commit 4e9cbab

Browse files
authored
Fix search attribute skipping on protos which don't use the SearchAttributes message type (#1131)
* Fix search attribute skipping on protos which don't use the SearchAttributes message type * Remove invalid testing * Add test * Formatting * Use existing search attribute, the server hits limit * Only run on local * Add env type
1 parent 02322ad commit 4e9cbab

File tree

3 files changed

+122
-7
lines changed

3 files changed

+122
-7
lines changed

scripts/gen_payload_visitor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def emit_loop(
3232
if not self.skip_headers:
3333
for v in {iter_expr}:
3434
await self._visit_{child_method}(fs, v)"""
35+
elif field_name == "search_attributes":
36+
return f"""\
37+
if not self.skip_search_attributes:
38+
for v in {iter_expr}:
39+
await self._visit_{child_method}(fs, v)"""
3540
else:
3641
return f"""\
3742
for v in {iter_expr}:
@@ -197,7 +202,7 @@ def walk(self, desc: Descriptor) -> bool:
197202
# Process regular fields first
198203
for field in regular_fields:
199204
# Repeated fields (including maps which are represented as repeated messages)
200-
if field.label == FieldDescriptor.LABEL_REPEATED:
205+
if field.is_repeated:
201206
if (
202207
field.message_type is not None
203208
and field.message_type.GetOptions().map_entry

temporalio/bridge/_visitor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(
320320
if not self.skip_headers:
321321
for v in o.headers.values():
322322
await self._visit_temporal_api_common_v1_Payload(fs, v)
323-
for v in o.search_attributes.values():
324-
await self._visit_temporal_api_common_v1_Payload(fs, v)
323+
if not self.skip_search_attributes:
324+
for v in o.search_attributes.values():
325+
await self._visit_temporal_api_common_v1_Payload(fs, v)
325326

326327
async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o):
327328
await self._visit_payload_container(fs, o.input)
@@ -330,8 +331,9 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs,
330331
await self._visit_temporal_api_common_v1_Payload(fs, v)
331332
for v in o.memo.values():
332333
await self._visit_temporal_api_common_v1_Payload(fs, v)
333-
for v in o.search_attributes.values():
334-
await self._visit_temporal_api_common_v1_Payload(fs, v)
334+
if not self.skip_search_attributes:
335+
for v in o.search_attributes.values():
336+
await self._visit_temporal_api_common_v1_Payload(fs, v)
335337

336338
async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(
337339
self, fs, o
@@ -350,8 +352,9 @@ async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o):
350352
async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(
351353
self, fs, o
352354
):
353-
for v in o.search_attributes.values():
354-
await self._visit_temporal_api_common_v1_Payload(fs, v)
355+
if not self.skip_search_attributes:
356+
for v in o.search_attributes.values():
357+
await self._visit_temporal_api_common_v1_Payload(fs, v)
355358

356359
async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o):
357360
if o.HasField("upserted_memo"):

tests/worker/test_workflow.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import temporalio.activity
4545
import temporalio.api.sdk.v1
4646
import temporalio.client
47+
import temporalio.converter
4748
import temporalio.worker
4849
import temporalio.workflow
4950
from temporalio import activity, workflow
@@ -8369,3 +8370,109 @@ async def test_previous_run_failure(client: Client):
83698370
)
83708371
result = await handle.result()
83718372
assert result == "Done"
8373+
8374+
8375+
class EncryptionCodec(PayloadCodec):
8376+
def __init__(
8377+
self,
8378+
key_id: str = "test-key-id",
8379+
key: bytes = b"test-key-test-key-test-key-test!",
8380+
) -> None:
8381+
super().__init__()
8382+
self.key_id = key_id
8383+
8384+
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
8385+
# We blindly encode all payloads with the key and set the metadata
8386+
# saying which key we used
8387+
return [
8388+
Payload(
8389+
metadata={
8390+
"encoding": b"binary/encrypted",
8391+
"encryption-key-id": self.key_id.encode(),
8392+
},
8393+
data=self.encrypt(p.SerializeToString()),
8394+
)
8395+
for p in payloads
8396+
]
8397+
8398+
async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
8399+
ret: List[Payload] = []
8400+
for p in payloads:
8401+
# Ignore ones w/out our expected encoding
8402+
if p.metadata.get("encoding", b"").decode() != "binary/encrypted":
8403+
ret.append(p)
8404+
continue
8405+
# Confirm our key ID is the same
8406+
key_id = p.metadata.get("encryption-key-id", b"").decode()
8407+
if key_id != self.key_id:
8408+
raise ValueError(
8409+
f"Unrecognized key ID {key_id}. Current key ID is {self.key_id}."
8410+
)
8411+
# Decrypt and append
8412+
ret.append(Payload.FromString(self.decrypt(p.data)))
8413+
return ret
8414+
8415+
def encrypt(self, data: bytes) -> bytes:
8416+
nonce = os.urandom(12)
8417+
return data
8418+
8419+
def decrypt(self, data: bytes) -> bytes:
8420+
return data
8421+
8422+
8423+
@workflow.defn
8424+
class SearchAttributeCodecParentWorkflow:
8425+
@workflow.run
8426+
async def run(self, name: str) -> str:
8427+
print(
8428+
await workflow.execute_child_workflow(
8429+
workflow=SearchAttributeCodecChildWorkflow.run,
8430+
arg=name,
8431+
id=f"child-{name}",
8432+
search_attributes=workflow.info().typed_search_attributes,
8433+
)
8434+
)
8435+
return f"Hello, {name}"
8436+
8437+
8438+
@workflow.defn
8439+
class SearchAttributeCodecChildWorkflow:
8440+
@workflow.run
8441+
async def run(self, name: str) -> str:
8442+
return f"Hello from child, {name}"
8443+
8444+
8445+
async def test_search_attribute_codec(client: Client, env_type: str):
8446+
if env_type != "local":
8447+
pytest.skip("Only testing search attributes on local which disables cache")
8448+
await ensure_search_attributes_present(
8449+
client,
8450+
SearchAttributeWorkflow.text_attribute,
8451+
)
8452+
8453+
config = client.config()
8454+
config["data_converter"] = dataclasses.replace(
8455+
temporalio.converter.default(), payload_codec=EncryptionCodec()
8456+
)
8457+
client = Client(**config)
8458+
8459+
# Run a worker for the workflow
8460+
async with new_worker(
8461+
client,
8462+
SearchAttributeCodecParentWorkflow,
8463+
SearchAttributeCodecChildWorkflow,
8464+
) as worker:
8465+
# Run workflow
8466+
result = await client.execute_workflow(
8467+
SearchAttributeCodecParentWorkflow.run,
8468+
"Temporal",
8469+
id=f"encryption-workflow-id",
8470+
task_queue=worker.task_queue,
8471+
search_attributes=TypedSearchAttributes(
8472+
[
8473+
SearchAttributePair(
8474+
SearchAttributeWorkflow.text_attribute, "test_text"
8475+
)
8476+
]
8477+
),
8478+
)

0 commit comments

Comments
 (0)