Compare commits
1 Commits
304e4dae6d
...
95715dcae7
| Author | SHA1 | Date | |
|---|---|---|---|
| 95715dcae7 |
11
.env
Normal file
11
.env
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
RAG_DB_DSN=postgresql://rag:rag_secret@localhost:5432/rag
|
||||||
|
RAG_REPO_PATH=/Users/alex/Dev_projects_v2/documentation/
|
||||||
|
|
||||||
|
GIGACHAT_CREDENTIALS=MGMyOGExMzctZDY1YS00OGNkLTk3NGYtYzFkZWVjOTEzM2RkOjFjOTc0YjFlLWNlMDUtNDM4Zi04ZDA2LWZkODA5MjRhZTY3NA==
|
||||||
|
GIGACHAT_EMBEDDINGS_MODEL=Embeddings
|
||||||
|
GIGACHAT_VERIFY_SSL=false
|
||||||
|
|
||||||
|
RAG_CHUNK_SIZE_LINES=20
|
||||||
|
RAG_CHUNK_SIZE=300
|
||||||
|
|
||||||
|
RAG_EMBEDDINGS_DIM=1024
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1 @@
|
|||||||
src/rag_agent/.env
|
src/rag_agent/.env
|
||||||
.env
|
|
||||||
docker/ssh
|
|
||||||
docker/postgres_test_data
|
|
||||||
16
README.md
16
README.md
@@ -123,21 +123,7 @@ Scripts: `scripts/create_db.py` (Python, uses `ensure_schema` and `RAG_*` env),
|
|||||||
|
|
||||||
If `GIGACHAT_CREDENTIALS` is set (e.g. in `.env` for local runs), embeddings use GigaChat API; otherwise the stub client is used. Optional env: `GIGACHAT_EMBEDDINGS_MODEL` (default `Embeddings`), `GIGACHAT_VERIFY_SSL` (`true`/`false`). Ensure `RAG_EMBEDDINGS_DIM` matches the model output (see GigaChat docs).
|
If `GIGACHAT_CREDENTIALS` is set (e.g. in `.env` for local runs), embeddings use GigaChat API; otherwise the stub client is used. Optional env: `GIGACHAT_EMBEDDINGS_MODEL` (default `Embeddings`), `GIGACHAT_VERIFY_SSL` (`true`/`false`). Ensure `RAG_EMBEDDINGS_DIM` matches the model output (see GigaChat docs).
|
||||||
|
|
||||||
## Agent (GigaChat)
|
|
||||||
|
|
||||||
Ответы на вопросы формирует агент на базе GigaChat: поиск по базе знаний (RAG) + генерация текста. Если задана переменная `GIGACHAT_CREDENTIALS`, используется `GigaChatLLMClient` в `src/rag_agent/agent/pipeline.py`; иначе — заглушка. Модель чата задаётся через `RAG_LLM_MODEL` (по умолчанию `GigaChat`).
|
|
||||||
|
|
||||||
## Telegram-бот
|
|
||||||
|
|
||||||
Общение с пользователем через бота в Telegram: бот отвечает на текстовые сообщения, используя знания из базы (RAG + GigaChat).
|
|
||||||
|
|
||||||
1. Создайте бота через [@BotFather](https://t.me/BotFather) и получите токен.
|
|
||||||
2. Добавьте в `.env`: `TELEGRAM_BOT_TOKEN=<токен>`.
|
|
||||||
3. Запуск: `rag-agent bot` (или `python -m rag_agent.telegram_bot`).
|
|
||||||
4. Через Docker: `docker compose up -d` поднимает БД, вебхук-сервер и бота в отдельных контейнерах; в `.env` должен быть задан `TELEGRAM_BOT_TOKEN`.
|
|
||||||
|
|
||||||
Требуются: `RAG_DB_DSN`, `RAG_REPO_PATH`, `GIGACHAT_CREDENTIALS`, `TELEGRAM_BOT_TOKEN`. Расширенное логирование (входящие сообщения, число эмбеддингов, число чанков из БД, ответ LLM): `RAG_BOT_VERBOSE_LOGGING=true|false` (по умолчанию `true` для отладки).
|
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
|
- LLM client is still a stub; replace it in `src/rag_agent/agent/pipeline.py` for real answers.
|
||||||
- This project requires Postgres with the `pgvector` extension.
|
- This project requires Postgres with the `pgvector` extension.
|
||||||
|
|||||||
@@ -58,31 +58,6 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- rag_net
|
- rag_net
|
||||||
|
|
||||||
bot:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
image: rag-agent:latest
|
|
||||||
container_name: rag-bot
|
|
||||||
restart: unless-stopped
|
|
||||||
depends_on:
|
|
||||||
postgres:
|
|
||||||
condition: service_healthy
|
|
||||||
environment:
|
|
||||||
RAG_DB_DSN: "postgresql://${POSTGRES_USER:-rag}:${POSTGRES_PASSWORD:-rag_secret}@postgres:5432/${POSTGRES_DB:-rag}"
|
|
||||||
RAG_REPO_PATH: "/data"
|
|
||||||
RAG_EMBEDDINGS_DIM: ${RAG_EMBEDDINGS_DIM:-1024}
|
|
||||||
GIGACHAT_CREDENTIALS: ${GIGACHAT_CREDENTIALS:-}
|
|
||||||
GIGACHAT_EMBEDDINGS_MODEL: ${GIGACHAT_EMBEDDINGS_MODEL:-Embeddings}
|
|
||||||
TELEGRAM_BOT_TOKEN: ${TELEGRAM_BOT_TOKEN:-}
|
|
||||||
RAG_BOT_VERBOSE_LOGGING: ${RAG_BOT_VERBOSE_LOGGING:-true}
|
|
||||||
volumes:
|
|
||||||
- ${RAG_REPO_HOST:-${RAG_REPO_PATH:-./data}}:/data
|
|
||||||
entrypoint: ["rag-agent"]
|
|
||||||
command: ["bot"]
|
|
||||||
networks:
|
|
||||||
- rag_net
|
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
rag_net:
|
rag_net:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ dependencies = [
|
|||||||
"gigachat>=0.2.0",
|
"gigachat>=0.2.0",
|
||||||
"fastapi>=0.115.0",
|
"fastapi>=0.115.0",
|
||||||
"uvicorn[standard]>=0.32.0",
|
"uvicorn[standard]>=0.32.0",
|
||||||
"python-telegram-bot>=21.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -1,21 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import psycopg
|
import psycopg
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from rag_agent.config import AppConfig
|
from rag_agent.config import AppConfig
|
||||||
from rag_agent.index.embeddings import EmbeddingClient
|
from rag_agent.index.embeddings import EmbeddingClient
|
||||||
from rag_agent.retrieval.search import search_similar
|
from rag_agent.retrieval.search import search_similar
|
||||||
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
|
||||||
load_dotenv(_repo_root / ".env")
|
|
||||||
|
|
||||||
|
|
||||||
class LLMClient(Protocol):
|
class LLMClient(Protocol):
|
||||||
def generate(self, prompt: str, model: str) -> str:
|
def generate(self, prompt: str, model: str) -> str:
|
||||||
@@ -27,49 +20,10 @@ class StubLLMClient:
|
|||||||
def generate(self, prompt: str, model: str) -> str:
|
def generate(self, prompt: str, model: str) -> str:
|
||||||
return (
|
return (
|
||||||
"LLM client is not configured. "
|
"LLM client is not configured. "
|
||||||
"Set GIGACHAT_CREDENTIALS in .env for GigaChat answers."
|
"Replace StubLLMClient with a real implementation."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GigaChatLLMClient:
|
|
||||||
"""LLM generation via GigaChat API. Credentials from env GIGACHAT_CREDENTIALS."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
credentials: str,
|
|
||||||
model: str = "GigaChat",
|
|
||||||
verify_ssl_certs: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self._credentials = credentials.strip()
|
|
||||||
self._model = model
|
|
||||||
self._verify_ssl_certs = verify_ssl_certs
|
|
||||||
|
|
||||||
def generate(self, prompt: str, model: str) -> str:
|
|
||||||
from gigachat import GigaChat
|
|
||||||
|
|
||||||
use_model = model or self._model
|
|
||||||
with GigaChat(
|
|
||||||
credentials=self._credentials,
|
|
||||||
model=use_model,
|
|
||||||
verify_ssl_certs=self._verify_ssl_certs,
|
|
||||||
) as giga:
|
|
||||||
response = giga.chat(prompt)
|
|
||||||
return (response.choices[0].message.content or "").strip()
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm_client(config: AppConfig) -> LLMClient:
|
|
||||||
"""Return GigaChat LLM client if credentials set, else stub."""
|
|
||||||
credentials = os.getenv("GIGACHAT_CREDENTIALS", "").strip()
|
|
||||||
if credentials:
|
|
||||||
return GigaChatLLMClient(
|
|
||||||
credentials=credentials,
|
|
||||||
model=config.llm_model,
|
|
||||||
verify_ssl_certs=os.getenv("GIGACHAT_VERIFY_SSL", "false").lower()
|
|
||||||
in ("1", "true", "yes"),
|
|
||||||
)
|
|
||||||
return StubLLMClient()
|
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(question: str, contexts: list[str]) -> str:
|
def build_prompt(question: str, contexts: list[str]) -> str:
|
||||||
joined = "\n\n".join(contexts)
|
joined = "\n\n".join(contexts)
|
||||||
return (
|
return (
|
||||||
@@ -88,32 +42,10 @@ def answer_query(
|
|||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
story_id: int | None = None,
|
story_id: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
answer, _ = answer_query_with_stats(
|
query_embedding = embedding_client.embed_texts([question])[0]
|
||||||
conn, config, embedding_client, llm_client, question, top_k, story_id
|
|
||||||
)
|
|
||||||
return answer
|
|
||||||
|
|
||||||
|
|
||||||
def answer_query_with_stats(
|
|
||||||
conn: psycopg.Connection,
|
|
||||||
config: AppConfig,
|
|
||||||
embedding_client: EmbeddingClient,
|
|
||||||
llm_client: LLMClient,
|
|
||||||
question: str,
|
|
||||||
top_k: int = 5,
|
|
||||||
story_id: int | None = None,
|
|
||||||
) -> tuple[str, dict]:
|
|
||||||
"""Like answer_query but returns (answer, stats) for logging. stats: query_embeddings, chunks_found, answer."""
|
|
||||||
query_embeddings = embedding_client.embed_texts([question])
|
|
||||||
results = search_similar(
|
results = search_similar(
|
||||||
conn, query_embeddings[0], top_k=top_k, story_id=story_id
|
conn, query_embedding, top_k=top_k, story_id=story_id
|
||||||
)
|
)
|
||||||
contexts = [f"Source: {item.path}\n{item.content}" for item in results]
|
contexts = [f"Source: {item.path}\n{item.content}" for item in results]
|
||||||
prompt = build_prompt(question, contexts)
|
prompt = build_prompt(question, contexts)
|
||||||
answer = llm_client.generate(prompt, model=config.llm_model)
|
return llm_client.generate(prompt, model=config.llm_model)
|
||||||
stats = {
|
|
||||||
"query_embeddings": len(query_embeddings),
|
|
||||||
"chunks_found": len(results),
|
|
||||||
"answer": answer,
|
|
||||||
}
|
|
||||||
return answer, stats
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from rag_agent.index.postgres import (
|
|||||||
update_story_indexed_range,
|
update_story_indexed_range,
|
||||||
upsert_document,
|
upsert_document,
|
||||||
)
|
)
|
||||||
from rag_agent.agent.pipeline import answer_query, get_llm_client
|
from rag_agent.agent.pipeline import StubLLMClient, answer_query
|
||||||
|
|
||||||
|
|
||||||
def _file_version(path: Path) -> str:
|
def _file_version(path: Path) -> str:
|
||||||
@@ -55,24 +55,13 @@ def cmd_index(args: argparse.Namespace) -> None:
|
|||||||
existing = filter_existing(changed_files)
|
existing = filter_existing(changed_files)
|
||||||
else:
|
else:
|
||||||
removed = []
|
removed = []
|
||||||
repo_path = Path(config.repo_path)
|
existing = [
|
||||||
if not repo_path.exists():
|
p for p in Path(config.repo_path).rglob("*") if p.is_file()
|
||||||
raise SystemExit(f"RAG_REPO_PATH does not exist: {config.repo_path}")
|
]
|
||||||
existing = [p for p in repo_path.rglob("*") if p.is_file()]
|
|
||||||
|
|
||||||
allowed = list(iter_text_files(existing, config.allowed_extensions))
|
|
||||||
print(
|
|
||||||
f"repo={config.repo_path} all_files={len(existing)} "
|
|
||||||
f"allowed={len(allowed)} ext={config.allowed_extensions}"
|
|
||||||
)
|
|
||||||
if not allowed:
|
|
||||||
print("No files to index (check path and extensions .md, .txt, .rst)")
|
|
||||||
return
|
|
||||||
|
|
||||||
for path in removed:
|
for path in removed:
|
||||||
delete_document(conn, story_id, str(path))
|
delete_document(conn, story_id, str(path))
|
||||||
|
|
||||||
indexed = 0
|
|
||||||
for path, text in iter_text_files(existing, config.allowed_extensions):
|
for path, text in iter_text_files(existing, config.allowed_extensions):
|
||||||
chunks = chunk_text_by_lines(
|
chunks = chunk_text_by_lines(
|
||||||
text,
|
text,
|
||||||
@@ -99,18 +88,11 @@ def cmd_index(args: argparse.Namespace) -> None:
|
|||||||
replace_chunks(
|
replace_chunks(
|
||||||
conn, document_id, chunks, embeddings, base_chunks=base_chunks
|
conn, document_id, chunks, embeddings, base_chunks=base_chunks
|
||||||
)
|
)
|
||||||
indexed += 1
|
|
||||||
|
|
||||||
print(f"Indexed {indexed} documents for story={args.story}")
|
|
||||||
if args.changed:
|
if args.changed:
|
||||||
update_story_indexed_range(conn, story_id, base_ref, head_ref)
|
update_story_indexed_range(conn, story_id, base_ref, head_ref)
|
||||||
|
|
||||||
|
|
||||||
def cmd_bot(args: argparse.Namespace) -> None:
|
|
||||||
from rag_agent.telegram_bot import main as bot_main
|
|
||||||
bot_main()
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_serve(args: argparse.Namespace) -> None:
|
def cmd_serve(args: argparse.Namespace) -> None:
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
@@ -131,7 +113,7 @@ def cmd_ask(args: argparse.Namespace) -> None:
|
|||||||
if story_id is None:
|
if story_id is None:
|
||||||
raise SystemExit(f"Story not found: {args.story}")
|
raise SystemExit(f"Story not found: {args.story}")
|
||||||
embedding_client = get_embedding_client(config.embeddings_dim)
|
embedding_client = get_embedding_client(config.embeddings_dim)
|
||||||
llm_client = get_llm_client(config)
|
llm_client = StubLLMClient()
|
||||||
answer = answer_query(
|
answer = answer_query(
|
||||||
conn,
|
conn,
|
||||||
config,
|
config,
|
||||||
@@ -203,12 +185,6 @@ def build_parser() -> argparse.ArgumentParser:
|
|||||||
)
|
)
|
||||||
serve_parser.set_defaults(func=cmd_serve)
|
serve_parser.set_defaults(func=cmd_serve)
|
||||||
|
|
||||||
bot_parser = sub.add_parser(
|
|
||||||
"bot",
|
|
||||||
help="Run Telegram bot: answers questions using RAG (requires TELEGRAM_BOT_TOKEN)",
|
|
||||||
)
|
|
||||||
bot_parser.set_defaults(func=cmd_bot)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -220,4 +196,3 @@ def main() -> None:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ def load_config() -> AppConfig:
|
|||||||
chunk_overlap_lines=_env_int("RAG_CHUNK_OVERLAP_LINES", 8),
|
chunk_overlap_lines=_env_int("RAG_CHUNK_OVERLAP_LINES", 8),
|
||||||
embeddings_dim=_env_int("RAG_EMBEDDINGS_DIM", 1024), # GigaChat Embeddings = 1024; OpenAI = 1536
|
embeddings_dim=_env_int("RAG_EMBEDDINGS_DIM", 1024), # GigaChat Embeddings = 1024; OpenAI = 1536
|
||||||
embeddings_model=os.getenv("RAG_EMBEDDINGS_MODEL", "stub-embeddings"),
|
embeddings_model=os.getenv("RAG_EMBEDDINGS_MODEL", "stub-embeddings"),
|
||||||
llm_model=os.getenv("RAG_LLM_MODEL", "GigaChat"),
|
llm_model=os.getenv("RAG_LLM_MODEL", "stub-llm"),
|
||||||
allowed_extensions=tuple(
|
allowed_extensions=tuple(
|
||||||
_env_list("RAG_ALLOWED_EXTENSIONS", [".md", ".txt", ".rst"])
|
_env_list("RAG_ALLOWED_EXTENSIONS", [".md", ".txt", ".rst"])
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from datetime import datetime, timezone
|
|||||||
from typing import Iterable, Sequence
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
import psycopg
|
import psycopg
|
||||||
from pgvector import Vector
|
|
||||||
from pgvector.psycopg import register_vector
|
from pgvector.psycopg import register_vector
|
||||||
|
|
||||||
from rag_agent.ingest.chunker import TextChunk
|
from rag_agent.ingest.chunker import TextChunk
|
||||||
@@ -114,20 +113,16 @@ def ensure_schema(conn: psycopg.Connection, embeddings_dim: int) -> None:
|
|||||||
except psycopg.ProgrammingError:
|
except psycopg.ProgrammingError:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
pass
|
pass
|
||||||
cur.execute(
|
try:
|
||||||
"""
|
cur.execute(
|
||||||
DO $$
|
"""
|
||||||
BEGIN
|
|
||||||
IF NOT EXISTS (
|
|
||||||
SELECT 1 FROM pg_constraint
|
|
||||||
WHERE conrelid = 'chunks'::regclass AND conname = 'chunks_change_type_check'
|
|
||||||
) THEN
|
|
||||||
ALTER TABLE chunks ADD CONSTRAINT chunks_change_type_check
|
ALTER TABLE chunks ADD CONSTRAINT chunks_change_type_check
|
||||||
CHECK (change_type IN ('added', 'modified', 'unchanged'));
|
CHECK (change_type IN ('added', 'modified', 'unchanged'));
|
||||||
END IF;
|
"""
|
||||||
END $$;
|
)
|
||||||
"""
|
except psycopg.ProgrammingError:
|
||||||
)
|
conn.rollback()
|
||||||
|
pass # constraint may already exist
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
CREATE INDEX IF NOT EXISTS idx_documents_story_id
|
CREATE INDEX IF NOT EXISTS idx_documents_story_id
|
||||||
@@ -348,7 +343,6 @@ def fetch_similar(
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
story_id: int | None = None,
|
story_id: int | None = None,
|
||||||
) -> list[tuple[str, str, float]]:
|
) -> list[tuple[str, str, float]]:
|
||||||
vec = Vector(query_embedding)
|
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
if story_id is not None:
|
if story_id is not None:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@@ -360,7 +354,7 @@ def fetch_similar(
|
|||||||
ORDER BY c.embedding <=> %s
|
ORDER BY c.embedding <=> %s
|
||||||
LIMIT %s;
|
LIMIT %s;
|
||||||
""",
|
""",
|
||||||
(vec, story_id, vec, top_k),
|
(query_embedding, story_id, query_embedding, top_k),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@@ -371,7 +365,7 @@ def fetch_similar(
|
|||||||
ORDER BY c.embedding <=> %s
|
ORDER BY c.embedding <=> %s
|
||||||
LIMIT %s;
|
LIMIT %s;
|
||||||
""",
|
""",
|
||||||
(vec, vec, top_k),
|
(query_embedding, query_embedding, top_k),
|
||||||
)
|
)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
return [(row[0], row[1], row[2]) for row in rows]
|
return [(row[0], row[1], row[2]) for row in rows]
|
||||||
|
|||||||
@@ -1,156 +0,0 @@
|
|||||||
"""Telegram bot: answers user questions using RAG (retrieval + GigaChat)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent
|
|
||||||
load_dotenv(_repo_root / ".env")
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Расширенное логирование: входящие сообщения, число эмбеддингов, число чанков из БД, ответ LLM.
|
|
||||||
# Включить/выключить: RAG_BOT_VERBOSE_LOGGING=true|false (по умолчанию true для отладки).
|
|
||||||
VERBOSE_LOGGING_MAX_ANSWER_CHARS = 500
|
|
||||||
|
|
||||||
|
|
||||||
def _verbose_logging_enabled() -> bool:
|
|
||||||
return os.getenv("RAG_BOT_VERBOSE_LOGGING", "true").lower() in ("1", "true", "yes")
|
|
||||||
|
|
||||||
|
|
||||||
def _run_rag(
|
|
||||||
question: str,
|
|
||||||
top_k: int = 5,
|
|
||||||
story_id: int | None = None,
|
|
||||||
with_stats: bool = False,
|
|
||||||
) -> str | tuple[str, dict]:
|
|
||||||
"""Synchronous RAG call: retrieval + LLM. Used from thread.
|
|
||||||
If with_stats=True, returns (answer, stats); else returns answer only.
|
|
||||||
"""
|
|
||||||
from rag_agent.config import load_config
|
|
||||||
from rag_agent.index.embeddings import get_embedding_client
|
|
||||||
from rag_agent.index.postgres import connect, ensure_schema
|
|
||||||
from rag_agent.agent.pipeline import answer_query, answer_query_with_stats, get_llm_client
|
|
||||||
|
|
||||||
config = load_config()
|
|
||||||
conn = connect(config.db_dsn)
|
|
||||||
try:
|
|
||||||
ensure_schema(conn, config.embeddings_dim)
|
|
||||||
embedding_client = get_embedding_client(config.embeddings_dim)
|
|
||||||
llm_client = get_llm_client(config)
|
|
||||||
if with_stats:
|
|
||||||
return answer_query_with_stats(
|
|
||||||
conn,
|
|
||||||
config,
|
|
||||||
embedding_client,
|
|
||||||
llm_client,
|
|
||||||
question,
|
|
||||||
top_k=top_k,
|
|
||||||
story_id=story_id,
|
|
||||||
)
|
|
||||||
return answer_query(
|
|
||||||
conn,
|
|
||||||
config,
|
|
||||||
embedding_client,
|
|
||||||
llm_client,
|
|
||||||
question,
|
|
||||||
top_k=top_k,
|
|
||||||
story_id=story_id,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def run_bot() -> None:
|
|
||||||
token = os.getenv("TELEGRAM_BOT_TOKEN", "").strip()
|
|
||||||
if not token:
|
|
||||||
logger.error(
|
|
||||||
"TELEGRAM_BOT_TOKEN is required. Set it in .env or environment. "
|
|
||||||
"Container will stay up; restart after setting the token."
|
|
||||||
)
|
|
||||||
import time
|
|
||||||
while True:
|
|
||||||
time.sleep(3600)
|
|
||||||
|
|
||||||
from telegram import Update
|
|
||||||
from telegram.ext import Application, ContextTypes, MessageHandler, filters
|
|
||||||
|
|
||||||
verbose = _verbose_logging_enabled()
|
|
||||||
|
|
||||||
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
|
||||||
if not update.message or not update.message.text:
|
|
||||||
return
|
|
||||||
question = update.message.text.strip()
|
|
||||||
if not question:
|
|
||||||
await update.message.reply_text("Напишите вопрос текстом.")
|
|
||||||
return
|
|
||||||
user_id = update.effective_user.id if update.effective_user else None
|
|
||||||
chat_id = update.effective_chat.id if update.effective_chat else None
|
|
||||||
if verbose:
|
|
||||||
logger.info(
|
|
||||||
"received message user_id=%s chat_id=%s text=%s",
|
|
||||||
user_id,
|
|
||||||
chat_id,
|
|
||||||
repr(question[:200] + ("…" if len(question) > 200 else "")),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
result = await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: _run_rag(
|
|
||||||
question,
|
|
||||||
top_k=5,
|
|
||||||
story_id=None,
|
|
||||||
with_stats=verbose,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if verbose:
|
|
||||||
answer, stats = result
|
|
||||||
logger.info(
|
|
||||||
"query_embeddings=%s chunks_found=%s",
|
|
||||||
stats["query_embeddings"],
|
|
||||||
stats["chunks_found"],
|
|
||||||
)
|
|
||||||
answer_preview = stats["answer"]
|
|
||||||
if len(answer_preview) > VERBOSE_LOGGING_MAX_ANSWER_CHARS:
|
|
||||||
answer_preview = (
|
|
||||||
answer_preview[:VERBOSE_LOGGING_MAX_ANSWER_CHARS] + "…"
|
|
||||||
)
|
|
||||||
logger.info("llm_response=%s", repr(answer_preview))
|
|
||||||
else:
|
|
||||||
answer = result
|
|
||||||
if len(answer) > 4096:
|
|
||||||
answer = answer[:4090] + "\n…"
|
|
||||||
await update.message.reply_text(answer)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("RAG error")
|
|
||||||
await update.message.reply_text(
|
|
||||||
f"Не удалось получить ответ: {e!s}. "
|
|
||||||
"Проверьте RAG_DB_DSN и GIGACHAT_CREDENTIALS."
|
|
||||||
)
|
|
||||||
|
|
||||||
app = Application.builder().token(token).build()
|
|
||||||
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
|
|
||||||
logger.info("Telegram bot started (polling)")
|
|
||||||
app.run_polling(drop_pending_updates=True)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
||||||
level=logging.INFO,
|
|
||||||
)
|
|
||||||
# Убрать из лога пустые HTTP-ответы polling (без сообщений от пользователя)
|
|
||||||
for name in ("telegram", "httpx", "httpcore"):
|
|
||||||
logging.getLogger(name).setLevel(logging.WARNING)
|
|
||||||
try:
|
|
||||||
run_bot()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
except ValueError as e:
|
|
||||||
raise SystemExit(e) from e
|
|
||||||
@@ -34,14 +34,6 @@ def _branch_from_ref(ref: str) -> str | None:
|
|||||||
return ref.removeprefix("refs/heads/")
|
return ref.removeprefix("refs/heads/")
|
||||||
|
|
||||||
|
|
||||||
# GitHub/GitLab send null SHA as "before" when a branch is first created.
|
|
||||||
_NULL_SHA = "0000000000000000000000000000000000000000"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_null_sha(sha: str | None) -> bool:
|
|
||||||
return sha is not None and sha == _NULL_SHA
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_github_signature(body: bytes, secret: str, signature_header: str | None) -> bool:
|
def _verify_github_signature(body: bytes, secret: str, signature_header: str | None) -> bool:
|
||||||
if not secret or not signature_header or not signature_header.startswith("sha256="):
|
if not secret or not signature_header or not signature_header.startswith("sha256="):
|
||||||
return not secret
|
return not secret
|
||||||
@@ -126,36 +118,9 @@ def _pull_and_index(
|
|||||||
logger.warning("git checkout %s failed: %s", branch, _decode_stderr(e.stderr))
|
logger.warning("git checkout %s failed: %s", branch, _decode_stderr(e.stderr))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Branch deletion: after is null SHA → nothing to index
|
|
||||||
if _is_null_sha(payload_after):
|
|
||||||
logger.info("webhook: branch deletion detected (after is null SHA) for branch=%s; skipping index", branch)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Prefer commit range from webhook payload (GitHub/GitLab before/after) so we index every push
|
# Prefer commit range from webhook payload (GitHub/GitLab before/after) so we index every push
|
||||||
# even when the clone is the same dir as the one that was pushed from (HEAD already at new commit).
|
# even when the clone is the same dir as the one that was pushed from (HEAD already at new commit).
|
||||||
if payload_before and payload_after and payload_before != payload_after:
|
if payload_before and payload_after and payload_before != payload_after:
|
||||||
# Update working tree to new commits (fetch only fetches refs; index reads files from disk)
|
|
||||||
origin_ref = f"origin/{branch}"
|
|
||||||
merge_proc = subprocess.run(
|
|
||||||
["git", "-C", repo_path, "merge", "--ff-only", origin_ref],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=60,
|
|
||||||
)
|
|
||||||
if merge_proc.returncode != 0:
|
|
||||||
logger.warning(
|
|
||||||
"webhook: git merge --ff-only failed (branch=%s); stderr=%s",
|
|
||||||
branch, _decode_stderr(merge_proc.stderr),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# New branch: before is null SHA → use auto (merge-base with default branch)
|
|
||||||
if _is_null_sha(payload_before):
|
|
||||||
logger.info(
|
|
||||||
"webhook: new branch detected (before is null SHA), using --base-ref auto story=%s head=%s",
|
|
||||||
branch, payload_after,
|
|
||||||
)
|
|
||||||
_run_index(repo_path, story=branch, base_ref="auto", head_ref=payload_after)
|
|
||||||
return
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"webhook: running index from payload story=%s %s..%s",
|
"webhook: running index from payload story=%s %s..%s",
|
||||||
branch, payload_before, payload_after,
|
branch, payload_before, payload_after,
|
||||||
|
|||||||
Reference in New Issue
Block a user