Первый коммит

This commit is contained in:
2026-02-25 14:47:19 +03:00
commit 1e376aff24
170 changed files with 4893 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

View File

@@ -0,0 +1,9 @@
from app.modules.shared.gigachat.client import GigaChatClient
class GigaChatEmbedder:
def __init__(self, client: GigaChatClient) -> None:
self._client = client
def embed(self, texts: list[str]) -> list[list[float]]:
return self._client.embed(texts)

View File

@@ -0,0 +1,141 @@
import asyncio
from collections import defaultdict
from app.schemas.common import ErrorPayload, ModuleName
from app.schemas.indexing import IndexJobStatus
from app.modules.contracts import RagIndexer
from app.modules.rag.job_store import IndexJob, IndexJobStore
from app.modules.shared.event_bus import EventBus
from app.modules.shared.retry_executor import RetryExecutor
class IndexingOrchestrator:
def __init__(
self,
store: IndexJobStore,
rag: RagIndexer,
events: EventBus,
retry: RetryExecutor,
) -> None:
self._store = store
self._rag = rag
self._events = events
self._retry = retry
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
async def enqueue_snapshot(self, rag_session_id: str, files: list[dict]) -> IndexJob:
job = self._store.create(rag_session_id)
asyncio.create_task(self._process_snapshot(job.index_job_id, rag_session_id, files))
return job
async def enqueue_changes(self, rag_session_id: str, changed_files: list[dict]) -> IndexJob:
job = self._store.create(rag_session_id)
asyncio.create_task(self._process_changes(job.index_job_id, rag_session_id, changed_files))
return job
async def _process_snapshot(self, job_id: str, rag_session_id: str, files: list[dict]) -> None:
await self._run_with_project_lock(
job_id=job_id,
rag_session_id=rag_session_id,
total_files=len(files),
operation=lambda progress_cb: self._rag.index_snapshot(
rag_session_id=rag_session_id,
files=files,
progress_cb=progress_cb,
),
)
async def _process_changes(self, job_id: str, rag_session_id: str, changed_files: list[dict]) -> None:
await self._run_with_project_lock(
job_id=job_id,
rag_session_id=rag_session_id,
total_files=len(changed_files),
operation=lambda progress_cb: self._rag.index_changes(
rag_session_id=rag_session_id,
changed_files=changed_files,
progress_cb=progress_cb,
),
)
async def _run_with_project_lock(self, job_id: str, rag_session_id: str, total_files: int, operation) -> None:
lock = self._locks[rag_session_id]
async with lock:
job = self._store.get(job_id)
if not job:
return
job.status = IndexJobStatus.RUNNING
self._store.save(job)
await self._events.publish(
job_id,
"index_status",
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
)
try:
async def progress_cb(current_file_index: int, total: int, current_file_name: str) -> None:
await self._events.publish(
job_id,
"index_progress",
{
"index_job_id": job_id,
"current_file_index": current_file_index,
"total_files": total,
"processed_files": current_file_index,
"current_file_path": current_file_name,
"current_file_name": current_file_name,
},
)
indexed, failed = await self._retry.run(lambda: operation(progress_cb))
job.status = IndexJobStatus.DONE
job.indexed_files = indexed
job.failed_files = failed
self._store.save(job)
await self._events.publish(
job_id,
"index_status",
{
"index_job_id": job_id,
"status": job.status.value,
"indexed_files": indexed,
"failed_files": failed,
"total_files": total_files,
},
)
await self._events.publish(
job_id,
"terminal",
{
"index_job_id": job_id,
"status": "done",
"indexed_files": indexed,
"failed_files": failed,
"total_files": total_files,
},
)
except (TimeoutError, ConnectionError, OSError) as exc:
job.status = IndexJobStatus.ERROR
job.error = ErrorPayload(
code="index_retry_exhausted",
desc=f"Temporary indexing failure after retries: {exc}",
module=ModuleName.RAG,
)
self._store.save(job)
await self._events.publish(
job_id,
"index_status",
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
)
await self._events.publish(
job_id,
"terminal",
{
"index_job_id": job_id,
"status": "error",
"total_files": total_files,
"error": {
"code": job.error.code,
"desc": job.error.desc,
"module": job.error.module.value,
},
},
)

