Скелет проекта
This commit is contained in:
117
src/rag_agent/cli.py
Normal file
117
src/rag_agent/cli.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from rag_agent.config import load_config
|
||||
from rag_agent.ingest.chunker import chunk_text
|
||||
from rag_agent.ingest.file_loader import iter_text_files
|
||||
from rag_agent.ingest.git_watcher import filter_existing, filter_removed, get_changed_files
|
||||
from rag_agent.index.embeddings import get_embedding_client
|
||||
from rag_agent.index.postgres import (
|
||||
connect,
|
||||
delete_document,
|
||||
ensure_schema,
|
||||
get_or_create_story,
|
||||
get_story_id,
|
||||
replace_chunks,
|
||||
upsert_document,
|
||||
)
|
||||
from rag_agent.agent.pipeline import StubLLMClient, answer_query
|
||||
|
||||
|
||||
def _file_version(path: Path) -> str:
|
||||
stat = path.stat()
|
||||
payload = f"{path.as_posix()}:{stat.st_mtime_ns}:{stat.st_size}"
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def cmd_index(args: argparse.Namespace) -> None:
|
||||
config = load_config()
|
||||
conn = connect(config.db_dsn)
|
||||
ensure_schema(conn, config.embeddings_dim)
|
||||
story_id = get_or_create_story(conn, args.story)
|
||||
embedding_client = get_embedding_client(config.embeddings_dim)
|
||||
|
||||
if args.changed:
|
||||
changed_files = get_changed_files(config.repo_path, args.base_ref, args.head_ref)
|
||||
removed = filter_removed(changed_files)
|
||||
existing = filter_existing(changed_files)
|
||||
else:
|
||||
removed = []
|
||||
existing = [path for path in Path(config.repo_path).rglob("*") if path.is_file()]
|
||||
|
||||
for path in removed:
|
||||
delete_document(conn, story_id, str(path))
|
||||
|
||||
for path, text in iter_text_files(existing, config.allowed_extensions):
|
||||
chunks = chunk_text(text, config.chunk_size, config.chunk_overlap)
|
||||
if not chunks:
|
||||
continue
|
||||
embeddings = embedding_client.embed_texts([chunk.text for chunk in chunks])
|
||||
document_id = upsert_document(
|
||||
conn, story_id, str(path), _file_version(path)
|
||||
)
|
||||
replace_chunks(conn, document_id, chunks, embeddings)
|
||||
|
||||
|
||||
def cmd_ask(args: argparse.Namespace) -> None:
|
||||
config = load_config()
|
||||
conn = connect(config.db_dsn)
|
||||
ensure_schema(conn, config.embeddings_dim)
|
||||
story_id = None
|
||||
if args.story:
|
||||
story_id = get_story_id(conn, args.story)
|
||||
if story_id is None:
|
||||
raise SystemExit(f"Story not found: {args.story}")
|
||||
embedding_client = get_embedding_client(config.embeddings_dim)
|
||||
llm_client = StubLLMClient()
|
||||
answer = answer_query(
|
||||
conn,
|
||||
config,
|
||||
embedding_client,
|
||||
llm_client,
|
||||
args.question,
|
||||
top_k=args.top_k,
|
||||
story_id=story_id,
|
||||
)
|
||||
print(answer)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="rag-agent")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
index_parser = sub.add_parser("index", help="Index files into Postgres")
|
||||
index_parser.add_argument(
|
||||
"--story",
|
||||
required=True,
|
||||
help="Story slug (e.g. branch name or story id); documents are tied to this story",
|
||||
)
|
||||
index_parser.add_argument("--changed", action="store_true", help="Index only changed files")
|
||||
index_parser.add_argument("--base-ref", default="HEAD~1", help="Base git ref")
|
||||
index_parser.add_argument("--head-ref", default="HEAD", help="Head git ref")
|
||||
index_parser.set_defaults(func=cmd_index)
|
||||
|
||||
ask_parser = sub.add_parser("ask", help="Ask a question")
|
||||
ask_parser.add_argument("question", help="Question text")
|
||||
ask_parser.add_argument(
|
||||
"--story",
|
||||
default=None,
|
||||
help="Limit retrieval to this story slug (optional)",
|
||||
)
|
||||
ask_parser.add_argument("--top-k", type=int, default=5, help="Top K chunks to retrieve")
|
||||
ask_parser.set_defaults(func=cmd_ask)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user