from __future__ import annotations from app_runtime.workflow.persistence.checkpoint_repository import CheckpointRepository from app_runtime.workflow.persistence.snapshot_sanitizer import WorkflowSnapshotSanitizer from app_runtime.workflow.persistence.workflow_repository import WorkflowRepository class WorkflowPersistence: def __init__(self, workflow_repository, checkpoint_repository, *, snapshot_sanitizer=None) -> None: self._workflow_repository = workflow_repository self._checkpoint_repository = checkpoint_repository self._snapshot_sanitizer = snapshot_sanitizer or WorkflowSnapshotSanitizer() @classmethod def create_default(cls, connection_factory=None, *, snapshot_sanitizer=None) -> "WorkflowPersistence": return cls( workflow_repository=WorkflowRepository(connection_factory), checkpoint_repository=CheckpointRepository(connection_factory), snapshot_sanitizer=snapshot_sanitizer, ) def start_run(self, workflow_name: str, start_at: str, snapshot: dict[str, object]) -> int: sanitized = self._snapshot_sanitizer.sanitize(snapshot) run_id = self._workflow_repository.create_run(workflow_name, sanitized) self._checkpoint_repository.save(run_id, start_at, "workflow_started", sanitized) return run_id def start_step(self, run_id: int, node_name: str, snapshot: dict[str, object]) -> None: self._checkpoint_repository.save(run_id, node_name, "step_started", self._snapshot_sanitizer.sanitize(snapshot)) def complete_step( self, run_id: int, node_name: str, status: str, transition: str, snapshot: dict[str, object], ) -> None: sanitized = self._snapshot_sanitizer.sanitize(snapshot) self._workflow_repository.record_step(run_id, node_name, status, transition, sanitized) self._checkpoint_repository.save(run_id, node_name, "step_completed", sanitized) def fail_step(self, run_id: int, node_name: str, snapshot: dict[str, object], error: Exception) -> None: sanitized = self._snapshot_sanitizer.sanitize(snapshot) self._workflow_repository.fail_step(run_id, node_name, sanitized, error) self._checkpoint_repository.save(run_id, node_name, "step_failed", sanitized) def complete_run(self, run_id: int, snapshot: dict[str, object]) -> None: sanitized = self._snapshot_sanitizer.sanitize(snapshot) self._workflow_repository.complete_run(run_id, sanitized) self._checkpoint_repository.save(run_id, "workflow_done", "workflow_finished", sanitized) def fail_run(self, run_id: int, snapshot: dict[str, object]) -> None: self._workflow_repository.fail_run(run_id, self._snapshot_sanitizer.sanitize(snapshot))