|
44 | 44 | import temporalio.activity
|
45 | 45 | import temporalio.api.sdk.v1
|
46 | 46 | import temporalio.client
|
| 47 | +import temporalio.converter |
47 | 48 | import temporalio.worker
|
48 | 49 | import temporalio.workflow
|
49 | 50 | from temporalio import activity, workflow
|
@@ -8369,3 +8370,109 @@ async def test_previous_run_failure(client: Client):
|
8369 | 8370 | )
|
8370 | 8371 | result = await handle.result()
|
8371 | 8372 | assert result == "Done"
|
| 8373 | + |
| 8374 | + |
| 8375 | +class EncryptionCodec(PayloadCodec): |
| 8376 | + def __init__( |
| 8377 | + self, |
| 8378 | + key_id: str = "test-key-id", |
| 8379 | + key: bytes = b"test-key-test-key-test-key-test!", |
| 8380 | + ) -> None: |
| 8381 | + super().__init__() |
| 8382 | + self.key_id = key_id |
| 8383 | + |
| 8384 | + async def encode(self, payloads: Sequence[Payload]) -> List[Payload]: |
| 8385 | + # We blindly encode all payloads with the key and set the metadata |
| 8386 | + # saying which key we used |
| 8387 | + return [ |
| 8388 | + Payload( |
| 8389 | + metadata={ |
| 8390 | + "encoding": b"binary/encrypted", |
| 8391 | + "encryption-key-id": self.key_id.encode(), |
| 8392 | + }, |
| 8393 | + data=self.encrypt(p.SerializeToString()), |
| 8394 | + ) |
| 8395 | + for p in payloads |
| 8396 | + ] |
| 8397 | + |
| 8398 | + async def decode(self, payloads: Sequence[Payload]) -> List[Payload]: |
| 8399 | + ret: List[Payload] = [] |
| 8400 | + for p in payloads: |
| 8401 | + # Ignore ones w/out our expected encoding |
| 8402 | + if p.metadata.get("encoding", b"").decode() != "binary/encrypted": |
| 8403 | + ret.append(p) |
| 8404 | + continue |
| 8405 | + # Confirm our key ID is the same |
| 8406 | + key_id = p.metadata.get("encryption-key-id", b"").decode() |
| 8407 | + if key_id != self.key_id: |
| 8408 | + raise ValueError( |
| 8409 | + f"Unrecognized key ID {key_id}. Current key ID is {self.key_id}." |
| 8410 | + ) |
| 8411 | + # Decrypt and append |
| 8412 | + ret.append(Payload.FromString(self.decrypt(p.data))) |
| 8413 | + return ret |
| 8414 | + |
| 8415 | + def encrypt(self, data: bytes) -> bytes: |
| 8416 | + nonce = os.urandom(12) |
| 8417 | + return data |
| 8418 | + |
| 8419 | + def decrypt(self, data: bytes) -> bytes: |
| 8420 | + return data |
| 8421 | + |
| 8422 | + |
| 8423 | +@workflow.defn |
| 8424 | +class SearchAttributeCodecParentWorkflow: |
| 8425 | + @workflow.run |
| 8426 | + async def run(self, name: str) -> str: |
| 8427 | + print( |
| 8428 | + await workflow.execute_child_workflow( |
| 8429 | + workflow=SearchAttributeCodecChildWorkflow.run, |
| 8430 | + arg=name, |
| 8431 | + id=f"child-{name}", |
| 8432 | + search_attributes=workflow.info().typed_search_attributes, |
| 8433 | + ) |
| 8434 | + ) |
| 8435 | + return f"Hello, {name}" |
| 8436 | + |
| 8437 | + |
| 8438 | +@workflow.defn |
| 8439 | +class SearchAttributeCodecChildWorkflow: |
| 8440 | + @workflow.run |
| 8441 | + async def run(self, name: str) -> str: |
| 8442 | + return f"Hello from child, {name}" |
| 8443 | + |
| 8444 | + |
| 8445 | +async def test_search_attribute_codec(client: Client, env_type: str): |
| 8446 | + if env_type != "local": |
| 8447 | + pytest.skip("Only testing search attributes on local which disables cache") |
| 8448 | + await ensure_search_attributes_present( |
| 8449 | + client, |
| 8450 | + SearchAttributeWorkflow.text_attribute, |
| 8451 | + ) |
| 8452 | + |
| 8453 | + config = client.config() |
| 8454 | + config["data_converter"] = dataclasses.replace( |
| 8455 | + temporalio.converter.default(), payload_codec=EncryptionCodec() |
| 8456 | + ) |
| 8457 | + client = Client(**config) |
| 8458 | + |
| 8459 | + # Run a worker for the workflow |
| 8460 | + async with new_worker( |
| 8461 | + client, |
| 8462 | + SearchAttributeCodecParentWorkflow, |
| 8463 | + SearchAttributeCodecChildWorkflow, |
| 8464 | + ) as worker: |
| 8465 | + # Run workflow |
| 8466 | + result = await client.execute_workflow( |
| 8467 | + SearchAttributeCodecParentWorkflow.run, |
| 8468 | + "Temporal", |
| 8469 | + id=f"encryption-workflow-id", |
| 8470 | + task_queue=worker.task_queue, |
| 8471 | + search_attributes=TypedSearchAttributes( |
| 8472 | + [ |
| 8473 | + SearchAttributePair( |
| 8474 | + SearchAttributeWorkflow.text_attribute, "test_text" |
| 8475 | + ) |
| 8476 | + ] |
| 8477 | + ), |
| 8478 | + ) |
0 commit comments