72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from uuid import uuid4
|
|
|
|
from app.modules.agent.llm import AgentLlmService
|
|
from app.modules.chat.evidence_gate import CodeExplainEvidenceGate
|
|
from app.modules.chat.session_resolver import ChatSessionResolver
|
|
from app.modules.chat.task_store import TaskState, TaskStore
|
|
from app.modules.rag.explain import CodeExplainRetrieverV2, PromptBudgeter
|
|
from app.schemas.chat import ChatMessageRequest, TaskQueuedResponse, TaskResultType, TaskStatus
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class CodeExplainChatService:
|
|
def __init__(
|
|
self,
|
|
retriever: CodeExplainRetrieverV2,
|
|
llm: AgentLlmService,
|
|
session_resolver: ChatSessionResolver,
|
|
task_store: TaskStore,
|
|
message_sink,
|
|
budgeter: PromptBudgeter | None = None,
|
|
evidence_gate: CodeExplainEvidenceGate | None = None,
|
|
) -> None:
|
|
self._retriever = retriever
|
|
self._llm = llm
|
|
self._session_resolver = session_resolver
|
|
self._task_store = task_store
|
|
self._message_sink = message_sink
|
|
self._budgeter = budgeter or PromptBudgeter()
|
|
self._evidence_gate = evidence_gate or CodeExplainEvidenceGate()
|
|
|
|
async def handle_message(self, request: ChatMessageRequest) -> TaskQueuedResponse:
|
|
dialog_session_id, rag_session_id = self._session_resolver.resolve(request)
|
|
task_id = str(uuid4())
|
|
task = TaskState(task_id=task_id, status=TaskStatus.RUNNING)
|
|
self._task_store.save(task)
|
|
self._message_sink(dialog_session_id, "user", request.message, task_id=task_id)
|
|
pack = self._retriever.build_pack(
|
|
rag_session_id,
|
|
request.message,
|
|
file_candidates=[item.model_dump(mode="json") for item in request.files],
|
|
)
|
|
decision = self._evidence_gate.evaluate(pack)
|
|
if decision.passed:
|
|
prompt_input = self._budgeter.build_prompt_input(request.message, pack)
|
|
answer = self._llm.generate(
|
|
"code_explain_answer_v2",
|
|
prompt_input,
|
|
log_context="chat.code_explain.direct",
|
|
).strip()
|
|
else:
|
|
answer = decision.answer
|
|
self._message_sink(dialog_session_id, "assistant", answer, task_id=task_id)
|
|
task.status = TaskStatus.DONE
|
|
task.result_type = TaskResultType.ANSWER
|
|
task.answer = answer
|
|
self._task_store.save(task)
|
|
LOGGER.warning(
|
|
"direct code explain response: task_id=%s rag_session_id=%s excerpts=%s missing=%s",
|
|
task_id,
|
|
rag_session_id,
|
|
len(pack.code_excerpts),
|
|
pack.missing,
|
|
)
|
|
return TaskQueuedResponse(
|
|
task_id=task_id,
|
|
status=TaskStatus.DONE.value,
|
|
)
|