52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Protocol
|
|
|
|
import psycopg
|
|
|
|
from rag_agent.config import AppConfig
|
|
from rag_agent.index.embeddings import EmbeddingClient
|
|
from rag_agent.retrieval.search import search_similar
|
|
|
|
|
|
class LLMClient(Protocol):
|
|
def generate(self, prompt: str, model: str) -> str:
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class StubLLMClient:
|
|
def generate(self, prompt: str, model: str) -> str:
|
|
return (
|
|
"LLM client is not configured. "
|
|
"Replace StubLLMClient with a real implementation."
|
|
)
|
|
|
|
|
|
def build_prompt(question: str, contexts: list[str]) -> str:
|
|
joined = "\n\n".join(contexts)
|
|
return (
|
|
"You are a RAG assistant. Use the context below to answer the question.\n\n"
|
|
f"Context:\n{joined}\n\n"
|
|
f"Question: {question}\nAnswer:"
|
|
)
|
|
|
|
|
|
def answer_query(
|
|
conn: psycopg.Connection,
|
|
config: AppConfig,
|
|
embedding_client: EmbeddingClient,
|
|
llm_client: LLMClient,
|
|
question: str,
|
|
top_k: int = 5,
|
|
story_id: int | None = None,
|
|
) -> str:
|
|
query_embedding = embedding_client.embed_texts([question])[0]
|
|
results = search_similar(
|
|
conn, query_embedding, top_k=top_k, story_id=story_id
|
|
)
|
|
contexts = [f"Source: {item.path}\n{item.content}" for item in results]
|
|
prompt = build_prompt(question, contexts)
|
|
return llm_client.generate(prompt, model=config.llm_model)
|