Files
agent/app/modules/agent/repository.py

287 lines
12 KiB
Python

from __future__ import annotations
import json
from sqlalchemy import text
from app.modules.agent.engine.router.schemas import RouterContext
from app.modules.shared.db import get_engine
class AgentRepository:
def ensure_tables(self) -> None:
with get_engine().connect() as conn:
conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS router_context (
conversation_key VARCHAR(64) PRIMARY KEY,
last_domain_id VARCHAR(64) NULL,
last_process_id VARCHAR(64) NULL,
active_domain_id VARCHAR(64) NULL,
active_process_id VARCHAR(64) NULL,
dialog_started BOOLEAN NOT NULL DEFAULT FALSE,
turn_index INTEGER NOT NULL DEFAULT 0,
message_history_json TEXT NOT NULL DEFAULT '[]',
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS agent_quality_metrics (
id BIGSERIAL PRIMARY KEY,
task_id VARCHAR(64) NOT NULL,
dialog_session_id VARCHAR(64) NOT NULL,
rag_session_id VARCHAR(64) NOT NULL,
scenario VARCHAR(64) NOT NULL,
domain_id VARCHAR(64) NOT NULL,
process_id VARCHAR(64) NOT NULL,
faithfulness_score DOUBLE PRECISION NOT NULL,
coverage_score DOUBLE PRECISION NOT NULL,
faithfulness_claims_total INTEGER NOT NULL,
faithfulness_claims_supported INTEGER NOT NULL,
coverage_required_items INTEGER NOT NULL,
coverage_covered_items INTEGER NOT NULL,
quality_status VARCHAR(32) NOT NULL,
metrics_json JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_agent_quality_metrics_task
ON agent_quality_metrics(task_id, created_at DESC)
"""
)
)
conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_agent_quality_metrics_scenario
ON agent_quality_metrics(scenario, created_at DESC)
"""
)
)
self._ensure_router_context_columns(conn)
conn.commit()
def _ensure_router_context_columns(self, conn) -> None:
for statement in (
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS active_domain_id VARCHAR(64) NULL",
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS active_process_id VARCHAR(64) NULL",
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS dialog_started BOOLEAN NOT NULL DEFAULT FALSE",
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS turn_index INTEGER NOT NULL DEFAULT 0",
):
conn.execute(text(statement))
def get_router_context(self, conversation_key: str) -> RouterContext:
with get_engine().connect() as conn:
row = conn.execute(
text(
"""
SELECT last_domain_id, last_process_id, active_domain_id, active_process_id, dialog_started, turn_index, message_history_json
FROM router_context
WHERE conversation_key = :key
"""
),
{"key": conversation_key},
).fetchone()
if not row:
return RouterContext()
history_raw = row[6] or "[]"
try:
history = json.loads(history_raw)
except json.JSONDecodeError:
history = []
last = None
if row[0] and row[1]:
last = {"domain_id": str(row[0]), "process_id": str(row[1])}
active = None
if row[2] and row[3]:
active = {"domain_id": str(row[2]), "process_id": str(row[3])}
clean_history = []
for item in history if isinstance(history, list) else []:
if not isinstance(item, dict):
continue
role = str(item.get("role") or "")
content = str(item.get("content") or "")
if role in {"user", "assistant"} and content:
clean_history.append({"role": role, "content": content})
return RouterContext(
last_routing=last,
message_history=clean_history,
active_intent=active or last,
dialog_started=bool(row[4]),
turn_index=int(row[5] or 0),
)
def update_router_context(
self,
conversation_key: str,
*,
domain_id: str,
process_id: str,
user_message: str,
assistant_message: str,
decision_type: str,
max_history: int,
) -> None:
current = self.get_router_context(conversation_key)
history = list(current.message_history)
if user_message:
history.append({"role": "user", "content": user_message})
if assistant_message:
history.append({"role": "assistant", "content": assistant_message})
if max_history > 0:
history = history[-max_history:]
current_active = current.active_intent or current.last_routing or {"domain_id": domain_id, "process_id": process_id}
next_active = (
{"domain_id": domain_id, "process_id": process_id}
if decision_type in {"start", "switch"}
else current_active
)
next_turn_index = max(0, int(current.turn_index or 0)) + (1 if user_message else 0)
with get_engine().connect() as conn:
conn.execute(
text(
"""
INSERT INTO router_context (
conversation_key, last_domain_id, last_process_id, active_domain_id, active_process_id,
dialog_started, turn_index, message_history_json
) VALUES (:key, :domain, :process, :active_domain, :active_process, :dialog_started, :turn_index, :history)
ON CONFLICT (conversation_key) DO UPDATE SET
last_domain_id = EXCLUDED.last_domain_id,
last_process_id = EXCLUDED.last_process_id,
active_domain_id = EXCLUDED.active_domain_id,
active_process_id = EXCLUDED.active_process_id,
dialog_started = EXCLUDED.dialog_started,
turn_index = EXCLUDED.turn_index,
message_history_json = EXCLUDED.message_history_json,
updated_at = CURRENT_TIMESTAMP
"""
),
{
"key": conversation_key,
"domain": domain_id,
"process": process_id,
"active_domain": str(next_active["domain_id"]),
"active_process": str(next_active["process_id"]),
"dialog_started": True,
"turn_index": next_turn_index,
"history": json.dumps(history, ensure_ascii=False),
},
)
conn.commit()
def save_quality_metrics(
self,
*,
task_id: str,
dialog_session_id: str,
rag_session_id: str,
scenario: str,
domain_id: str,
process_id: str,
quality: dict,
) -> None:
faithfulness = quality.get("faithfulness", {}) if isinstance(quality, dict) else {}
coverage = quality.get("coverage", {}) if isinstance(quality, dict) else {}
status = str(quality.get("status", "unknown")) if isinstance(quality, dict) else "unknown"
with get_engine().connect() as conn:
conn.execute(
text(
"""
INSERT INTO agent_quality_metrics (
task_id,
dialog_session_id,
rag_session_id,
scenario,
domain_id,
process_id,
faithfulness_score,
coverage_score,
faithfulness_claims_total,
faithfulness_claims_supported,
coverage_required_items,
coverage_covered_items,
quality_status,
metrics_json
) VALUES (
:task_id,
:dialog_session_id,
:rag_session_id,
:scenario,
:domain_id,
:process_id,
:faithfulness_score,
:coverage_score,
:faithfulness_claims_total,
:faithfulness_claims_supported,
:coverage_required_items,
:coverage_covered_items,
:quality_status,
CAST(:metrics_json AS JSONB)
)
"""
),
{
"task_id": task_id,
"dialog_session_id": dialog_session_id,
"rag_session_id": rag_session_id,
"scenario": scenario,
"domain_id": domain_id,
"process_id": process_id,
"faithfulness_score": float(faithfulness.get("score", 0.0) or 0.0),
"coverage_score": float(coverage.get("score", 0.0) or 0.0),
"faithfulness_claims_total": int(faithfulness.get("claims_total", 0) or 0),
"faithfulness_claims_supported": int(faithfulness.get("claims_supported", 0) or 0),
"coverage_required_items": int(coverage.get("required_count", 0) or 0),
"coverage_covered_items": int(coverage.get("covered_count", 0) or 0),
"quality_status": status,
"metrics_json": json.dumps(quality if isinstance(quality, dict) else {}, ensure_ascii=False),
},
)
conn.commit()
def get_quality_metrics(self, *, limit: int = 50, scenario: str | None = None) -> list[dict]:
query = """
SELECT
task_id,
dialog_session_id,
rag_session_id,
scenario,
domain_id,
process_id,
faithfulness_score,
coverage_score,
faithfulness_claims_total,
faithfulness_claims_supported,
coverage_required_items,
coverage_covered_items,
quality_status,
metrics_json,
created_at
FROM agent_quality_metrics
"""
params: dict = {"limit": max(1, int(limit))}
if scenario:
query += " WHERE scenario = :scenario"
params["scenario"] = scenario
query += " ORDER BY created_at DESC LIMIT :limit"
with get_engine().connect() as conn:
rows = conn.execute(text(query), params).mappings().fetchall()
return [dict(row) for row in rows]