from __future__ import annotations import argparse import hashlib from pathlib import Path from rag_agent.config import load_config from rag_agent.ingest.chunker import chunk_text_by_lines from rag_agent.ingest.file_loader import iter_text_files from rag_agent.ingest.git_watcher import ( filter_existing, filter_removed, get_changed_files, get_merge_base, read_file_at_ref, ) from rag_agent.index.embeddings import get_embedding_client from rag_agent.index.postgres import ( connect, delete_document, ensure_schema, get_or_create_story, get_story_id, replace_chunks, update_story_indexed_range, upsert_document, ) from rag_agent.agent.pipeline import StubLLMClient, answer_query def _file_version(path: Path) -> str: stat = path.stat() payload = f"{path.as_posix()}:{stat.st_mtime_ns}:{stat.st_size}" return hashlib.sha256(payload.encode("utf-8")).hexdigest() def cmd_index(args: argparse.Namespace) -> None: config = load_config() conn = connect(config.db_dsn) ensure_schema(conn, config.embeddings_dim) story_id = get_or_create_story(conn, args.story) embedding_client = get_embedding_client(config.embeddings_dim) base_ref = args.base_ref.strip() head_ref = args.head_ref.strip() if args.changed: if base_ref.lower() == "auto": base_ref = get_merge_base( config.repo_path, args.default_branch, head_ref ) or args.default_branch changed_files = get_changed_files( config.repo_path, base_ref, head_ref ) removed = filter_removed(changed_files) existing = filter_existing(changed_files) else: removed = [] existing = [ p for p in Path(config.repo_path).rglob("*") if p.is_file() ] for path in removed: delete_document(conn, story_id, str(path)) for path, text in iter_text_files(existing, config.allowed_extensions): chunks = chunk_text_by_lines( text, config.chunk_size_lines, config.chunk_overlap_lines, ) if not chunks: continue base_chunks = None if args.changed: base_text = read_file_at_ref(config.repo_path, path, base_ref) if base_text is not None: base_chunks = chunk_text_by_lines( base_text, config.chunk_size_lines, config.chunk_overlap_lines, ) embeddings = embedding_client.embed_texts( [chunk.text for chunk in chunks] ) document_id = upsert_document( conn, story_id, str(path), _file_version(path) ) replace_chunks( conn, document_id, chunks, embeddings, base_chunks=base_chunks ) if args.changed: update_story_indexed_range(conn, story_id, base_ref, head_ref) def cmd_serve(args: argparse.Namespace) -> None: import uvicorn uvicorn.run( "rag_agent.webhook:app", host=args.host, port=args.port, log_level="info", ) def cmd_ask(args: argparse.Namespace) -> None: config = load_config() conn = connect(config.db_dsn) ensure_schema(conn, config.embeddings_dim) story_id = None if args.story: story_id = get_story_id(conn, args.story) if story_id is None: raise SystemExit(f"Story not found: {args.story}") embedding_client = get_embedding_client(config.embeddings_dim) llm_client = StubLLMClient() answer = answer_query( conn, config, embedding_client, llm_client, args.question, top_k=args.top_k, story_id=story_id, ) print(answer) def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(prog="rag-agent") sub = parser.add_subparsers(dest="command", required=True) index_parser = sub.add_parser("index", help="Index files into Postgres") index_parser.add_argument( "--story", required=True, help="Story slug (e.g. branch name or story id); documents are tied to this story", ) index_parser.add_argument( "--changed", action="store_true", help="Index only files changed in the story range (base-ref..head-ref); all commits in range belong to the story", ) index_parser.add_argument( "--base-ref", default="main", help="Start of story range (e.g. main). Use 'auto' for merge-base(default-branch, head-ref). All commits from base to head are the story.", ) index_parser.add_argument( "--head-ref", default="HEAD", help="End of story range (e.g. current branch tip)", ) index_parser.add_argument( "--default-branch", default="main", help="Default branch name for --base-ref auto", ) index_parser.set_defaults(func=cmd_index) ask_parser = sub.add_parser("ask", help="Ask a question") ask_parser.add_argument("question", help="Question text") ask_parser.add_argument( "--story", default=None, help="Limit retrieval to this story slug (optional)", ) ask_parser.add_argument("--top-k", type=int, default=5, help="Top K chunks to retrieve") ask_parser.set_defaults(func=cmd_ask) serve_parser = sub.add_parser( "serve", help="Run webhook server: on push to remote repo, pull and index changes", ) serve_parser.add_argument( "--host", default="0.0.0.0", help="Bind host (default: 0.0.0.0)", ) serve_parser.add_argument( "--port", type=int, default=8000, help="Bind port (default: 8000)", ) serve_parser.set_defaults(func=cmd_serve) return parser def main() -> None: parser = build_parser() args = parser.parse_args() args.func(args) if __name__ == "__main__": main()