Перенес workflow
This commit is contained in:
Binary file not shown.
@@ -13,4 +13,4 @@ class ApplicationModule(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def register(self, registry: ModuleRegistry) -> None:
|
||||
"""Register workers, queues, handlers, services, and health contributors."""
|
||||
"""Register workers, services, and health contributors."""
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from app_runtime.contracts.tasks import Task
|
||||
|
||||
|
||||
class TaskQueue(ABC):
|
||||
@abstractmethod
|
||||
def publish(self, task: Task) -> None:
|
||||
"""Push a task into the queue."""
|
||||
|
||||
@abstractmethod
|
||||
def consume(self, timeout: float = 0.1) -> Task | None:
|
||||
"""Return the next available task or None."""
|
||||
|
||||
@abstractmethod
|
||||
def ack(self, task: Task) -> None:
|
||||
"""Confirm successful task processing."""
|
||||
|
||||
@abstractmethod
|
||||
def nack(self, task: Task, retry_delay: float | None = None) -> None:
|
||||
"""Signal failed task processing."""
|
||||
|
||||
@abstractmethod
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""Return transport-level queue statistics."""
|
||||
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Task:
|
||||
name: str
|
||||
payload: dict[str, Any]
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class TaskHandler(ABC):
|
||||
@abstractmethod
|
||||
def handle(self, task: Task) -> None:
|
||||
"""Execute domain logic for a task."""
|
||||
Binary file not shown.
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app_runtime.contracts.health import HealthContributor
|
||||
from app_runtime.contracts.queue import TaskQueue
|
||||
from app_runtime.contracts.tasks import TaskHandler
|
||||
from app_runtime.contracts.worker import Worker
|
||||
from app_runtime.core.service_container import ServiceContainer
|
||||
|
||||
@@ -10,8 +8,6 @@ from app_runtime.core.service_container import ServiceContainer
|
||||
class ModuleRegistry:
|
||||
def __init__(self, services: ServiceContainer) -> None:
|
||||
self.services = services
|
||||
self.queues: dict[str, TaskQueue] = {}
|
||||
self.handlers: dict[str, TaskHandler] = {}
|
||||
self.workers: list[Worker] = []
|
||||
self.health_contributors: list[HealthContributor] = []
|
||||
self.modules: list[str] = []
|
||||
@@ -19,12 +15,6 @@ class ModuleRegistry:
|
||||
def register_module(self, name: str) -> None:
|
||||
self.modules.append(name)
|
||||
|
||||
def add_queue(self, name: str, queue: TaskQueue) -> None:
|
||||
self.queues[name] = queue
|
||||
|
||||
def add_handler(self, name: str, handler: TaskHandler) -> None:
|
||||
self.handlers[name] = handler
|
||||
|
||||
def add_worker(self, worker: Worker) -> None:
|
||||
self.workers.append(worker)
|
||||
|
||||
|
||||
Binary file not shown.
@@ -1,43 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from queue import Empty, Queue
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from app_runtime.contracts.queue import TaskQueue
|
||||
from app_runtime.contracts.tasks import Task
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InMemoryTaskQueue(TaskQueue):
|
||||
class InMemoryTaskQueue(Generic[T]):
|
||||
def __init__(self) -> None:
|
||||
self._queue: Queue[Task] = Queue()
|
||||
self._published = 0
|
||||
self._acked = 0
|
||||
self._nacked = 0
|
||||
self._queue: Queue[T] = Queue()
|
||||
self._put_count = 0
|
||||
self._get_count = 0
|
||||
|
||||
def publish(self, task: Task) -> None:
|
||||
self._published += 1
|
||||
self._queue.put(task)
|
||||
def put(self, item: T) -> None:
|
||||
self._put_count += 1
|
||||
self._queue.put(item)
|
||||
|
||||
def consume(self, timeout: float = 0.1) -> Task | None:
|
||||
def get(self, timeout: float = 0.1) -> T | None:
|
||||
try:
|
||||
return self._queue.get(timeout=timeout)
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except Empty:
|
||||
return None
|
||||
self._get_count += 1
|
||||
return item
|
||||
|
||||
def ack(self, task: Task) -> None:
|
||||
del task
|
||||
self._acked += 1
|
||||
def task_done(self) -> None:
|
||||
self._queue.task_done()
|
||||
|
||||
def nack(self, task: Task, retry_delay: float | None = None) -> None:
|
||||
del retry_delay
|
||||
self._nacked += 1
|
||||
self._queue.put(task)
|
||||
self._queue.task_done()
|
||||
def qsize(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
return {
|
||||
"published": self._published,
|
||||
"acked": self._acked,
|
||||
"nacked": self._nacked,
|
||||
"put": self._put_count,
|
||||
"got": self._get_count,
|
||||
"queued": self._queue.qsize(),
|
||||
}
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -99,7 +99,7 @@ class TraceService(TraceContextFactory):
|
||||
self._write_message("ERROR", message, status, attrs)
|
||||
|
||||
def new_root(self, operation: str) -> TraceContext:
|
||||
trace_id = self.create_context(alias=operation, kind="source", attrs={"operation": operation})
|
||||
trace_id = self.create_context(alias=operation, kind="operation", attrs={"operation": operation})
|
||||
return TraceContext(trace_id=trace_id, span_id=trace_id, attributes={"operation": operation})
|
||||
|
||||
def child_of(self, parent: TraceContext, operation: str) -> TraceContext:
|
||||
@@ -116,22 +116,22 @@ class TraceService(TraceContextFactory):
|
||||
attributes={"operation": operation},
|
||||
)
|
||||
|
||||
def attach(self, task_metadata: dict[str, object], context: TraceContext) -> dict[str, object]:
|
||||
updated = dict(task_metadata)
|
||||
def attach(self, metadata: dict[str, object], context: TraceContext) -> dict[str, object]:
|
||||
updated = dict(metadata)
|
||||
updated["trace_id"] = context.trace_id
|
||||
updated["span_id"] = context.span_id
|
||||
updated["parent_span_id"] = context.parent_span_id
|
||||
return updated
|
||||
|
||||
def resume(self, task_metadata: dict[str, object], operation: str) -> TraceContext:
|
||||
trace_id = str(task_metadata.get("trace_id") or uuid4().hex)
|
||||
span_id = str(task_metadata.get("span_id") or trace_id)
|
||||
parent_id = task_metadata.get("parent_span_id")
|
||||
def resume(self, metadata: dict[str, object], operation: str) -> TraceContext:
|
||||
trace_id = str(metadata.get("trace_id") or uuid4().hex)
|
||||
span_id = str(metadata.get("span_id") or trace_id)
|
||||
parent_id = metadata.get("parent_span_id")
|
||||
self.create_context(
|
||||
alias=operation,
|
||||
parent_id=str(parent_id) if parent_id else None,
|
||||
kind="handler",
|
||||
attrs=dict(task_metadata),
|
||||
kind="worker",
|
||||
attrs=dict(metadata),
|
||||
)
|
||||
return TraceContext(
|
||||
trace_id=trace_id,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from app_runtime.workers.queue_worker import QueueWorker
|
||||
from app_runtime.workers.supervisor import WorkerSupervisor
|
||||
|
||||
__all__ = ["QueueWorker", "WorkerSupervisor"]
|
||||
__all__ = ["WorkerSupervisor"]
|
||||
|
||||
Binary file not shown.
@@ -1,125 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
from app_runtime.contracts.queue import TaskQueue
|
||||
from app_runtime.contracts.tasks import TaskHandler
|
||||
from app_runtime.contracts.worker import Worker, WorkerHealth, WorkerStatus
|
||||
from app_runtime.tracing.service import TraceService
|
||||
|
||||
|
||||
class QueueWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
queue: TaskQueue,
|
||||
handler: TaskHandler,
|
||||
traces: TraceService,
|
||||
*,
|
||||
concurrency: int = 1,
|
||||
critical: bool = True,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._queue = queue
|
||||
self._handler = handler
|
||||
self._traces = traces
|
||||
self._concurrency = concurrency
|
||||
self._critical = critical
|
||||
self._threads: list[Thread] = []
|
||||
self._stop_requested = Event()
|
||||
self._force_stop = Event()
|
||||
self._lock = Lock()
|
||||
self._started = False
|
||||
self._in_flight = 0
|
||||
self._processed = 0
|
||||
self._failures = 0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def critical(self) -> bool:
|
||||
return self._critical
|
||||
|
||||
def start(self) -> None:
|
||||
if any(thread.is_alive() for thread in self._threads):
|
||||
return
|
||||
self._threads.clear()
|
||||
self._stop_requested.clear()
|
||||
self._force_stop.clear()
|
||||
self._started = True
|
||||
for index in range(self._concurrency):
|
||||
thread = Thread(target=self._run_loop, name=f"{self._name}-{index + 1}", daemon=True)
|
||||
self._threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
def stop(self, force: bool = False) -> None:
|
||||
self._stop_requested.set()
|
||||
if force:
|
||||
self._force_stop.set()
|
||||
|
||||
def health(self) -> WorkerHealth:
|
||||
status = self.status()
|
||||
if self._started and not self._stop_requested.is_set() and self._alive_threads() == 0:
|
||||
return WorkerHealth(self.name, "unhealthy", self.critical, "worker threads are not running", status.meta)
|
||||
if self._failures > 0:
|
||||
return WorkerHealth(self.name, "degraded", self.critical, "worker has processing failures", status.meta)
|
||||
return WorkerHealth(self.name, "ok", self.critical, meta=status.meta)
|
||||
|
||||
def status(self) -> WorkerStatus:
|
||||
alive_threads = self._alive_threads()
|
||||
with self._lock:
|
||||
in_flight = self._in_flight
|
||||
processed = self._processed
|
||||
failures = self._failures
|
||||
if self._started and alive_threads == 0:
|
||||
state = "stopped"
|
||||
elif self._stop_requested.is_set():
|
||||
state = "stopping" if alive_threads > 0 else "stopped"
|
||||
elif not self._started:
|
||||
state = "stopped"
|
||||
elif in_flight > 0:
|
||||
state = "busy"
|
||||
else:
|
||||
state = "idle"
|
||||
return WorkerStatus(
|
||||
name=self.name,
|
||||
state=state,
|
||||
in_flight=in_flight,
|
||||
meta={
|
||||
"alive_threads": alive_threads,
|
||||
"concurrency": self._concurrency,
|
||||
"processed": processed,
|
||||
"failures": failures,
|
||||
},
|
||||
)
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
if self._force_stop.is_set() or self._stop_requested.is_set():
|
||||
return
|
||||
task = self._queue.consume(timeout=0.1)
|
||||
if task is None:
|
||||
continue
|
||||
with self._lock:
|
||||
self._in_flight += 1
|
||||
self._traces.resume(task.metadata, f"worker:{self.name}")
|
||||
try:
|
||||
self._handler.handle(task)
|
||||
except Exception:
|
||||
with self._lock:
|
||||
self._failures += 1
|
||||
self._queue.nack(task)
|
||||
else:
|
||||
with self._lock:
|
||||
self._processed += 1
|
||||
self._queue.ack(task)
|
||||
finally:
|
||||
with self._lock:
|
||||
self._in_flight -= 1
|
||||
if self._stop_requested.is_set():
|
||||
return
|
||||
|
||||
def _alive_threads(self) -> int:
|
||||
return sum(1 for thread in self._threads if thread.is_alive())
|
||||
@@ -0,0 +1 @@
|
||||
__all__: list[str] = []
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
__all__: list[str] = []
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WorkflowContext:
|
||||
payload: dict[str, Any]
|
||||
state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"payload": dict(self.payload),
|
||||
"state": dict(self.state),
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StepResult:
|
||||
transition: str = "success"
|
||||
updates: dict[str, Any] = field(default_factory=dict)
|
||||
status: str = "completed"
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app_runtime.workflow.contracts.context import WorkflowContext
|
||||
from app_runtime.workflow.contracts.result import StepResult
|
||||
|
||||
|
||||
class WorkflowStep(ABC):
|
||||
@abstractmethod
|
||||
def run(self, context: WorkflowContext) -> StepResult:
|
||||
"""Run the step and return transition metadata."""
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from app_runtime.workflow.contracts.step import WorkflowStep
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WorkflowNode:
|
||||
name: str
|
||||
step: WorkflowStep
|
||||
transitions: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WorkflowDefinition:
|
||||
name: str
|
||||
start_at: str
|
||||
nodes: dict[str, WorkflowNode]
|
||||
@@ -0,0 +1 @@
|
||||
__all__: list[str] = []
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app_runtime.workflow.contracts.context import WorkflowContext
|
||||
|
||||
|
||||
class WorkflowEngineHooks:
|
||||
def on_step_started(self, context: WorkflowContext, step: str) -> None:
|
||||
del context, step
|
||||
|
||||
def on_step_finished(self, context: WorkflowContext, step: str) -> None:
|
||||
del context, step
|
||||
|
||||
def on_step_failed(self, context: WorkflowContext, step: str) -> None:
|
||||
del context, step
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app_runtime.workflow.contracts.result import StepResult
|
||||
from app_runtime.workflow.contracts.workflow import WorkflowNode
|
||||
|
||||
|
||||
class TransitionResolver:
|
||||
def resolve(self, node: WorkflowNode, result: StepResult) -> str | None:
|
||||
return node.transitions.get(result.transition)
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app_runtime.workflow.contracts.context import WorkflowContext
|
||||
from app_runtime.workflow.engine.hooks import WorkflowEngineHooks
|
||||
from app_runtime.workflow.engine.transition_resolver import TransitionResolver
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
def __init__(self, workflow, persistence, *, traces, hooks: WorkflowEngineHooks | None = None) -> None:
|
||||
self._workflow = workflow
|
||||
self._persistence = persistence
|
||||
self._transition_resolver = TransitionResolver()
|
||||
self._traces = traces
|
||||
self._hooks = hooks or WorkflowEngineHooks()
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
def run(self, context: WorkflowContext) -> dict[str, object]:
|
||||
run_id = self._persistence.start_run(
|
||||
self._workflow.definition.name,
|
||||
self._workflow.definition.start_at,
|
||||
context.snapshot(),
|
||||
)
|
||||
context.state.setdefault("runtime", {})
|
||||
context.state["runtime"]["workflow_run_id"] = run_id
|
||||
self._traces.step("workflow")
|
||||
self._traces.info("Workflow started.", status="started", attrs={"workflow_run_id": run_id})
|
||||
current_name = self._workflow.definition.start_at
|
||||
while current_name is not None:
|
||||
node = self._workflow.definition.nodes[current_name]
|
||||
self._logger.info("Workflow run %s: step '%s' started.", run_id, node.name)
|
||||
self._hooks.on_step_started(context, node.name)
|
||||
self._persistence.start_step(run_id, node.name, context.snapshot())
|
||||
self._traces.step(node.name)
|
||||
self._traces.info(f"Step '{node.name}' started.", status="started")
|
||||
try:
|
||||
result = node.step.run(context)
|
||||
except Exception as error:
|
||||
self._persistence.fail_step(run_id, node.name, context.snapshot(), error)
|
||||
self._persistence.fail_run(run_id, context.snapshot())
|
||||
self._traces.error(
|
||||
f"Step '{node.name}' failed: {error}",
|
||||
status="failed",
|
||||
attrs={"exception_type": type(error).__name__},
|
||||
)
|
||||
self._logger.exception("Workflow run %s: step '%s' failed.", run_id, node.name)
|
||||
self._hooks.on_step_failed(context, node.name)
|
||||
raise
|
||||
context.state.update(result.updates)
|
||||
self._persistence.complete_step(run_id, node.name, result.status, result.transition, context.snapshot())
|
||||
self._traces.info(
|
||||
f"Step '{node.name}' completed with transition '{result.transition}'.",
|
||||
status=result.status,
|
||||
)
|
||||
self._logger.info(
|
||||
"Workflow run %s: step '%s' completed with transition '%s'.",
|
||||
run_id,
|
||||
node.name,
|
||||
result.transition,
|
||||
)
|
||||
self._hooks.on_step_finished(context, node.name)
|
||||
current_name = self._transition_resolver.resolve(node, result)
|
||||
self._persistence.complete_run(run_id, context.snapshot())
|
||||
self._traces.step("workflow")
|
||||
self._traces.info("Workflow completed.", status="completed", attrs={"workflow_run_id": run_id})
|
||||
self._logger.info("Workflow run %s completed.", run_id)
|
||||
return {"run_id": run_id, "status": "completed", "context": context.snapshot()}
|
||||
@@ -0,0 +1,3 @@
|
||||
from app_runtime.workflow.persistence.workflow_persistence import WorkflowPersistence
|
||||
|
||||
__all__ = ["WorkflowPersistence"]
|
||||
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CheckpointRepository:
|
||||
def __init__(self, connection_factory: object | None = None) -> None:
|
||||
self._connection_factory = connection_factory
|
||||
self._checkpoints: list[dict[str, Any]] = []
|
||||
|
||||
def save(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
node_name: str,
|
||||
checkpoint_kind: str,
|
||||
snapshot: dict[str, Any],
|
||||
) -> None:
|
||||
if self._use_memory():
|
||||
self._checkpoints.append(
|
||||
{
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"node_name": node_name,
|
||||
"checkpoint_kind": checkpoint_kind,
|
||||
"snapshot": snapshot,
|
||||
}
|
||||
)
|
||||
return
|
||||
query = """
|
||||
INSERT INTO workflow_checkpoints (
|
||||
workflow_run_id, node_name, checkpoint_kind, snapshot_json, created_at
|
||||
) VALUES (%s, %s, %s, %s, UTC_TIMESTAMP(6))
|
||||
"""
|
||||
with self._connection_factory.connect() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
query,
|
||||
(
|
||||
workflow_run_id,
|
||||
node_name,
|
||||
checkpoint_kind,
|
||||
self._connection_factory.dumps(snapshot),
|
||||
),
|
||||
)
|
||||
|
||||
def _use_memory(self) -> bool:
|
||||
return self._connection_factory is None or not self._connection_factory.is_configured()
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class WorkflowSnapshotSanitizer:
|
||||
def sanitize(self, snapshot: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = dict(snapshot.get("payload", {}))
|
||||
state = dict(snapshot.get("state", {}))
|
||||
return {
|
||||
"payload": self._sanitize_dict(payload),
|
||||
"state": self._sanitize_dict(state),
|
||||
}
|
||||
|
||||
def _sanitize_dict(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {str(key): self._sanitize_dict(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [self._sanitize_dict(item) for item in value]
|
||||
return value
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app_runtime.workflow.persistence.checkpoint_repository import CheckpointRepository
|
||||
from app_runtime.workflow.persistence.snapshot_sanitizer import WorkflowSnapshotSanitizer
|
||||
from app_runtime.workflow.persistence.workflow_repository import WorkflowRepository
|
||||
|
||||
|
||||
class WorkflowPersistence:
|
||||
def __init__(self, workflow_repository, checkpoint_repository) -> None:
|
||||
self._workflow_repository = workflow_repository
|
||||
self._checkpoint_repository = checkpoint_repository
|
||||
self._snapshot_sanitizer = WorkflowSnapshotSanitizer()
|
||||
|
||||
@classmethod
|
||||
def create_default(cls, connection_factory=None) -> "WorkflowPersistence":
|
||||
return cls(
|
||||
workflow_repository=WorkflowRepository(connection_factory),
|
||||
checkpoint_repository=CheckpointRepository(connection_factory),
|
||||
)
|
||||
|
||||
def start_run(self, workflow_name: str, start_at: str, snapshot: dict[str, object]) -> int:
|
||||
sanitized = self._snapshot_sanitizer.sanitize(snapshot)
|
||||
run_id = self._workflow_repository.create_run(workflow_name, sanitized)
|
||||
self._checkpoint_repository.save(run_id, start_at, "workflow_started", sanitized)
|
||||
return run_id
|
||||
|
||||
def start_step(self, run_id: int, node_name: str, snapshot: dict[str, object]) -> None:
|
||||
self._checkpoint_repository.save(run_id, node_name, "step_started", self._snapshot_sanitizer.sanitize(snapshot))
|
||||
|
||||
def complete_step(
|
||||
self,
|
||||
run_id: int,
|
||||
node_name: str,
|
||||
status: str,
|
||||
transition: str,
|
||||
snapshot: dict[str, object],
|
||||
) -> None:
|
||||
sanitized = self._snapshot_sanitizer.sanitize(snapshot)
|
||||
self._workflow_repository.record_step(run_id, node_name, status, transition, sanitized)
|
||||
self._checkpoint_repository.save(run_id, node_name, "step_completed", sanitized)
|
||||
|
||||
def fail_step(self, run_id: int, node_name: str, snapshot: dict[str, object], error: Exception) -> None:
|
||||
sanitized = self._snapshot_sanitizer.sanitize(snapshot)
|
||||
self._workflow_repository.fail_step(run_id, node_name, sanitized, error)
|
||||
self._checkpoint_repository.save(run_id, node_name, "step_failed", sanitized)
|
||||
|
||||
def complete_run(self, run_id: int, snapshot: dict[str, object]) -> None:
|
||||
sanitized = self._snapshot_sanitizer.sanitize(snapshot)
|
||||
self._workflow_repository.complete_run(run_id, sanitized)
|
||||
self._checkpoint_repository.save(run_id, "workflow_done", "workflow_finished", sanitized)
|
||||
|
||||
def fail_run(self, run_id: int, snapshot: dict[str, object]) -> None:
|
||||
self._workflow_repository.fail_run(run_id, self._snapshot_sanitizer.sanitize(snapshot))
|
||||
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
|
||||
|
||||
class WorkflowRepository:
|
||||
def __init__(self, connection_factory: object | None = None) -> None:
|
||||
self._connection_factory = connection_factory
|
||||
self._counter = count(1)
|
||||
self._runs: dict[int, dict[str, Any]] = {}
|
||||
|
||||
def create_run(self, workflow_name: str, snapshot: dict[str, Any]) -> int:
|
||||
if self._use_memory():
|
||||
run_id = next(self._counter)
|
||||
self._runs[run_id] = {
|
||||
"workflow_name": workflow_name,
|
||||
"status": "running",
|
||||
"snapshot": snapshot,
|
||||
"steps": [],
|
||||
}
|
||||
return run_id
|
||||
payload = self._build_run_payload(workflow_name, snapshot)
|
||||
query = """
|
||||
INSERT INTO workflow_runs (
|
||||
workflow_name, workflow_version, business_key, queue_task_id, inbox_message_id,
|
||||
current_node, status, context_json, trace_id, started_at, created_at, updated_at
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6), UTC_TIMESTAMP(6))
|
||||
"""
|
||||
with self._connection_factory.connect() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, payload)
|
||||
return int(cursor.lastrowid)
|
||||
|
||||
def record_step(
|
||||
self,
|
||||
run_id: int,
|
||||
node_name: str,
|
||||
status: str,
|
||||
transition: str,
|
||||
snapshot: dict[str, Any],
|
||||
) -> None:
|
||||
if self._use_memory():
|
||||
self._runs[run_id]["steps"].append(
|
||||
{"node_name": node_name, "status": status, "transition": transition, "snapshot": snapshot}
|
||||
)
|
||||
return
|
||||
context_json = self._connection_factory.dumps(snapshot)
|
||||
insert_query = """
|
||||
INSERT INTO workflow_steps (
|
||||
workflow_run_id, node_name, status, transition_name, input_json, output_json, created_at, started_at, finished_at
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6), UTC_TIMESTAMP(6))
|
||||
"""
|
||||
update_query = """
|
||||
UPDATE workflow_runs
|
||||
SET current_node = %s, status = %s, context_json = %s, updated_at = UTC_TIMESTAMP(6)
|
||||
WHERE id = %s
|
||||
"""
|
||||
with self._connection_factory.connect() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(insert_query, (run_id, node_name, status, transition, context_json, context_json))
|
||||
cursor.execute(update_query, (node_name, "running", context_json, run_id))
|
||||
|
||||
def complete_run(self, run_id: int, snapshot: dict[str, Any]) -> None:
|
||||
if self._use_memory():
|
||||
self._runs[run_id]["status"] = "completed"
|
||||
self._runs[run_id]["snapshot"] = snapshot
|
||||
return
|
||||
query = """
|
||||
UPDATE workflow_runs
|
||||
SET current_node = NULL, status = 'completed', context_json = %s, finished_at = UTC_TIMESTAMP(6), updated_at = UTC_TIMESTAMP(6)
|
||||
WHERE id = %s
|
||||
"""
|
||||
with self._connection_factory.connect() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, (self._connection_factory.dumps(snapshot), run_id))
|
||||
|
||||
def fail_step(self, run_id: int, node_name: str, snapshot: dict[str, Any], error: Exception) -> None:
|
||||
if self._use_memory():
|
||||
self._runs[run_id]["steps"].append(
|
||||
{
|
||||
"node_name": node_name,
|
||||
"status": "failed",
|
||||
"transition": "error",
|
||||
"snapshot": snapshot,
|
||||
"error": str(error),
|
||||
}
|
||||
)
|
||||
self._runs[run_id]["status"] = "failed"
|
||||
return
|
||||
snapshot_json = self._connection_factory.dumps(snapshot)
|
||||
error_json = self._connection_factory.dumps({"message": str(error), "exception_type": type(error).__name__})
|
||||
insert_query = """
|
||||
INSERT INTO workflow_steps (
|
||||
workflow_run_id, node_name, status, transition_name, input_json, output_json, error_json, created_at, started_at, finished_at
|
||||
) VALUES (%s, %s, 'failed', 'error', %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6), UTC_TIMESTAMP(6))
|
||||
"""
|
||||
update_query = """
|
||||
UPDATE workflow_runs
|
||||
SET current_node = %s, status = 'failed', context_json = %s, updated_at = UTC_TIMESTAMP(6)
|
||||
WHERE id = %s
|
||||
"""
|
||||
with self._connection_factory.connect() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(insert_query, (run_id, node_name, snapshot_json, snapshot_json, error_json))
|
||||
cursor.execute(update_query, (node_name, snapshot_json, run_id))
|
||||
|
||||
def fail_run(self, run_id: int, snapshot: dict[str, Any]) -> None:
|
||||
if self._use_memory():
|
||||
self._runs[run_id]["status"] = "failed"
|
||||
self._runs[run_id]["snapshot"] = snapshot
|
||||
return
|
||||
query = """
|
||||
UPDATE workflow_runs
|
||||
SET status = 'failed', context_json = %s, finished_at = UTC_TIMESTAMP(6), updated_at = UTC_TIMESTAMP(6)
|
||||
WHERE id = %s
|
||||
"""
|
||||
with self._connection_factory.connect() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, (self._connection_factory.dumps(snapshot), run_id))
|
||||
|
||||
def _build_run_payload(self, workflow_name: str, snapshot: dict[str, Any]) -> tuple:
|
||||
payload = snapshot.get("payload", {})
|
||||
state = snapshot.get("state", {})
|
||||
runtime = state.get("runtime", {})
|
||||
business_key = payload.get("inbox_message", {}).get("external_message_id") or str(next(self._counter))
|
||||
return (
|
||||
workflow_name,
|
||||
"v1",
|
||||
business_key,
|
||||
runtime.get("queue_task_id"),
|
||||
payload.get("inbox_message", {}).get("id"),
|
||||
None,
|
||||
"running",
|
||||
self._connection_factory.dumps(snapshot),
|
||||
runtime.get("email_trace_id"),
|
||||
)
|
||||
|
||||
def _use_memory(self) -> bool:
|
||||
return self._connection_factory is None or not self._connection_factory.is_configured()
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app_runtime.workflow.engine.workflow_engine import WorkflowEngine
|
||||
from app_runtime.workflow.persistence import WorkflowPersistence
|
||||
|
||||
|
||||
class WorkflowRuntimeFactory:
|
||||
def __init__(self, connection_factory=None, *, traces, hooks=None) -> None:
|
||||
self._connection_factory = connection_factory
|
||||
self._traces = traces
|
||||
self._hooks = hooks
|
||||
|
||||
def create_engine(self, workflow) -> WorkflowEngine:
|
||||
persistence = WorkflowPersistence.create_default(self._connection_factory)
|
||||
return WorkflowEngine(workflow, persistence, traces=self._traces, hooks=self._hooks)
|
||||
+19
-8
@@ -5,9 +5,6 @@ from plba.contracts import (
|
||||
ApplicationModule,
|
||||
ConfigProvider,
|
||||
HealthContributor,
|
||||
Task,
|
||||
TaskHandler,
|
||||
TaskQueue,
|
||||
TraceContext,
|
||||
TraceContextRecord,
|
||||
TraceLogMessage,
|
||||
@@ -21,7 +18,17 @@ from plba.health import HealthRegistry
|
||||
from plba.logging import LogManager
|
||||
from plba.queue import InMemoryTaskQueue
|
||||
from plba.tracing import MySqlTraceTransport, NoOpTraceTransport, TraceService
|
||||
from plba.workers import QueueWorker, WorkerSupervisor
|
||||
from plba.workflow import (
|
||||
StepResult,
|
||||
WorkflowContext,
|
||||
WorkflowDefinition,
|
||||
WorkflowEngine,
|
||||
WorkflowEngineHooks,
|
||||
WorkflowNode,
|
||||
WorkflowRuntimeFactory,
|
||||
WorkflowStep,
|
||||
)
|
||||
from plba.workers import WorkerSupervisor
|
||||
|
||||
__all__ = [
|
||||
"ApplicationModule",
|
||||
@@ -40,19 +47,23 @@ __all__ = [
|
||||
"LogManager",
|
||||
"MySqlTraceTransport",
|
||||
"NoOpTraceTransport",
|
||||
"QueueWorker",
|
||||
"RuntimeManager",
|
||||
"ServiceContainer",
|
||||
"Task",
|
||||
"TaskHandler",
|
||||
"TaskQueue",
|
||||
"TraceContext",
|
||||
"TraceContextRecord",
|
||||
"TraceLogMessage",
|
||||
"TraceService",
|
||||
"TraceTransport",
|
||||
"StepResult",
|
||||
"Worker",
|
||||
"WorkerHealth",
|
||||
"WorkerStatus",
|
||||
"WorkflowContext",
|
||||
"WorkflowDefinition",
|
||||
"WorkflowEngine",
|
||||
"WorkflowEngineHooks",
|
||||
"WorkflowNode",
|
||||
"WorkflowRuntimeFactory",
|
||||
"WorkflowStep",
|
||||
"WorkerSupervisor",
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,8 +1,6 @@
|
||||
from app_runtime.contracts.application import ApplicationModule
|
||||
from app_runtime.contracts.config import ConfigProvider
|
||||
from app_runtime.contracts.health import HealthContributor
|
||||
from app_runtime.contracts.queue import TaskQueue
|
||||
from app_runtime.contracts.tasks import Task, TaskHandler
|
||||
from app_runtime.contracts.trace import (
|
||||
TraceContext,
|
||||
TraceContextRecord,
|
||||
@@ -15,9 +13,6 @@ __all__ = [
|
||||
"ApplicationModule",
|
||||
"ConfigProvider",
|
||||
"HealthContributor",
|
||||
"Task",
|
||||
"TaskHandler",
|
||||
"TaskQueue",
|
||||
"TraceContext",
|
||||
"TraceContextRecord",
|
||||
"TraceLogMessage",
|
||||
|
||||
+1
-2
@@ -1,4 +1,3 @@
|
||||
from app_runtime.workers.queue_worker import QueueWorker
|
||||
from app_runtime.workers.supervisor import WorkerSupervisor
|
||||
|
||||
__all__ = ["QueueWorker", "WorkerSupervisor"]
|
||||
__all__ = ["WorkerSupervisor"]
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
from app_runtime.workflow.contracts.context import WorkflowContext
|
||||
from app_runtime.workflow.contracts.result import StepResult
|
||||
from app_runtime.workflow.contracts.step import WorkflowStep
|
||||
from app_runtime.workflow.contracts.workflow import WorkflowDefinition, WorkflowNode
|
||||
from app_runtime.workflow.engine.hooks import WorkflowEngineHooks
|
||||
from app_runtime.workflow.engine.workflow_engine import WorkflowEngine
|
||||
from app_runtime.workflow.runtime_factory import WorkflowRuntimeFactory
|
||||
|
||||
__all__ = [
|
||||
"StepResult",
|
||||
"WorkflowContext",
|
||||
"WorkflowDefinition",
|
||||
"WorkflowEngine",
|
||||
"WorkflowEngineHooks",
|
||||
"WorkflowNode",
|
||||
"WorkflowRuntimeFactory",
|
||||
"WorkflowStep",
|
||||
]
|
||||
Reference in New Issue
Block a user