feat: add workflow snapshot codec registry and strict sanitization

This commit is contained in:
2026-03-07 17:11:21 +03:00
parent 80a0ea8956
commit 3c6f74dadd
6 changed files with 115 additions and 6 deletions

View File

@@ -1,3 +1,11 @@
from app_runtime.workflow.persistence.codec_registry import CodecRegistry
from app_runtime.workflow.persistence.entity_codec import EntityCodec
from app_runtime.workflow.persistence.snapshot_sanitizer import WorkflowSnapshotSanitizer
from app_runtime.workflow.persistence.workflow_persistence import WorkflowPersistence
__all__ = ["WorkflowPersistence"]
__all__ = [
"CodecRegistry",
"EntityCodec",
"WorkflowPersistence",
"WorkflowSnapshotSanitizer",
]

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from typing import Any
from app_runtime.workflow.persistence.entity_codec import EntityCodec
class CodecRegistry:
def __init__(self, codecs: list[EntityCodec] | None = None) -> None:
self._codecs = list(codecs or [])
self._by_type_id = {codec.type_id: codec for codec in self._codecs}
def register(self, codec: EntityCodec) -> None:
self._codecs.append(codec)
self._by_type_id[codec.type_id] = codec
def find_for_value(self, value: Any) -> EntityCodec | None:
for codec in self._codecs:
if codec.can_encode(value):
return codec
return None
def find_by_type_id(self, type_id: str) -> EntityCodec | None:
return self._by_type_id.get(type_id)

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class EntityCodec(Protocol):
type_id: str
schema_version: int
def can_encode(self, value: Any) -> bool:
...
def to_dict(self, value: Any) -> dict[str, Any]:
...
def from_dict(self, payload: dict[str, Any]) -> Any:
...

View File

@@ -1,12 +1,25 @@
from __future__ import annotations
from datetime import date, datetime
from typing import Any
from app_runtime.workflow.persistence.codec_registry import CodecRegistry
class WorkflowSnapshotSanitizer:
_RUNTIME_STATE_KEY = "runtime_only"
_TAG_TYPE_KEY = "__type__"
_TAG_VERSION_KEY = "__v__"
_TAG_DATA_KEY = "data"
def __init__(self, codec_registry: CodecRegistry | None = None, *, strict: bool = True) -> None:
self._codec_registry = codec_registry or CodecRegistry()
self._strict = strict
def sanitize(self, snapshot: dict[str, Any]) -> dict[str, Any]:
payload = dict(snapshot.get("payload", {}))
state = dict(snapshot.get("state", {}))
state.pop(self._RUNTIME_STATE_KEY, None)
return {
"payload": self._sanitize_dict(payload),
"state": self._sanitize_dict(state),
@@ -17,4 +30,45 @@ class WorkflowSnapshotSanitizer:
return {str(key): self._sanitize_dict(item) for key, item in value.items()}
if isinstance(value, list):
return [self._sanitize_dict(item) for item in value]
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (datetime, date)):
return value.isoformat()
codec = self._codec_registry.find_for_value(value)
if codec is not None:
return {
self._TAG_TYPE_KEY: codec.type_id,
self._TAG_VERSION_KEY: codec.schema_version,
self._TAG_DATA_KEY: self._sanitize_dict(codec.to_dict(value)),
}
if self._strict:
raise TypeError(f"Unsupported snapshot value type: {type(value).__name__}")
return value
def hydrate(self, snapshot: dict[str, Any]) -> dict[str, Any]:
payload = dict(snapshot.get("payload", {}))
state = dict(snapshot.get("state", {}))
return {
"payload": self._hydrate_value(payload),
"state": self._hydrate_value(state),
}
def _hydrate_value(self, value: Any) -> Any:
if isinstance(value, list):
return [self._hydrate_value(item) for item in value]
if isinstance(value, dict):
if (
self._TAG_TYPE_KEY in value
and self._TAG_VERSION_KEY in value
and self._TAG_DATA_KEY in value
):
type_id = str(value.get(self._TAG_TYPE_KEY))
codec = self._codec_registry.find_by_type_id(type_id)
if codec is None:
if self._strict:
raise TypeError(f"Unknown snapshot entity codec type: {type_id}")
return value
payload = self._hydrate_value(value.get(self._TAG_DATA_KEY, {}))
return codec.from_dict(payload)
return {str(key): self._hydrate_value(item) for key, item in value.items()}
return value

View File

@@ -6,16 +6,17 @@ from app_runtime.workflow.persistence.workflow_repository import WorkflowReposit
class WorkflowPersistence:
def __init__(self, workflow_repository, checkpoint_repository) -> None:
def __init__(self, workflow_repository, checkpoint_repository, *, snapshot_sanitizer=None) -> None:
self._workflow_repository = workflow_repository
self._checkpoint_repository = checkpoint_repository
self._snapshot_sanitizer = WorkflowSnapshotSanitizer()
self._snapshot_sanitizer = snapshot_sanitizer or WorkflowSnapshotSanitizer()
@classmethod
def create_default(cls, connection_factory=None) -> "WorkflowPersistence":
def create_default(cls, connection_factory=None, *, snapshot_sanitizer=None) -> "WorkflowPersistence":
return cls(
workflow_repository=WorkflowRepository(connection_factory),
checkpoint_repository=CheckpointRepository(connection_factory),
snapshot_sanitizer=snapshot_sanitizer,
)
def start_run(self, workflow_name: str, start_at: str, snapshot: dict[str, object]) -> int:

View File

@@ -5,11 +5,15 @@ from app_runtime.workflow.persistence import WorkflowPersistence
class WorkflowRuntimeFactory:
def __init__(self, connection_factory=None, *, traces, hooks=None) -> None:
def __init__(self, connection_factory=None, *, traces, hooks=None, snapshot_sanitizer=None) -> None:
self._connection_factory = connection_factory
self._traces = traces
self._hooks = hooks
self._snapshot_sanitizer = snapshot_sanitizer
def create_engine(self, workflow) -> WorkflowEngine:
persistence = WorkflowPersistence.create_default(self._connection_factory)
persistence = WorkflowPersistence.create_default(
self._connection_factory,
snapshot_sanitizer=self._snapshot_sanitizer,
)
return WorkflowEngine(workflow, persistence, traces=self._traces, hooks=self._hooks)