feat: add workflow snapshot codec registry and strict sanitization
This commit is contained in:
@@ -1,3 +1,11 @@
|
||||
from app_runtime.workflow.persistence.codec_registry import CodecRegistry
|
||||
from app_runtime.workflow.persistence.entity_codec import EntityCodec
|
||||
from app_runtime.workflow.persistence.snapshot_sanitizer import WorkflowSnapshotSanitizer
|
||||
from app_runtime.workflow.persistence.workflow_persistence import WorkflowPersistence
|
||||
|
||||
__all__ = ["WorkflowPersistence"]
|
||||
__all__ = [
|
||||
"CodecRegistry",
|
||||
"EntityCodec",
|
||||
"WorkflowPersistence",
|
||||
"WorkflowSnapshotSanitizer",
|
||||
]
|
||||
|
||||
24
src/app_runtime/workflow/persistence/codec_registry.py
Normal file
24
src/app_runtime/workflow/persistence/codec_registry.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app_runtime.workflow.persistence.entity_codec import EntityCodec
|
||||
|
||||
|
||||
class CodecRegistry:
|
||||
def __init__(self, codecs: list[EntityCodec] | None = None) -> None:
|
||||
self._codecs = list(codecs or [])
|
||||
self._by_type_id = {codec.type_id: codec for codec in self._codecs}
|
||||
|
||||
def register(self, codec: EntityCodec) -> None:
|
||||
self._codecs.append(codec)
|
||||
self._by_type_id[codec.type_id] = codec
|
||||
|
||||
def find_for_value(self, value: Any) -> EntityCodec | None:
|
||||
for codec in self._codecs:
|
||||
if codec.can_encode(value):
|
||||
return codec
|
||||
return None
|
||||
|
||||
def find_by_type_id(self, type_id: str) -> EntityCodec | None:
|
||||
return self._by_type_id.get(type_id)
|
||||
18
src/app_runtime/workflow/persistence/entity_codec.py
Normal file
18
src/app_runtime/workflow/persistence/entity_codec.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EntityCodec(Protocol):
|
||||
type_id: str
|
||||
schema_version: int
|
||||
|
||||
def can_encode(self, value: Any) -> bool:
|
||||
...
|
||||
|
||||
def to_dict(self, value: Any) -> dict[str, Any]:
|
||||
...
|
||||
|
||||
def from_dict(self, payload: dict[str, Any]) -> Any:
|
||||
...
|
||||
@@ -1,12 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Any
|
||||
|
||||
from app_runtime.workflow.persistence.codec_registry import CodecRegistry
|
||||
|
||||
|
||||
class WorkflowSnapshotSanitizer:
|
||||
_RUNTIME_STATE_KEY = "runtime_only"
|
||||
_TAG_TYPE_KEY = "__type__"
|
||||
_TAG_VERSION_KEY = "__v__"
|
||||
_TAG_DATA_KEY = "data"
|
||||
|
||||
def __init__(self, codec_registry: CodecRegistry | None = None, *, strict: bool = True) -> None:
|
||||
self._codec_registry = codec_registry or CodecRegistry()
|
||||
self._strict = strict
|
||||
|
||||
def sanitize(self, snapshot: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = dict(snapshot.get("payload", {}))
|
||||
state = dict(snapshot.get("state", {}))
|
||||
state.pop(self._RUNTIME_STATE_KEY, None)
|
||||
return {
|
||||
"payload": self._sanitize_dict(payload),
|
||||
"state": self._sanitize_dict(state),
|
||||
@@ -17,4 +30,45 @@ class WorkflowSnapshotSanitizer:
|
||||
return {str(key): self._sanitize_dict(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [self._sanitize_dict(item) for item in value]
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (datetime, date)):
|
||||
return value.isoformat()
|
||||
codec = self._codec_registry.find_for_value(value)
|
||||
if codec is not None:
|
||||
return {
|
||||
self._TAG_TYPE_KEY: codec.type_id,
|
||||
self._TAG_VERSION_KEY: codec.schema_version,
|
||||
self._TAG_DATA_KEY: self._sanitize_dict(codec.to_dict(value)),
|
||||
}
|
||||
if self._strict:
|
||||
raise TypeError(f"Unsupported snapshot value type: {type(value).__name__}")
|
||||
return value
|
||||
|
||||
def hydrate(self, snapshot: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = dict(snapshot.get("payload", {}))
|
||||
state = dict(snapshot.get("state", {}))
|
||||
return {
|
||||
"payload": self._hydrate_value(payload),
|
||||
"state": self._hydrate_value(state),
|
||||
}
|
||||
|
||||
def _hydrate_value(self, value: Any) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [self._hydrate_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
if (
|
||||
self._TAG_TYPE_KEY in value
|
||||
and self._TAG_VERSION_KEY in value
|
||||
and self._TAG_DATA_KEY in value
|
||||
):
|
||||
type_id = str(value.get(self._TAG_TYPE_KEY))
|
||||
codec = self._codec_registry.find_by_type_id(type_id)
|
||||
if codec is None:
|
||||
if self._strict:
|
||||
raise TypeError(f"Unknown snapshot entity codec type: {type_id}")
|
||||
return value
|
||||
payload = self._hydrate_value(value.get(self._TAG_DATA_KEY, {}))
|
||||
return codec.from_dict(payload)
|
||||
return {str(key): self._hydrate_value(item) for key, item in value.items()}
|
||||
return value
|
||||
|
||||
@@ -6,16 +6,17 @@ from app_runtime.workflow.persistence.workflow_repository import WorkflowReposit
|
||||
|
||||
|
||||
class WorkflowPersistence:
|
||||
def __init__(self, workflow_repository, checkpoint_repository) -> None:
|
||||
def __init__(self, workflow_repository, checkpoint_repository, *, snapshot_sanitizer=None) -> None:
|
||||
self._workflow_repository = workflow_repository
|
||||
self._checkpoint_repository = checkpoint_repository
|
||||
self._snapshot_sanitizer = WorkflowSnapshotSanitizer()
|
||||
self._snapshot_sanitizer = snapshot_sanitizer or WorkflowSnapshotSanitizer()
|
||||
|
||||
@classmethod
|
||||
def create_default(cls, connection_factory=None) -> "WorkflowPersistence":
|
||||
def create_default(cls, connection_factory=None, *, snapshot_sanitizer=None) -> "WorkflowPersistence":
|
||||
return cls(
|
||||
workflow_repository=WorkflowRepository(connection_factory),
|
||||
checkpoint_repository=CheckpointRepository(connection_factory),
|
||||
snapshot_sanitizer=snapshot_sanitizer,
|
||||
)
|
||||
|
||||
def start_run(self, workflow_name: str, start_at: str, snapshot: dict[str, object]) -> int:
|
||||
|
||||
@@ -5,11 +5,15 @@ from app_runtime.workflow.persistence import WorkflowPersistence
|
||||
|
||||
|
||||
class WorkflowRuntimeFactory:
|
||||
def __init__(self, connection_factory=None, *, traces, hooks=None) -> None:
|
||||
def __init__(self, connection_factory=None, *, traces, hooks=None, snapshot_sanitizer=None) -> None:
|
||||
self._connection_factory = connection_factory
|
||||
self._traces = traces
|
||||
self._hooks = hooks
|
||||
self._snapshot_sanitizer = snapshot_sanitizer
|
||||
|
||||
def create_engine(self, workflow) -> WorkflowEngine:
|
||||
persistence = WorkflowPersistence.create_default(self._connection_factory)
|
||||
persistence = WorkflowPersistence.create_default(
|
||||
self._connection_factory,
|
||||
snapshot_sanitizer=self._snapshot_sanitizer,
|
||||
)
|
||||
return WorkflowEngine(workflow, persistence, traces=self._traces, hooks=self._hooks)
|
||||
|
||||
Reference in New Issue
Block a user