View File

@@ -0,0 +1,66 @@
from dataclasses import dataclass
from uuid import uuid4
from app.modules.rag.repository import RagRepository
from app.schemas.common import ErrorPayload, ModuleName
from app.schemas.indexing import IndexJobStatus
@dataclass
class IndexJob:
index_job_id: str
rag_session_id: str
status: IndexJobStatus = IndexJobStatus.QUEUED
indexed_files: int = 0
failed_files: int = 0
error: ErrorPayload | None = None
class IndexJobStore:
def __init__(self, repository: RagRepository) -> None:
self._repo = repository
def create(self, rag_session_id: str) -> IndexJob:
job = IndexJob(index_job_id=str(uuid4()), rag_session_id=rag_session_id)
self._repo.create_job(job.index_job_id, rag_session_id, job.status.value)
return job
def get(self, index_job_id: str) -> IndexJob | None:
row = self._repo.get_job(index_job_id)
if not row:
return None
payload = None
if row.error_code:
module = ModuleName.RAG
if row.error_module:
try:
module = ModuleName(row.error_module)
except ValueError:
module = ModuleName.RAG
payload = ErrorPayload(
code=row.error_code,
desc=row.error_desc or "",
module=module,
)
return IndexJob(
index_job_id=row.index_job_id,
rag_session_id=row.rag_session_id,
status=IndexJobStatus(row.status),
indexed_files=row.indexed_files,
failed_files=row.failed_files,
error=payload,
)
def save(self, job: IndexJob) -> None:
error_code = job.error.code if job.error else None
error_desc = job.error.desc if job.error else None
error_module = job.error.module.value if job.error else None
self._repo.update_job(
job.index_job_id,
status=job.status.value,
indexed_files=job.indexed_files,
failed_files=job.failed_files,
error_code=error_code,
error_desc=error_desc,
error_module=error_module,
)

247
app/modules/rag/module.py Normal file
View File

