287 lines
12 KiB
Python
287 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.modules.agent.engine.router.schemas import RouterContext
|
|
from app.modules.shared.db import get_engine
|
|
|
|
|
|
class AgentRepository:
|
|
def ensure_tables(self) -> None:
|
|
with get_engine().connect() as conn:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS router_context (
|
|
conversation_key VARCHAR(64) PRIMARY KEY,
|
|
last_domain_id VARCHAR(64) NULL,
|
|
last_process_id VARCHAR(64) NULL,
|
|
active_domain_id VARCHAR(64) NULL,
|
|
active_process_id VARCHAR(64) NULL,
|
|
dialog_started BOOLEAN NOT NULL DEFAULT FALSE,
|
|
turn_index INTEGER NOT NULL DEFAULT 0,
|
|
message_history_json TEXT NOT NULL DEFAULT '[]',
|
|
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"""
|
|
)
|
|
)
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS agent_quality_metrics (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
task_id VARCHAR(64) NOT NULL,
|
|
dialog_session_id VARCHAR(64) NOT NULL,
|
|
rag_session_id VARCHAR(64) NOT NULL,
|
|
scenario VARCHAR(64) NOT NULL,
|
|
domain_id VARCHAR(64) NOT NULL,
|
|
process_id VARCHAR(64) NOT NULL,
|
|
faithfulness_score DOUBLE PRECISION NOT NULL,
|
|
coverage_score DOUBLE PRECISION NOT NULL,
|
|
faithfulness_claims_total INTEGER NOT NULL,
|
|
faithfulness_claims_supported INTEGER NOT NULL,
|
|
coverage_required_items INTEGER NOT NULL,
|
|
coverage_covered_items INTEGER NOT NULL,
|
|
quality_status VARCHAR(32) NOT NULL,
|
|
metrics_json JSONB NOT NULL,
|
|
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"""
|
|
)
|
|
)
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_agent_quality_metrics_task
|
|
ON agent_quality_metrics(task_id, created_at DESC)
|
|
"""
|
|
)
|
|
)
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_agent_quality_metrics_scenario
|
|
ON agent_quality_metrics(scenario, created_at DESC)
|
|
"""
|
|
)
|
|
)
|
|
self._ensure_router_context_columns(conn)
|
|
conn.commit()
|
|
|
|
def _ensure_router_context_columns(self, conn) -> None:
|
|
for statement in (
|
|
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS active_domain_id VARCHAR(64) NULL",
|
|
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS active_process_id VARCHAR(64) NULL",
|
|
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS dialog_started BOOLEAN NOT NULL DEFAULT FALSE",
|
|
"ALTER TABLE router_context ADD COLUMN IF NOT EXISTS turn_index INTEGER NOT NULL DEFAULT 0",
|
|
):
|
|
conn.execute(text(statement))
|
|
|
|
def get_router_context(self, conversation_key: str) -> RouterContext:
|
|
with get_engine().connect() as conn:
|
|
row = conn.execute(
|
|
text(
|
|
"""
|
|
SELECT last_domain_id, last_process_id, active_domain_id, active_process_id, dialog_started, turn_index, message_history_json
|
|
FROM router_context
|
|
WHERE conversation_key = :key
|
|
"""
|
|
),
|
|
{"key": conversation_key},
|
|
).fetchone()
|
|
|
|
if not row:
|
|
return RouterContext()
|
|
|
|
history_raw = row[6] or "[]"
|
|
try:
|
|
history = json.loads(history_raw)
|
|
except json.JSONDecodeError:
|
|
history = []
|
|
|
|
last = None
|
|
if row[0] and row[1]:
|
|
last = {"domain_id": str(row[0]), "process_id": str(row[1])}
|
|
active = None
|
|
if row[2] and row[3]:
|
|
active = {"domain_id": str(row[2]), "process_id": str(row[3])}
|
|
|
|
clean_history = []
|
|
for item in history if isinstance(history, list) else []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
role = str(item.get("role") or "")
|
|
content = str(item.get("content") or "")
|
|
if role in {"user", "assistant"} and content:
|
|
clean_history.append({"role": role, "content": content})
|
|
|
|
return RouterContext(
|
|
last_routing=last,
|
|
message_history=clean_history,
|
|
active_intent=active or last,
|
|
dialog_started=bool(row[4]),
|
|
turn_index=int(row[5] or 0),
|
|
)
|
|
|
|
def update_router_context(
|
|
self,
|
|
conversation_key: str,
|
|
*,
|
|
domain_id: str,
|
|
process_id: str,
|
|
user_message: str,
|
|
assistant_message: str,
|
|
decision_type: str,
|
|
max_history: int,
|
|
) -> None:
|
|
current = self.get_router_context(conversation_key)
|
|
history = list(current.message_history)
|
|
if user_message:
|
|
history.append({"role": "user", "content": user_message})
|
|
if assistant_message:
|
|
history.append({"role": "assistant", "content": assistant_message})
|
|
if max_history > 0:
|
|
history = history[-max_history:]
|
|
current_active = current.active_intent or current.last_routing or {"domain_id": domain_id, "process_id": process_id}
|
|
next_active = (
|
|
{"domain_id": domain_id, "process_id": process_id}
|
|
if decision_type in {"start", "switch"}
|
|
else current_active
|
|
)
|
|
next_turn_index = max(0, int(current.turn_index or 0)) + (1 if user_message else 0)
|
|
|
|
with get_engine().connect() as conn:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO router_context (
|
|
conversation_key, last_domain_id, last_process_id, active_domain_id, active_process_id,
|
|
dialog_started, turn_index, message_history_json
|
|
) VALUES (:key, :domain, :process, :active_domain, :active_process, :dialog_started, :turn_index, :history)
|
|
ON CONFLICT (conversation_key) DO UPDATE SET
|
|
last_domain_id = EXCLUDED.last_domain_id,
|
|
last_process_id = EXCLUDED.last_process_id,
|
|
active_domain_id = EXCLUDED.active_domain_id,
|
|
active_process_id = EXCLUDED.active_process_id,
|
|
dialog_started = EXCLUDED.dialog_started,
|
|
turn_index = EXCLUDED.turn_index,
|
|
message_history_json = EXCLUDED.message_history_json,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
"""
|
|
),
|
|
{
|
|
"key": conversation_key,
|
|
"domain": domain_id,
|
|
"process": process_id,
|
|
"active_domain": str(next_active["domain_id"]),
|
|
"active_process": str(next_active["process_id"]),
|
|
"dialog_started": True,
|
|
"turn_index": next_turn_index,
|
|
"history": json.dumps(history, ensure_ascii=False),
|
|
},
|
|
)
|
|
conn.commit()
|
|
|
|
def save_quality_metrics(
|
|
self,
|
|
*,
|
|
task_id: str,
|
|
dialog_session_id: str,
|
|
rag_session_id: str,
|
|
scenario: str,
|
|
domain_id: str,
|
|
process_id: str,
|
|
quality: dict,
|
|
) -> None:
|
|
faithfulness = quality.get("faithfulness", {}) if isinstance(quality, dict) else {}
|
|
coverage = quality.get("coverage", {}) if isinstance(quality, dict) else {}
|
|
status = str(quality.get("status", "unknown")) if isinstance(quality, dict) else "unknown"
|
|
with get_engine().connect() as conn:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO agent_quality_metrics (
|
|
task_id,
|
|
dialog_session_id,
|
|
rag_session_id,
|
|
scenario,
|
|
domain_id,
|
|
process_id,
|
|
faithfulness_score,
|
|
coverage_score,
|
|
faithfulness_claims_total,
|
|
faithfulness_claims_supported,
|
|
coverage_required_items,
|
|
coverage_covered_items,
|
|
quality_status,
|
|
metrics_json
|
|
) VALUES (
|
|
:task_id,
|
|
:dialog_session_id,
|
|
:rag_session_id,
|
|
:scenario,
|
|
:domain_id,
|
|
:process_id,
|
|
:faithfulness_score,
|
|
:coverage_score,
|
|
:faithfulness_claims_total,
|
|
:faithfulness_claims_supported,
|
|
:coverage_required_items,
|
|
:coverage_covered_items,
|
|
:quality_status,
|
|
CAST(:metrics_json AS JSONB)
|
|
)
|
|
"""
|
|
),
|
|
{
|
|
"task_id": task_id,
|
|
"dialog_session_id": dialog_session_id,
|
|
"rag_session_id": rag_session_id,
|
|
"scenario": scenario,
|
|
"domain_id": domain_id,
|
|
"process_id": process_id,
|
|
"faithfulness_score": float(faithfulness.get("score", 0.0) or 0.0),
|
|
"coverage_score": float(coverage.get("score", 0.0) or 0.0),
|
|
"faithfulness_claims_total": int(faithfulness.get("claims_total", 0) or 0),
|
|
"faithfulness_claims_supported": int(faithfulness.get("claims_supported", 0) or 0),
|
|
"coverage_required_items": int(coverage.get("required_count", 0) or 0),
|
|
"coverage_covered_items": int(coverage.get("covered_count", 0) or 0),
|
|
"quality_status": status,
|
|
"metrics_json": json.dumps(quality if isinstance(quality, dict) else {}, ensure_ascii=False),
|
|
},
|
|
)
|
|
conn.commit()
|
|
|
|
def get_quality_metrics(self, *, limit: int = 50, scenario: str | None = None) -> list[dict]:
|
|
query = """
|
|
SELECT
|
|
task_id,
|
|
dialog_session_id,
|
|
rag_session_id,
|
|
scenario,
|
|
domain_id,
|
|
process_id,
|
|
faithfulness_score,
|
|
coverage_score,
|
|
faithfulness_claims_total,
|
|
faithfulness_claims_supported,
|
|
coverage_required_items,
|
|
coverage_covered_items,
|
|
quality_status,
|
|
metrics_json,
|
|
created_at
|
|
FROM agent_quality_metrics
|
|
"""
|
|
params: dict = {"limit": max(1, int(limit))}
|
|
if scenario:
|
|
query += " WHERE scenario = :scenario"
|
|
params["scenario"] = scenario
|
|
query += " ORDER BY created_at DESC LIMIT :limit"
|
|
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(text(query), params).mappings().fetchall()
|
|
return [dict(row) for row in rows]
|