Files
agent/app/modules/chat/direct_service.py

72 lines
2.7 KiB
Python

from __future__ import annotations
import logging
from uuid import uuid4
from app.modules.agent.llm import AgentLlmService
from app.modules.chat.evidence_gate import CodeExplainEvidenceGate
from app.modules.chat.session_resolver import ChatSessionResolver
from app.modules.chat.task_store import TaskState, TaskStore
from app.modules.rag.explain import CodeExplainRetrieverV2, PromptBudgeter
from app.schemas.chat import ChatMessageRequest, TaskQueuedResponse, TaskResultType, TaskStatus
LOGGER = logging.getLogger(__name__)
class CodeExplainChatService:
def __init__(
self,
retriever: CodeExplainRetrieverV2,
llm: AgentLlmService,
session_resolver: ChatSessionResolver,
task_store: TaskStore,
message_sink,
budgeter: PromptBudgeter | None = None,
evidence_gate: CodeExplainEvidenceGate | None = None,
) -> None:
self._retriever = retriever
self._llm = llm
self._session_resolver = session_resolver
self._task_store = task_store
self._message_sink = message_sink
self._budgeter = budgeter or PromptBudgeter()
self._evidence_gate = evidence_gate or CodeExplainEvidenceGate()
async def handle_message(self, request: ChatMessageRequest) -> TaskQueuedResponse:
dialog_session_id, rag_session_id = self._session_resolver.resolve(request)
task_id = str(uuid4())
task = TaskState(task_id=task_id, status=TaskStatus.RUNNING)
self._task_store.save(task)
self._message_sink(dialog_session_id, "user", request.message, task_id=task_id)
pack = self._retriever.build_pack(
rag_session_id,
request.message,
file_candidates=[item.model_dump(mode="json") for item in request.files],
)
decision = self._evidence_gate.evaluate(pack)
if decision.passed:
prompt_input = self._budgeter.build_prompt_input(request.message, pack)
answer = self._llm.generate(
"code_explain_answer_v2",
prompt_input,
log_context="chat.code_explain.direct",
).strip()
else:
answer = decision.answer
self._message_sink(dialog_session_id, "assistant", answer, task_id=task_id)
task.status = TaskStatus.DONE
task.result_type = TaskResultType.ANSWER
task.answer = answer
self._task_store.save(task)
LOGGER.warning(
"direct code explain response: task_id=%s rag_session_id=%s excerpts=%s missing=%s",
task_id,
rag_session_id,
len(pack.code_excerpts),
pack.missing,
)
return TaskQueuedResponse(
task_id=task_id,
status=TaskStatus.DONE.value,
)