Files
RagAgent/src/rag_agent/cli.py

224 lines
6.7 KiB
Python

from __future__ import annotations
import argparse
import hashlib
from pathlib import Path
from rag_agent.config import load_config
from rag_agent.ingest.chunker import chunk_text_by_lines
from rag_agent.ingest.file_loader import iter_text_files
from rag_agent.ingest.git_watcher import (
filter_existing,
filter_removed,
get_changed_files,
get_merge_base,
read_file_at_ref,
)
from rag_agent.index.embeddings import get_embedding_client
from rag_agent.index.postgres import (
connect,
delete_document,
ensure_schema,
get_or_create_story,
get_story_id,
replace_chunks,
update_story_indexed_range,
upsert_document,
)
from rag_agent.agent.pipeline import answer_query, get_llm_client
def _file_version(path: Path) -> str:
stat = path.stat()
payload = f"{path.as_posix()}:{stat.st_mtime_ns}:{stat.st_size}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def cmd_index(args: argparse.Namespace) -> None:
config = load_config()
conn = connect(config.db_dsn)
ensure_schema(conn, config.embeddings_dim)
story_id = get_or_create_story(conn, args.story)
embedding_client = get_embedding_client(config.embeddings_dim)
base_ref = args.base_ref.strip()
head_ref = args.head_ref.strip()
if args.changed:
if base_ref.lower() == "auto":
base_ref = get_merge_base(
config.repo_path, args.default_branch, head_ref
) or args.default_branch
changed_files = get_changed_files(
config.repo_path, base_ref, head_ref
)
removed = filter_removed(changed_files)
existing = filter_existing(changed_files)
else:
removed = []
repo_path = Path(config.repo_path)
if not repo_path.exists():
raise SystemExit(f"RAG_REPO_PATH does not exist: {config.repo_path}")
existing = [p for p in repo_path.rglob("*") if p.is_file()]
allowed = list(iter_text_files(existing, config.allowed_extensions))
print(
f"repo={config.repo_path} all_files={len(existing)} "
f"allowed={len(allowed)} ext={config.allowed_extensions}"
)
if not allowed:
print("No files to index (check path and extensions .md, .txt, .rst)")
return
for path in removed:
delete_document(conn, story_id, str(path))
indexed = 0
for path, text in iter_text_files(existing, config.allowed_extensions):
chunks = chunk_text_by_lines(
text,
config.chunk_size_lines,
config.chunk_overlap_lines,
)
if not chunks:
continue
base_chunks = None
if args.changed:
base_text = read_file_at_ref(config.repo_path, path, base_ref)
if base_text is not None:
base_chunks = chunk_text_by_lines(
base_text,
config.chunk_size_lines,
config.chunk_overlap_lines,
)
embeddings = embedding_client.embed_texts(
[chunk.text for chunk in chunks]
)
document_id = upsert_document(
conn, story_id, str(path), _file_version(path)
)
replace_chunks(
conn, document_id, chunks, embeddings, base_chunks=base_chunks
)
indexed += 1
print(f"Indexed {indexed} documents for story={args.story}")
if args.changed:
update_story_indexed_range(conn, story_id, base_ref, head_ref)
def cmd_bot(args: argparse.Namespace) -> None:
from rag_agent.telegram_bot import main as bot_main
bot_main()
def cmd_serve(args: argparse.Namespace) -> None:
import uvicorn
uvicorn.run(
"rag_agent.webhook:app",
host=args.host,
port=args.port,
log_level="info",
)
def cmd_ask(args: argparse.Namespace) -> None:
config = load_config()
conn = connect(config.db_dsn)
ensure_schema(conn, config.embeddings_dim)
story_id = None
if args.story:
story_id = get_story_id(conn, args.story)
if story_id is None:
raise SystemExit(f"Story not found: {args.story}")
embedding_client = get_embedding_client(config.embeddings_dim)
llm_client = get_llm_client(config)
answer = answer_query(
conn,
config,
embedding_client,
llm_client,
args.question,
top_k=args.top_k,
story_id=story_id,
)
print(answer)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="rag-agent")
sub = parser.add_subparsers(dest="command", required=True)
index_parser = sub.add_parser("index", help="Index files into Postgres")
index_parser.add_argument(
"--story",
required=True,
help="Story slug (e.g. branch name or story id); documents are tied to this story",
)
index_parser.add_argument(
"--changed",
action="store_true",
help="Index only files changed in the story range (base-ref..head-ref); all commits in range belong to the story",
)
index_parser.add_argument(
"--base-ref",
default="main",
help="Start of story range (e.g. main). Use 'auto' for merge-base(default-branch, head-ref). All commits from base to head are the story.",
)
index_parser.add_argument(
"--head-ref",
default="HEAD",
help="End of story range (e.g. current branch tip)",
)
index_parser.add_argument(
"--default-branch",
default="main",
help="Default branch name for --base-ref auto",
)
index_parser.set_defaults(func=cmd_index)
ask_parser = sub.add_parser("ask", help="Ask a question")
ask_parser.add_argument("question", help="Question text")
ask_parser.add_argument(
"--story",
default=None,
help="Limit retrieval to this story slug (optional)",
)
ask_parser.add_argument("--top-k", type=int, default=5, help="Top K chunks to retrieve")
ask_parser.set_defaults(func=cmd_ask)
serve_parser = sub.add_parser(
"serve",
help="Run webhook server: on push to remote repo, pull and index changes",
)
serve_parser.add_argument(
"--host",
default="0.0.0.0",
help="Bind host (default: 0.0.0.0)",
)
serve_parser.add_argument(
"--port",
type=int,
default=8000,
help="Bind port (default: 8000)",
)
serve_parser.set_defaults(func=cmd_serve)
bot_parser = sub.add_parser(
"bot",
help="Run Telegram bot: answers questions using RAG (requires TELEGRAM_BOT_TOKEN)",
)
bot_parser.set_defaults(func=cmd_bot)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()