@@ -0,0 +1,247 @@
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from app.core.exceptions import AppError
from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder
from app.modules.rag.indexing_service import IndexingOrchestrator
from app.modules.rag.job_store import IndexJobStore
from app.modules.rag.repository import RagRepository
from app.modules.rag.retrieval.chunker import TextChunker
from app.modules.rag.session_store import RagSessionStore
from app.modules.rag.service import RagService
from app.modules.shared.event_bus import EventBus
from app.modules.shared.gigachat.client import GigaChatClient
from app.modules.shared.gigachat.settings import GigaChatSettings
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
from app.modules.shared.retry_executor import RetryExecutor
from app.schemas.common import ModuleName
from app.schemas.indexing import (
IndexChangesRequest,
IndexJobQueuedResponse,
IndexJobResponse,
IndexSnapshotRequest,
)
from app.schemas.rag_sessions import (
RagSessionChangesRequest,
RagSessionCreateRequest,
RagSessionCreateResponse,
RagSessionJobResponse,
)
class RagModule:
def __init__(self, event_bus: EventBus, retry: RetryExecutor, repository: RagRepository) -> None:
self._events = event_bus
self.repository = repository
settings = GigaChatSettings.from_env()
token_provider = GigaChatTokenProvider(settings)
client = GigaChatClient(settings, token_provider)
embedder = GigaChatEmbedder(client)
self.rag = RagService(embedder=embedder, repository=repository, chunker=TextChunker())
self.sessions = RagSessionStore(repository)
self.jobs = IndexJobStore(repository)
self.indexing = IndexingOrchestrator(
store=self.jobs,
rag=self.rag,
events=event_bus,
retry=retry,
)
def public_router(self) -> APIRouter:
router = APIRouter(tags=["rag"])
@router.post("/api/rag/sessions", response_model=RagSessionCreateResponse)
async def create_rag_session(request: RagSessionCreateRequest) -> RagSessionCreateResponse:
session = self.sessions.create(request.project_id)
job = await self.indexing.enqueue_snapshot(
rag_session_id=session.rag_session_id,
files=[x.model_dump() for x in request.files],
)
return RagSessionCreateResponse(
rag_session_id=session.rag_session_id,
index_job_id=job.index_job_id,
status=job.status,
)
@router.post("/api/rag/sessions/{rag_session_id}/changes", response_model=IndexJobQueuedResponse)
async def rag_session_changes(
rag_session_id: str,
request: RagSessionChangesRequest,
) -> IndexJobQueuedResponse:
session = self.sessions.get(rag_session_id)
if not session:
raise AppError("not_found", f"RAG session not found: {rag_session_id}", ModuleName.RAG)
job = await self.indexing.enqueue_changes(
rag_session_id=rag_session_id,
changed_files=[x.model_dump() for x in request.changed_files],
)
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}", response_model=RagSessionJobResponse)
async def rag_session_job(rag_session_id: str, index_job_id: str) -> RagSessionJobResponse:
job = self.jobs.get(index_job_id)
if not job or job.rag_session_id != rag_session_id:
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
return RagSessionJobResponse(
rag_session_id=rag_session_id,
index_job_id=job.index_job_id,
status=job.status,
indexed_files=job.indexed_files,
failed_files=job.failed_files,
error=job.error.model_dump(mode="json") if job.error else None,
)
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events")
async def rag_session_job_events(rag_session_id: str, index_job_id: str) -> StreamingResponse:
job = self.jobs.get(index_job_id)
if not job or job.rag_session_id != rag_session_id:
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
queue = await self._events.subscribe(index_job_id, replay=True)
async def event_stream():
import asyncio
heartbeat = 10
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
yield EventBus.as_sse(event)
if event.name == "terminal":
break
except asyncio.TimeoutError:
yield ": keepalive\n\n"
finally:
await self._events.unsubscribe(index_job_id, queue)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# Legacy compatibility endpoints.
legacy = APIRouter(prefix="/api/index", tags=["index"])
@legacy.post("/snapshot", response_model=IndexJobQueuedResponse)
async def index_snapshot(request: IndexSnapshotRequest) -> IndexJobQueuedResponse:
session = self.sessions.put(
rag_session_id=request.project_id,
project_id=request.project_id,
)
job = await self.indexing.enqueue_snapshot(
rag_session_id=session.rag_session_id,
files=[x.model_dump() for x in request.files],
)
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
@legacy.post("/changes", response_model=IndexJobQueuedResponse)
async def index_changes(request: IndexChangesRequest) -> IndexJobQueuedResponse:
rag_session_id = request.project_id
if not self.sessions.get(rag_session_id):
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
job = await self.indexing.enqueue_changes(
rag_session_id=rag_session_id,
changed_files=[x.model_dump() for x in request.changed_files],
)
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
@legacy.get("/jobs/{index_job_id}", response_model=IndexJobResponse)
async def get_index_job(index_job_id: str) -> IndexJobResponse:
job = self.jobs.get(index_job_id)
if not job:
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
return IndexJobResponse(
index_job_id=job.index_job_id,
status=job.status,
indexed_files=job.indexed_files,
failed_files=job.failed_files,
error=job.error,
)
@legacy.get("/jobs/{index_job_id}/events")
async def get_index_job_events(index_job_id: str) -> StreamingResponse:
job = self.jobs.get(index_job_id)
if not job:
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
queue = await self._events.subscribe(index_job_id, replay=True)
async def event_stream():
import asyncio
heartbeat = 10
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
yield EventBus.as_sse(event)
if event.name == "terminal":
break
except asyncio.TimeoutError:
yield ": keepalive\n\n"
finally:
await self._events.unsubscribe(index_job_id, queue)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
router.include_router(legacy)
return router
def internal_router(self) -> APIRouter:
router = APIRouter(prefix="/internal/rag", tags=["internal-rag"])
@router.post("/index/snapshot")
async def index_snapshot(request: IndexSnapshotRequest) -> dict:
rag_session_id = request.project_id
if not self.sessions.get(rag_session_id):
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
indexed, failed = await self.rag.index_snapshot(
rag_session_id=rag_session_id,
files=[x.model_dump() for x in request.files],
)
return {"indexed_files": indexed, "failed_files": failed}
@router.post("/index/changes")
async def index_changes(request: IndexChangesRequest) -> dict:
rag_session_id = request.project_id
indexed, failed = await self.rag.index_changes(
rag_session_id=rag_session_id,
changed_files=[x.model_dump() for x in request.changed_files],
)
return {"indexed_files": indexed, "failed_files": failed}
@router.get("/index/jobs/{index_job_id}")
async def get_job(index_job_id: str) -> dict:
job = self.jobs.get(index_job_id)
if not job:
return {"status": "not_found"}
return {
"index_job_id": job.index_job_id,
"status": job.status.value,
"indexed_files": job.indexed_files,
"failed_files": job.failed_files,
"error": job.error.model_dump(mode="json") if job.error else None,
}
@router.post("/retrieve")
async def retrieve(payload: dict) -> dict:
rag_session_id = payload.get("rag_session_id") or payload.get("project_id", "")
ctx = await self.rag.retrieve(
rag_session_id=rag_session_id,
query=payload.get("query", ""),
)
return {"items": ctx}
return router

