feat: add workflow snapshot codec registry and strict sanitization
This commit is contained in:
@@ -1,3 +1,11 @@
|
|||||||
|
from app_runtime.workflow.persistence.codec_registry import CodecRegistry
|
||||||
|
from app_runtime.workflow.persistence.entity_codec import EntityCodec
|
||||||
|
from app_runtime.workflow.persistence.snapshot_sanitizer import WorkflowSnapshotSanitizer
|
||||||
from app_runtime.workflow.persistence.workflow_persistence import WorkflowPersistence
|
from app_runtime.workflow.persistence.workflow_persistence import WorkflowPersistence
|
||||||
|
|
||||||
__all__ = ["WorkflowPersistence"]
|
__all__ = [
|
||||||
|
"CodecRegistry",
|
||||||
|
"EntityCodec",
|
||||||
|
"WorkflowPersistence",
|
||||||
|
"WorkflowSnapshotSanitizer",
|
||||||
|
]
|
||||||
|
|||||||
24
src/app_runtime/workflow/persistence/codec_registry.py
Normal file
24
src/app_runtime/workflow/persistence/codec_registry.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app_runtime.workflow.persistence.entity_codec import EntityCodec
|
||||||
|
|
||||||
|
|
||||||
|
class CodecRegistry:
|
||||||
|
def __init__(self, codecs: list[EntityCodec] | None = None) -> None:
|
||||||
|
self._codecs = list(codecs or [])
|
||||||
|
self._by_type_id = {codec.type_id: codec for codec in self._codecs}
|
||||||
|
|
||||||
|
def register(self, codec: EntityCodec) -> None:
|
||||||
|
self._codecs.append(codec)
|
||||||
|
self._by_type_id[codec.type_id] = codec
|
||||||
|
|
||||||
|
def find_for_value(self, value: Any) -> EntityCodec | None:
|
||||||
|
for codec in self._codecs:
|
||||||
|
if codec.can_encode(value):
|
||||||
|
return codec
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_by_type_id(self, type_id: str) -> EntityCodec | None:
|
||||||
|
return self._by_type_id.get(type_id)
|
||||||
18
src/app_runtime/workflow/persistence/entity_codec.py
Normal file
18
src/app_runtime/workflow/persistence/entity_codec.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class EntityCodec(Protocol):
|
||||||
|
type_id: str
|
||||||
|
schema_version: int
|
||||||
|
|
||||||
|
def can_encode(self, value: Any) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def to_dict(self, value: Any) -> dict[str, Any]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def from_dict(self, payload: dict[str, Any]) -> Any:
|
||||||
|
...
|
||||||
@@ -1,12 +1,25 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from app_runtime.workflow.persistence.codec_registry import CodecRegistry
|
||||||
|
|
||||||
|
|
||||||
class WorkflowSnapshotSanitizer:
|
class WorkflowSnapshotSanitizer:
|
||||||
|
_RUNTIME_STATE_KEY = "runtime_only"
|
||||||
|
_TAG_TYPE_KEY = "__type__"
|
||||||
|
_TAG_VERSION_KEY = "__v__"
|
||||||
|
_TAG_DATA_KEY = "data"
|
||||||
|
|
||||||
|
def __init__(self, codec_registry: CodecRegistry | None = None, *, strict: bool = True) -> None:
|
||||||
|
self._codec_registry = codec_registry or CodecRegistry()
|
||||||
|
self._strict = strict
|
||||||
|
|
||||||
def sanitize(self, snapshot: dict[str, Any]) -> dict[str, Any]:
|
def sanitize(self, snapshot: dict[str, Any]) -> dict[str, Any]:
|
||||||
payload = dict(snapshot.get("payload", {}))
|
payload = dict(snapshot.get("payload", {}))
|
||||||
state = dict(snapshot.get("state", {}))
|
state = dict(snapshot.get("state", {}))
|
||||||
|
state.pop(self._RUNTIME_STATE_KEY, None)
|
||||||
return {
|
return {
|
||||||
"payload": self._sanitize_dict(payload),
|
"payload": self._sanitize_dict(payload),
|
||||||
"state": self._sanitize_dict(state),
|
"state": self._sanitize_dict(state),
|
||||||
@@ -17,4 +30,45 @@ class WorkflowSnapshotSanitizer:
|
|||||||
return {str(key): self._sanitize_dict(item) for key, item in value.items()}
|
return {str(key): self._sanitize_dict(item) for key, item in value.items()}
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return [self._sanitize_dict(item) for item in value]
|
return [self._sanitize_dict(item) for item in value]
|
||||||
|
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, (datetime, date)):
|
||||||
|
return value.isoformat()
|
||||||
|
codec = self._codec_registry.find_for_value(value)
|
||||||
|
if codec is not None:
|
||||||
|
return {
|
||||||
|
self._TAG_TYPE_KEY: codec.type_id,
|
||||||
|
self._TAG_VERSION_KEY: codec.schema_version,
|
||||||
|
self._TAG_DATA_KEY: self._sanitize_dict(codec.to_dict(value)),
|
||||||
|
}
|
||||||
|
if self._strict:
|
||||||
|
raise TypeError(f"Unsupported snapshot value type: {type(value).__name__}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
def hydrate(self, snapshot: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
payload = dict(snapshot.get("payload", {}))
|
||||||
|
state = dict(snapshot.get("state", {}))
|
||||||
|
return {
|
||||||
|
"payload": self._hydrate_value(payload),
|
||||||
|
"state": self._hydrate_value(state),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _hydrate_value(self, value: Any) -> Any:
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [self._hydrate_value(item) for item in value]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if (
|
||||||
|
self._TAG_TYPE_KEY in value
|
||||||
|
and self._TAG_VERSION_KEY in value
|
||||||
|
and self._TAG_DATA_KEY in value
|
||||||
|
):
|
||||||
|
type_id = str(value.get(self._TAG_TYPE_KEY))
|
||||||
|
codec = self._codec_registry.find_by_type_id(type_id)
|
||||||
|
if codec is None:
|
||||||
|
if self._strict:
|
||||||
|
raise TypeError(f"Unknown snapshot entity codec type: {type_id}")
|
||||||
|
return value
|
||||||
|
payload = self._hydrate_value(value.get(self._TAG_DATA_KEY, {}))
|
||||||
|
return codec.from_dict(payload)
|
||||||
|
return {str(key): self._hydrate_value(item) for key, item in value.items()}
|
||||||
return value
|
return value
|
||||||
|
|||||||
@@ -6,16 +6,17 @@ from app_runtime.workflow.persistence.workflow_repository import WorkflowReposit
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowPersistence:
|
class WorkflowPersistence:
|
||||||
def __init__(self, workflow_repository, checkpoint_repository) -> None:
|
def __init__(self, workflow_repository, checkpoint_repository, *, snapshot_sanitizer=None) -> None:
|
||||||
self._workflow_repository = workflow_repository
|
self._workflow_repository = workflow_repository
|
||||||
self._checkpoint_repository = checkpoint_repository
|
self._checkpoint_repository = checkpoint_repository
|
||||||
self._snapshot_sanitizer = WorkflowSnapshotSanitizer()
|
self._snapshot_sanitizer = snapshot_sanitizer or WorkflowSnapshotSanitizer()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_default(cls, connection_factory=None) -> "WorkflowPersistence":
|
def create_default(cls, connection_factory=None, *, snapshot_sanitizer=None) -> "WorkflowPersistence":
|
||||||
return cls(
|
return cls(
|
||||||
workflow_repository=WorkflowRepository(connection_factory),
|
workflow_repository=WorkflowRepository(connection_factory),
|
||||||
checkpoint_repository=CheckpointRepository(connection_factory),
|
checkpoint_repository=CheckpointRepository(connection_factory),
|
||||||
|
snapshot_sanitizer=snapshot_sanitizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_run(self, workflow_name: str, start_at: str, snapshot: dict[str, object]) -> int:
|
def start_run(self, workflow_name: str, start_at: str, snapshot: dict[str, object]) -> int:
|
||||||
|
|||||||
@@ -5,11 +5,15 @@ from app_runtime.workflow.persistence import WorkflowPersistence
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowRuntimeFactory:
|
class WorkflowRuntimeFactory:
|
||||||
def __init__(self, connection_factory=None, *, traces, hooks=None) -> None:
|
def __init__(self, connection_factory=None, *, traces, hooks=None, snapshot_sanitizer=None) -> None:
|
||||||
self._connection_factory = connection_factory
|
self._connection_factory = connection_factory
|
||||||
self._traces = traces
|
self._traces = traces
|
||||||
self._hooks = hooks
|
self._hooks = hooks
|
||||||
|
self._snapshot_sanitizer = snapshot_sanitizer
|
||||||
|
|
||||||
def create_engine(self, workflow) -> WorkflowEngine:
|
def create_engine(self, workflow) -> WorkflowEngine:
|
||||||
persistence = WorkflowPersistence.create_default(self._connection_factory)
|
persistence = WorkflowPersistence.create_default(
|
||||||
|
self._connection_factory,
|
||||||
|
snapshot_sanitizer=self._snapshot_sanitizer,
|
||||||
|
)
|
||||||
return WorkflowEngine(workflow, persistence, traces=self._traces, hooks=self._hooks)
|
return WorkflowEngine(workflow, persistence, traces=self._traces, hooks=self._hooks)
|
||||||
|
|||||||
Reference in New Issue
Block a user