View File

@@ -0,0 +1,261 @@
from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy import text
from app.modules.shared.db import get_engine
@dataclass
class RagJobRow:
index_job_id: str
rag_session_id: str
status: str
indexed_files: int
failed_files: int
error_code: str | None
error_desc: str | None
error_module: str | None
class RagRepository:
def ensure_tables(self) -> None:
engine = get_engine()
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS rag_sessions (
rag_session_id VARCHAR(64) PRIMARY KEY,
project_id VARCHAR(512) NOT NULL,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS rag_index_jobs (
index_job_id VARCHAR(64) PRIMARY KEY,
rag_session_id VARCHAR(64) NOT NULL,
status VARCHAR(16) NOT NULL,
indexed_files INTEGER NOT NULL DEFAULT 0,
failed_files INTEGER NOT NULL DEFAULT 0,
error_code VARCHAR(128) NULL,
error_desc TEXT NULL,
error_module VARCHAR(64) NULL,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS rag_chunks (
id BIGSERIAL PRIMARY KEY,
rag_session_id VARCHAR(64) NOT NULL,
path TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
content TEXT NOT NULL,
embedding vector NULL,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
conn.execute(
text(
"""
ALTER TABLE rag_chunks
ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
"""
)
)
conn.execute(
text(
"""
ALTER TABLE rag_chunks
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
"""
)
)
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_session ON rag_chunks (rag_session_id)"))
conn.commit()
def upsert_session(self, rag_session_id: str, project_id: str) -> None:
with get_engine().connect() as conn:
conn.execute(
text(
"""
INSERT INTO rag_sessions (rag_session_id, project_id)
VALUES (:sid, :pid)
ON CONFLICT (rag_session_id) DO UPDATE SET project_id = EXCLUDED.project_id
"""
),
{"sid": rag_session_id, "pid": project_id},
)
conn.commit()
def session_exists(self, rag_session_id: str) -> bool:
with get_engine().connect() as conn:
row = conn.execute(
text("SELECT 1 FROM rag_sessions WHERE rag_session_id = :sid"),
{"sid": rag_session_id},
).fetchone()
return bool(row)
def get_session(self, rag_session_id: str) -> dict | None:
with get_engine().connect() as conn:
row = conn.execute(
text("SELECT rag_session_id, project_id FROM rag_sessions WHERE rag_session_id = :sid"),
{"sid": rag_session_id},
).mappings().fetchone()
return dict(row) if row else None
def create_job(self, index_job_id: str, rag_session_id: str, status: str) -> None:
with get_engine().connect() as conn:
conn.execute(
text(
"""
INSERT INTO rag_index_jobs (index_job_id, rag_session_id, status)
VALUES (:jid, :sid, :status)
"""
),
{"jid": index_job_id, "sid": rag_session_id, "status": status},
)
conn.commit()
def update_job(
self,
index_job_id: str,
*,
status: str,
indexed_files: int,
failed_files: int,
error_code: str | None = None,
error_desc: str | None = None,
error_module: str | None = None,
) -> None:
with get_engine().connect() as conn:
conn.execute(
text(
"""
UPDATE rag_index_jobs
SET status = :status,
indexed_files = :indexed,
failed_files = :failed,
error_code = :ecode,
error_desc = :edesc,
error_module = :emodule,
updated_at = CURRENT_TIMESTAMP
WHERE index_job_id = :jid
"""
),
{
"jid": index_job_id,
"status": status,
"indexed": indexed_files,
"failed": failed_files,
"ecode": error_code,
"edesc": error_desc,
"emodule": error_module,
},
)
conn.commit()
def get_job(self, index_job_id: str) -> RagJobRow | None:
with get_engine().connect() as conn:
row = conn.execute(
text(
"""
SELECT index_job_id, rag_session_id, status, indexed_files, failed_files,
error_code, error_desc, error_module
FROM rag_index_jobs
WHERE index_job_id = :jid
"""
),
{"jid": index_job_id},
).mappings().fetchone()
if not row:
return None
return RagJobRow(**dict(row))
def replace_chunks(self, rag_session_id: str, items: list[dict]) -> None:
with get_engine().connect() as conn:
conn.execute(text("DELETE FROM rag_chunks WHERE rag_session_id = :sid"), {"sid": rag_session_id})
self._insert_chunks(conn, rag_session_id, items)
conn.commit()
def apply_changes(self, rag_session_id: str, delete_paths: list[str], upserts: list[dict]) -> None:
with get_engine().connect() as conn:
if delete_paths:
conn.execute(
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
{"sid": rag_session_id, "paths": delete_paths},
)
if upserts:
paths = sorted({str(x["path"]) for x in upserts})
conn.execute(
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
{"sid": rag_session_id, "paths": paths},
)
self._insert_chunks(conn, rag_session_id, upserts)
conn.commit()
def retrieve(self, rag_session_id: str, query_embedding: list[float], limit: int = 5) -> list[dict]:
emb = "[" + ",".join(str(x) for x in query_embedding) + "]"
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content
FROM rag_chunks
WHERE rag_session_id = :sid
ORDER BY embedding <=> CAST(:emb AS vector)
LIMIT :lim
"""
),
{"sid": rag_session_id, "emb": emb, "lim": limit},
).mappings().fetchall()
return [dict(x) for x in rows]
def fallback_chunks(self, rag_session_id: str, limit: int = 5) -> list[dict]:
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content
FROM rag_chunks
WHERE rag_session_id = :sid
ORDER BY id DESC
LIMIT :lim
"""
),
{"sid": rag_session_id, "lim": limit},
).mappings().fetchall()
return [dict(x) for x in rows]
def _insert_chunks(self, conn, rag_session_id: str, items: list[dict]) -> None:
for item in items:
emb = item.get("embedding") or []
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
conn.execute(
text(
"""
INSERT INTO rag_chunks (rag_session_id, path, chunk_index, content, embedding, created_at, updated_at)
VALUES (:sid, :path, :idx, :content, CAST(:emb AS vector), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
),
{
"sid": rag_session_id,
"path": item["path"],
"idx": int(item["chunk_index"]),
"content": item["content"],
"emb": emb_str,
},
)

View File

View File

@@ -0,0 +1,20 @@
class TextChunker:
def __init__(self, chunk_size: int = 900, overlap: int = 120) -> None:
self._chunk_size = chunk_size
self._overlap = overlap
def chunk(self, text: str) -> list[str]:
cleaned = text.replace("\r\n", "\n")
if not cleaned.strip():
return []
chunks: list[str] = []
start = 0
while start < len(cleaned):
end = min(len(cleaned), start + self._chunk_size)
piece = cleaned[start:end].strip()
if piece:
chunks.append(piece)
if end == len(cleaned):
break
start = max(0, end - self._overlap)
return chunks

View File

@@ -0,0 +1,12 @@
import math
def cosine_similarity(a: list[float], b: list[float]) -> float:
if not a or not b or len(a) != len(b):
return -1.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(y * y for y in b))
if norm_a == 0 or norm_b == 0:
return -1.0
return dot / (norm_a * norm_b)

134
app/modules/rag/service.py Normal file
View File

@@ -0,0 +1,134 @@
import asyncio
import os
from collections.abc import Awaitable, Callable
from inspect import isawaitable
from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder
from app.modules.rag.repository import RagRepository
from app.modules.rag.retrieval.chunker import TextChunker
class RagService:
def __init__(
self,
embedder: GigaChatEmbedder,
repository: RagRepository,
chunker: TextChunker | None = None,
) -> None:
self._embedder = embedder
self._repo = repository
self._chunker = chunker or TextChunker()
async def index_snapshot(
self,
rag_session_id: str,
files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> tuple[int, int]:
total_files = len(files)
indexed_files = 0
failed_files = 0
all_chunks: list[dict] = []
for index, file in enumerate(files, start=1):
path = str(file.get("path", ""))
try:
chunks = self._build_chunks_for_file(file)
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks)
all_chunks.extend(embedded_chunks)
indexed_files += 1
except Exception:
failed_files += 1
await self._notify_progress(progress_cb, index, total_files, path)
await asyncio.to_thread(self._repo.replace_chunks, rag_session_id, all_chunks)
return indexed_files, failed_files
async def index_changes(
self,
rag_session_id: str,
changed_files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> tuple[int, int]:
total_files = len(changed_files)
indexed_files = 0
failed_files = 0
delete_paths: list[str] = []
upsert_chunks: list[dict] = []
for index, file in enumerate(changed_files, start=1):
path = str(file.get("path", ""))
op = str(file.get("op", ""))
try:
if op == "delete":
delete_paths.append(path)
indexed_files += 1
await self._notify_progress(progress_cb, index, total_files, path)
continue
if op == "upsert" and file.get("content") is not None:
chunks = self._build_chunks_for_file(file)
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks)
upsert_chunks.extend(embedded_chunks)
indexed_files += 1
await self._notify_progress(progress_cb, index, total_files, path)
continue
failed_files += 1
except Exception:
failed_files += 1
await self._notify_progress(progress_cb, index, total_files, path)
await asyncio.to_thread(
self._repo.apply_changes,
rag_session_id,
delete_paths,
upsert_chunks,
)
return indexed_files, failed_files
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
try:
query_embedding = self._embedder.embed([query])[0]
rows = self._repo.retrieve(rag_session_id, query_embedding, limit=5)
except Exception:
rows = self._repo.fallback_chunks(rag_session_id, limit=5)
return [{"source": row["path"], "content": row["content"]} for row in rows]
def _build_chunks_for_file(self, file: dict) -> list[tuple[str, int, str]]:
path = str(file.get("path", ""))
content = str(file.get("content", ""))
output: list[tuple[str, int, str]] = []
for idx, chunk in enumerate(self._chunker.chunk(content)):
output.append((path, idx, chunk))
return output
def _embed_chunks(self, raw_chunks: list[tuple[str, int, str]]) -> list[dict]:
if not raw_chunks:
return []
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
indexed: list[dict] = []
for i in range(0, len(raw_chunks), batch_size):
batch = raw_chunks[i : i + batch_size]
texts = [x[2] for x in batch]
vectors = self._embedder.embed(texts)
for (path, chunk_index, content), vector in zip(batch, vectors):
indexed.append(
{
"path": path,
"chunk_index": chunk_index,
"content": content,
"embedding": vector,
}
)
return indexed
async def _notify_progress(
self,
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
current_file_index: int,
total_files: int,
current_file_name: str,
) -> None:
if not progress_cb:
return
result = progress_cb(current_file_index, total_files, current_file_name)
if isawaitable(result):
await result

View File

@@ -0,0 +1,34 @@
from dataclasses import dataclass
from uuid import uuid4
from app.modules.rag.repository import RagRepository
@dataclass
class RagSession:
rag_session_id: str
project_id: str
class RagSessionStore:
def __init__(self, repository: RagRepository) -> None:
self._repo = repository
def create(self, project_id: str) -> RagSession:
session = RagSession(rag_session_id=str(uuid4()), project_id=project_id)
self._repo.upsert_session(session.rag_session_id, session.project_id)
return session
def put(self, rag_session_id: str, project_id: str) -> RagSession:
session = RagSession(rag_session_id=rag_session_id, project_id=project_id)
self._repo.upsert_session(rag_session_id, project_id)
return session
def get(self, rag_session_id: str) -> RagSession | None:
row = self._repo.get_session(rag_session_id)
if not row:
return None
return RagSession(
rag_session_id=str(row["rag_session_id"]),
project_id=str(row["project_id"]),
)