первый коммит
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
# Модуль rag_session
|
||||
|
||||
## 1. Функции модуля
|
||||
- Создание и обслуживание сессионного RAG индекса по загруженным пользователем файлам.
|
||||
- Индексация снапшота и инкрементальных изменений.
|
||||
- Хранение чанков, retrieval контекста, трекинг статуса index jobs.
|
||||
- Публикация прогресса индексации через SSE.
|
||||
|
||||
## 2. Диаграмма классов и взаимосвязей
|
||||
```mermaid
|
||||
classDiagram
|
||||
class RagModule
|
||||
class RagService
|
||||
class RagRepository
|
||||
class RagSessionStore
|
||||
class IndexJobStore
|
||||
class IndexingOrchestrator
|
||||
class TextChunker
|
||||
class GigaChatEmbedder
|
||||
class EventBus
|
||||
|
||||
RagModule --> RagService
|
||||
RagModule --> RagRepository
|
||||
RagModule --> RagSessionStore
|
||||
RagModule --> IndexJobStore
|
||||
RagModule --> IndexingOrchestrator
|
||||
RagService --> RagRepository
|
||||
RagService --> TextChunker
|
||||
RagService --> GigaChatEmbedder
|
||||
IndexingOrchestrator --> IndexJobStore
|
||||
IndexingOrchestrator --> RagService
|
||||
IndexingOrchestrator --> EventBus
|
||||
```
|
||||
|
||||
## 3. Описание классов
|
||||
- `RagModule`: composition-root для сессионного RAG и его API.
|
||||
Методы: `__init__` — собирает сервисы индексации/retrieval; `public_router` — публикует внешние endpoint'ы; `internal_router` — публикует внутренние endpoint'ы.
|
||||
- `RagService`: доменный сервис индексации и retrieval.
|
||||
Методы: `index_snapshot` — индексирует полный набор файлов; `index_changes` — индексирует только изменения; `retrieve` — возвращает релевантные чанки по запросу.
|
||||
- `RagRepository`: слой доступа к БД для сессий, джобов и чанков.
|
||||
Методы: `ensure_tables` — создает/обновляет схему; `upsert_session/get_session/session_exists` — операции по сессиям; `create_job/update_job/get_job` — операции по задачам индексации; `replace_chunks/apply_changes/retrieve/fallback_chunks` — операции по chunk-данным.
|
||||
- `RagSessionStore`: управление жизненным циклом `rag_session`.
|
||||
Методы: `create` — создает новую сессию; `put` — upsert с внешним id; `get` — читает сессию.
|
||||
- `IndexJobStore`: управление `index_job` на уровне приложения.
|
||||
Методы: `create` — создает задачу индексации; `get` — читает задачу; `save` — обновляет статус/ошибку.
|
||||
- `IndexingOrchestrator`: асинхронный оркестратор index-jobs.
|
||||
Методы: `enqueue_snapshot` — ставит полную индексацию в очередь; `enqueue_changes` — ставит инкрементальную индексацию в очередь.
|
||||
- `TextChunker`: разбивает текст файла на чанки для embedding.
|
||||
Методы: `chunk` — возвращает список чанков заданного текста.
|
||||
- `GigaChatEmbedder`: адаптер embeddings-модели.
|
||||
Методы: `embed` — возвращает векторы для набора текстов.
|
||||
- `EventBus`: доставка событий прогресса индексации.
|
||||
Методы: `publish` — отправляет событие; `subscribe/unsubscribe` — управляет подписками SSE.
|
||||
|
||||
## 4. Сиквенс-диаграммы API
|
||||
|
||||
### POST /api/rag/sessions
|
||||
Назначение: создает новую `rag_session` и запускает фоновую индексацию полного набора файлов.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as RagModule.APIRouter
|
||||
participant Sessions as RagSessionStore
|
||||
participant Indexing as IndexingOrchestrator
|
||||
|
||||
Router->>Sessions: create(project_id)
|
||||
Sessions-->>Router: rag_session_id
|
||||
Router->>Indexing: enqueue_snapshot(rag_session_id, files)
|
||||
Indexing-->>Router: index_job_id,status
|
||||
```
|
||||
|
||||
### POST /api/rag/sessions/{rag_session_id}/changes
|
||||
Назначение: ставит в очередь инкрементальную переиндексацию изменений для существующей `rag_session`.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as RagModule.APIRouter
|
||||
participant Sessions as RagSessionStore
|
||||
participant Indexing as IndexingOrchestrator
|
||||
|
||||
Router->>Sessions: get(rag_session_id)
|
||||
Sessions-->>Router: session
|
||||
Router->>Indexing: enqueue_changes(rag_session_id, changed_files)
|
||||
Indexing-->>Router: index_job_id,status
|
||||
```
|
||||
|
||||
### GET /api/rag/sessions/{rag_session_id}/jobs/{index_job_id}
|
||||
Назначение: возвращает состояние и статистику конкретной задачи индексации.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as RagModule.APIRouter
|
||||
participant Jobs as IndexJobStore
|
||||
|
||||
Router->>Jobs: get(index_job_id)
|
||||
Jobs-->>Router: job_state
|
||||
```
|
||||
|
||||
### GET /api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events
|
||||
Назначение: дает SSE-поток событий прогресса по задаче индексации.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as RagModule.APIRouter
|
||||
participant Jobs as IndexJobStore
|
||||
participant Events as EventBus
|
||||
|
||||
Router->>Jobs: get(index_job_id)
|
||||
Router->>Events: subscribe(index_job_id, replay=True)
|
||||
loop until terminal
|
||||
Events-->>Router: index event
|
||||
end
|
||||
Router->>Events: unsubscribe(index_job_id)
|
||||
```
|
||||
|
||||
### POST /api/index/snapshot (legacy)
|
||||
Назначение: legacy-вход для полной индексации проекта с автоматическим созданием сессии по `project_id`.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as LegacyAPIRouter
|
||||
participant Sessions as RagSessionStore
|
||||
participant Indexing as IndexingOrchestrator
|
||||
|
||||
Router->>Sessions: put(project_id, project_id)
|
||||
Router->>Indexing: enqueue_snapshot(project_id, files)
|
||||
Indexing-->>Router: index_job_id,status
|
||||
```
|
||||
|
||||
### POST /api/index/changes (legacy)
|
||||
Назначение: legacy-вход для инкрементальной индексации изменений по `project_id`.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as LegacyAPIRouter
|
||||
participant Sessions as RagSessionStore
|
||||
participant Indexing as IndexingOrchestrator
|
||||
|
||||
Router->>Sessions: get(project_id)
|
||||
alt missing
|
||||
Router->>Sessions: put(project_id, project_id)
|
||||
end
|
||||
Router->>Indexing: enqueue_changes(project_id, changed_files)
|
||||
Indexing-->>Router: index_job_id,status
|
||||
```
|
||||
|
||||
### GET /api/index/jobs/{index_job_id} (legacy)
|
||||
Назначение: legacy-чтение статуса index-job по `index_job_id`.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as LegacyAPIRouter
|
||||
participant Jobs as IndexJobStore
|
||||
|
||||
Router->>Jobs: get(index_job_id)
|
||||
Jobs-->>Router: job_state
|
||||
```
|
||||
|
||||
### GET /api/index/jobs/{index_job_id}/events (legacy)
|
||||
Назначение: legacy-SSE поток событий по index-job.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as LegacyAPIRouter
|
||||
participant Jobs as IndexJobStore
|
||||
participant Events as EventBus
|
||||
|
||||
Router->>Jobs: get(index_job_id)
|
||||
Router->>Events: subscribe(index_job_id, replay=True)
|
||||
loop until terminal
|
||||
Events-->>Router: index event
|
||||
end
|
||||
Router->>Events: unsubscribe(index_job_id)
|
||||
```
|
||||
|
||||
### POST /internal/rag/index/snapshot
|
||||
Назначение: внутренний синхронный запуск полной индексации для сервисных сценариев.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as InternalRagRouter
|
||||
participant Sessions as RagSessionStore
|
||||
participant RagService as RagService
|
||||
|
||||
Router->>Sessions: get(project_id)
|
||||
alt missing
|
||||
Router->>Sessions: put(project_id, project_id)
|
||||
end
|
||||
Router->>RagService: index_snapshot(project_id, files)
|
||||
RagService-->>Router: indexed_files,failed_files
|
||||
```
|
||||
|
||||
### POST /internal/rag/index/changes
|
||||
Назначение: внутренний синхронный запуск индексации изменений.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as InternalRagRouter
|
||||
participant RagService as RagService
|
||||
|
||||
Router->>RagService: index_changes(project_id, changed_files)
|
||||
RagService-->>Router: indexed_files,failed_files
|
||||
```
|
||||
|
||||
### GET /internal/rag/index/jobs/{index_job_id}
|
||||
Назначение: внутреннее получение статуса и ошибки index-job для сервисов оркестрации.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as InternalRagRouter
|
||||
participant Jobs as IndexJobStore
|
||||
|
||||
Router->>Jobs: get(index_job_id)
|
||||
Jobs-->>Router: job_state
|
||||
```
|
||||
|
||||
### POST /internal/rag/retrieve
|
||||
Назначение: внутренний retrieval релевантных чанков из `rag_session` по текстовому запросу.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Router as InternalRagRouter
|
||||
participant RagService as RagService
|
||||
participant RagRepo as RagRepository
|
||||
|
||||
Router->>RagService: retrieve(rag_session_id, query)
|
||||
RagService->>RagRepo: retrieve/fallback_chunks
|
||||
RagRepo-->>RagService: chunks
|
||||
RagService-->>Router: items
|
||||
```
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,9 @@
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
|
||||
|
||||
class GigaChatEmbedder:
|
||||
def __init__(self, client: GigaChatClient) -> None:
|
||||
self._client = client
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._client.embed(texts)
|
||||
@@ -0,0 +1,147 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
from app.schemas.common import ErrorPayload, ModuleName
|
||||
from app.schemas.indexing import IndexJobStatus
|
||||
from app.modules.contracts import RagIndexer
|
||||
from app.modules.rag_session.job_store import IndexJob, IndexJobStore
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
|
||||
|
||||
class IndexingOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
store: IndexJobStore,
|
||||
rag: RagIndexer,
|
||||
events: EventBus,
|
||||
retry: RetryExecutor,
|
||||
) -> None:
|
||||
self._store = store
|
||||
self._rag = rag
|
||||
self._events = events
|
||||
self._retry = retry
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
async def enqueue_snapshot(self, rag_session_id: str, files: list[dict]) -> IndexJob:
|
||||
job = self._store.create(rag_session_id)
|
||||
asyncio.create_task(self._process_snapshot(job.index_job_id, rag_session_id, files))
|
||||
return job
|
||||
|
||||
async def enqueue_changes(self, rag_session_id: str, changed_files: list[dict]) -> IndexJob:
|
||||
job = self._store.create(rag_session_id)
|
||||
asyncio.create_task(self._process_changes(job.index_job_id, rag_session_id, changed_files))
|
||||
return job
|
||||
|
||||
async def _process_snapshot(self, job_id: str, rag_session_id: str, files: list[dict]) -> None:
|
||||
await self._run_with_project_lock(
|
||||
job_id=job_id,
|
||||
rag_session_id=rag_session_id,
|
||||
total_files=len(files),
|
||||
operation=lambda progress_cb: self._rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=files,
|
||||
progress_cb=progress_cb,
|
||||
),
|
||||
)
|
||||
|
||||
async def _process_changes(self, job_id: str, rag_session_id: str, changed_files: list[dict]) -> None:
|
||||
await self._run_with_project_lock(
|
||||
job_id=job_id,
|
||||
rag_session_id=rag_session_id,
|
||||
total_files=len(changed_files),
|
||||
operation=lambda progress_cb: self._rag.index_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=changed_files,
|
||||
progress_cb=progress_cb,
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_with_project_lock(self, job_id: str, rag_session_id: str, total_files: int, operation) -> None:
|
||||
lock = self._locks[rag_session_id]
|
||||
async with lock:
|
||||
job = self._store.get(job_id)
|
||||
if not job:
|
||||
return
|
||||
job.status = IndexJobStatus.RUNNING
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
||||
)
|
||||
try:
|
||||
async def progress_cb(current_file_index: int, total: int, current_file_name: str) -> None:
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_progress",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"current_file_index": current_file_index,
|
||||
"total_files": total,
|
||||
"processed_files": current_file_index,
|
||||
"current_file_path": current_file_name,
|
||||
"current_file_name": current_file_name,
|
||||
},
|
||||
)
|
||||
|
||||
indexed, failed, cache_hits, cache_misses = await self._retry.run(lambda: operation(progress_cb))
|
||||
job.status = IndexJobStatus.DONE
|
||||
job.indexed_files = indexed
|
||||
job.failed_files = failed
|
||||
job.cache_hit_files = cache_hits
|
||||
job.cache_miss_files = cache_misses
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": job.status.value,
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"cache_hit_files": cache_hits,
|
||||
"cache_miss_files": cache_misses,
|
||||
"total_files": total_files,
|
||||
},
|
||||
)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"terminal",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": "done",
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"cache_hit_files": cache_hits,
|
||||
"cache_miss_files": cache_misses,
|
||||
"total_files": total_files,
|
||||
},
|
||||
)
|
||||
except (TimeoutError, ConnectionError, OSError) as exc:
|
||||
job.status = IndexJobStatus.ERROR
|
||||
job.error = ErrorPayload(
|
||||
code="index_retry_exhausted",
|
||||
desc=f"Temporary indexing failure after retries: {exc}",
|
||||
module=ModuleName.RAG,
|
||||
)
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
||||
)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"terminal",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": "error",
|
||||
"total_files": total_files,
|
||||
"error": {
|
||||
"code": job.error.code,
|
||||
"desc": job.error.desc,
|
||||
"module": job.error.module.value,
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from app.modules.rag_session.repository import RagRepository
|
||||
from app.schemas.common import ErrorPayload, ModuleName
|
||||
from app.schemas.indexing import IndexJobStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexJob:
|
||||
index_job_id: str
|
||||
rag_session_id: str
|
||||
status: IndexJobStatus = IndexJobStatus.QUEUED
|
||||
indexed_files: int = 0
|
||||
failed_files: int = 0
|
||||
cache_hit_files: int = 0
|
||||
cache_miss_files: int = 0
|
||||
error: ErrorPayload | None = None
|
||||
|
||||
|
||||
class IndexJobStore:
|
||||
def __init__(self, repository: RagRepository) -> None:
|
||||
self._repo = repository
|
||||
|
||||
def create(self, rag_session_id: str) -> IndexJob:
|
||||
job = IndexJob(index_job_id=str(uuid4()), rag_session_id=rag_session_id)
|
||||
self._repo.create_job(job.index_job_id, rag_session_id, job.status.value)
|
||||
return job
|
||||
|
||||
def get(self, index_job_id: str) -> IndexJob | None:
|
||||
row = self._repo.get_job(index_job_id)
|
||||
if not row:
|
||||
return None
|
||||
payload = None
|
||||
if row.error_code:
|
||||
module = ModuleName.RAG
|
||||
if row.error_module:
|
||||
try:
|
||||
module = ModuleName(row.error_module)
|
||||
except ValueError:
|
||||
module = ModuleName.RAG
|
||||
payload = ErrorPayload(
|
||||
code=row.error_code,
|
||||
desc=row.error_desc or "",
|
||||
module=module,
|
||||
)
|
||||
return IndexJob(
|
||||
index_job_id=row.index_job_id,
|
||||
rag_session_id=row.rag_session_id,
|
||||
status=IndexJobStatus(row.status),
|
||||
indexed_files=row.indexed_files,
|
||||
failed_files=row.failed_files,
|
||||
cache_hit_files=row.cache_hit_files,
|
||||
cache_miss_files=row.cache_miss_files,
|
||||
error=payload,
|
||||
)
|
||||
|
||||
def save(self, job: IndexJob) -> None:
|
||||
error_code = job.error.code if job.error else None
|
||||
error_desc = job.error.desc if job.error else None
|
||||
error_module = job.error.module.value if job.error else None
|
||||
self._repo.update_job(
|
||||
job.index_job_id,
|
||||
status=job.status.value,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
cache_hit_files=job.cache_hit_files,
|
||||
cache_miss_files=job.cache_miss_files,
|
||||
error_code=error_code,
|
||||
error_desc=error_desc,
|
||||
error_module=error_module,
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag_session.indexing_service import IndexingOrchestrator
|
||||
from app.modules.rag_session.job_store import IndexJobStore
|
||||
from app.modules.rag_session.repository import RagRepository
|
||||
from app.modules.rag_session.retrieval.chunker import TextChunker
|
||||
from app.modules.rag_session.session_store import RagSessionStore
|
||||
from app.modules.rag_session.service import RagService
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
from app.schemas.common import ModuleName
|
||||
from app.schemas.indexing import (
|
||||
IndexChangesRequest,
|
||||
IndexJobQueuedResponse,
|
||||
IndexJobResponse,
|
||||
IndexSnapshotRequest,
|
||||
)
|
||||
from app.schemas.rag_sessions import (
|
||||
RagSessionChangesRequest,
|
||||
RagSessionCreateRequest,
|
||||
RagSessionCreateResponse,
|
||||
RagSessionJobResponse,
|
||||
)
|
||||
|
||||
|
||||
class RagModule:
|
||||
def __init__(self, event_bus: EventBus, retry: RetryExecutor, repository: RagRepository) -> None:
|
||||
self._events = event_bus
|
||||
self.repository = repository
|
||||
settings = GigaChatSettings.from_env()
|
||||
token_provider = GigaChatTokenProvider(settings)
|
||||
client = GigaChatClient(settings, token_provider)
|
||||
embedder = GigaChatEmbedder(client)
|
||||
self.rag = RagService(embedder=embedder, repository=repository, chunker=TextChunker())
|
||||
self.sessions = RagSessionStore(repository)
|
||||
self.jobs = IndexJobStore(repository)
|
||||
self.indexing = IndexingOrchestrator(
|
||||
store=self.jobs,
|
||||
rag=self.rag,
|
||||
events=event_bus,
|
||||
retry=retry,
|
||||
)
|
||||
|
||||
def public_router(self) -> APIRouter:
|
||||
router = APIRouter(tags=["rag"])
|
||||
|
||||
@router.post("/api/rag/sessions", response_model=RagSessionCreateResponse)
|
||||
async def create_rag_session(request: RagSessionCreateRequest) -> RagSessionCreateResponse:
|
||||
session = self.sessions.create(request.project_id)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return RagSessionCreateResponse(
|
||||
rag_session_id=session.rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
)
|
||||
|
||||
@router.post("/api/rag/sessions/{rag_session_id}/changes", response_model=IndexJobQueuedResponse)
|
||||
async def rag_session_changes(
|
||||
rag_session_id: str,
|
||||
request: RagSessionChangesRequest,
|
||||
) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.get(rag_session_id)
|
||||
if not session:
|
||||
raise AppError("not_found", f"RAG session not found: {rag_session_id}", ModuleName.RAG)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}", response_model=RagSessionJobResponse)
|
||||
async def rag_session_job(rag_session_id: str, index_job_id: str) -> RagSessionJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return RagSessionJobResponse(
|
||||
rag_session_id=rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
cache_hit_files=job.cache_hit_files,
|
||||
cache_miss_files=job.cache_miss_files,
|
||||
error=job.error.model_dump(mode="json") if job.error else None,
|
||||
)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events")
|
||||
async def rag_session_job_events(rag_session_id: str, index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# Legacy compatibility endpoints.
|
||||
legacy = APIRouter(prefix="/api/index", tags=["index"])
|
||||
|
||||
@legacy.post("/snapshot", response_model=IndexJobQueuedResponse)
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.put(
|
||||
rag_session_id=request.project_id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.post("/changes", response_model=IndexJobQueuedResponse)
|
||||
async def index_changes(request: IndexChangesRequest) -> IndexJobQueuedResponse:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}", response_model=IndexJobResponse)
|
||||
async def get_index_job(index_job_id: str) -> IndexJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return IndexJobResponse(
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
cache_hit_files=job.cache_hit_files,
|
||||
cache_miss_files=job.cache_miss_files,
|
||||
error=job.error,
|
||||
)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}/events")
|
||||
async def get_index_job_events(index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
router.include_router(legacy)
|
||||
return router
|
||||
|
||||
def internal_router(self) -> APIRouter:
|
||||
router = APIRouter(prefix="/internal/rag", tags=["internal-rag"])
|
||||
|
||||
@router.post("/index/snapshot")
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
indexed, failed, cache_hits, cache_misses = await self.rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return {
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"cache_hit_files": cache_hits,
|
||||
"cache_miss_files": cache_misses,
|
||||
}
|
||||
|
||||
@router.post("/index/changes")
|
||||
async def index_changes(request: IndexChangesRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
indexed, failed, cache_hits, cache_misses = await self.rag.index_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return {
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"cache_hit_files": cache_hits,
|
||||
"cache_miss_files": cache_misses,
|
||||
}
|
||||
|
||||
@router.get("/index/jobs/{index_job_id}")
|
||||
async def get_job(index_job_id: str) -> dict:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
return {"status": "not_found"}
|
||||
return {
|
||||
"index_job_id": job.index_job_id,
|
||||
"status": job.status.value,
|
||||
"indexed_files": job.indexed_files,
|
||||
"failed_files": job.failed_files,
|
||||
"cache_hit_files": job.cache_hit_files,
|
||||
"cache_miss_files": job.cache_miss_files,
|
||||
"error": job.error.model_dump(mode="json") if job.error else None,
|
||||
}
|
||||
|
||||
@router.post("/retrieve")
|
||||
async def retrieve(payload: dict) -> dict:
|
||||
rag_session_id = payload.get("rag_session_id") or payload.get("project_id", "")
|
||||
ctx = await self.rag.retrieve(
|
||||
rag_session_id=rag_session_id,
|
||||
query=payload.get("query", ""),
|
||||
)
|
||||
return {"items": ctx}
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,660 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.modules.shared.db import get_engine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagJobRow:
|
||||
index_job_id: str
|
||||
rag_session_id: str
|
||||
status: str
|
||||
indexed_files: int
|
||||
failed_files: int
|
||||
cache_hit_files: int
|
||||
cache_miss_files: int
|
||||
error_code: str | None
|
||||
error_desc: str | None
|
||||
error_module: str | None
|
||||
|
||||
|
||||
class RagRepository:
|
||||
def ensure_tables(self) -> None:
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_sessions (
|
||||
rag_session_id VARCHAR(64) PRIMARY KEY,
|
||||
project_id VARCHAR(512) NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_index_jobs (
|
||||
index_job_id VARCHAR(64) PRIMARY KEY,
|
||||
rag_session_id VARCHAR(64) NOT NULL,
|
||||
status VARCHAR(16) NOT NULL,
|
||||
indexed_files INTEGER NOT NULL DEFAULT 0,
|
||||
failed_files INTEGER NOT NULL DEFAULT 0,
|
||||
cache_hit_files INTEGER NOT NULL DEFAULT 0,
|
||||
cache_miss_files INTEGER NOT NULL DEFAULT 0,
|
||||
error_code VARCHAR(128) NULL,
|
||||
error_desc TEXT NULL,
|
||||
error_module VARCHAR(64) NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chunks (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
rag_session_id VARCHAR(64) NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding vector NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS artifact_type VARCHAR(16) NULL"))
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS section TEXT NULL"))
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS doc_id TEXT NULL"))
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS doc_version TEXT NULL"))
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS owner TEXT NULL"))
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS system_component TEXT NULL"))
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS last_modified TIMESTAMPTZ NULL"))
|
||||
conn.execute(text("ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS staleness_score DOUBLE PRECISION NULL"))
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE rag_chunks
|
||||
ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE rag_chunks
|
||||
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_blob_cache (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
repo_id VARCHAR(512) NOT NULL,
|
||||
blob_sha VARCHAR(128) NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
artifact_type VARCHAR(16) NULL,
|
||||
section TEXT NULL,
|
||||
doc_id TEXT NULL,
|
||||
doc_version TEXT NULL,
|
||||
owner TEXT NULL,
|
||||
system_component TEXT NULL,
|
||||
last_modified TIMESTAMPTZ NULL,
|
||||
staleness_score DOUBLE PRECISION NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT uq_rag_blob_cache UNIQUE (repo_id, blob_sha, path)
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chunk_cache (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
repo_id VARCHAR(512) NOT NULL,
|
||||
blob_sha VARCHAR(128) NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding vector NULL,
|
||||
section TEXT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT uq_rag_chunk_cache UNIQUE (repo_id, blob_sha, chunk_index)
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_session_chunk_map (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
rag_session_id VARCHAR(64) NOT NULL,
|
||||
repo_id VARCHAR(512) NOT NULL,
|
||||
blob_sha VARCHAR(128) NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_session ON rag_chunks (rag_session_id)"))
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_artifact_type ON rag_chunks (artifact_type)"))
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_doc ON rag_chunks (doc_id, doc_version)"))
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_component ON rag_chunks (system_component)"))
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_path ON rag_chunks (path)"))
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_blob_cache_repo_blob ON rag_blob_cache (repo_id, blob_sha)"))
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunk_cache_repo_blob ON rag_chunk_cache (repo_id, blob_sha, chunk_index)"))
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_session_chunk_map_session ON rag_session_chunk_map (rag_session_id, created_at DESC)"))
|
||||
conn.execute(text("ALTER TABLE rag_index_jobs ADD COLUMN IF NOT EXISTS cache_hit_files INTEGER NOT NULL DEFAULT 0"))
|
||||
conn.execute(text("ALTER TABLE rag_index_jobs ADD COLUMN IF NOT EXISTS cache_miss_files INTEGER NOT NULL DEFAULT 0"))
|
||||
conn.commit()
|
||||
|
||||
def upsert_session(self, rag_session_id: str, project_id: str) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_sessions (rag_session_id, project_id)
|
||||
VALUES (:sid, :pid)
|
||||
ON CONFLICT (rag_session_id) DO UPDATE SET project_id = EXCLUDED.project_id
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "pid": project_id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def session_exists(self, rag_session_id: str) -> bool:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT 1 FROM rag_sessions WHERE rag_session_id = :sid"),
|
||||
{"sid": rag_session_id},
|
||||
).fetchone()
|
||||
return bool(row)
|
||||
|
||||
def get_session(self, rag_session_id: str) -> dict | None:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT rag_session_id, project_id FROM rag_sessions WHERE rag_session_id = :sid"),
|
||||
{"sid": rag_session_id},
|
||||
).mappings().fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def create_job(self, index_job_id: str, rag_session_id: str, status: str) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_index_jobs (index_job_id, rag_session_id, status)
|
||||
VALUES (:jid, :sid, :status)
|
||||
"""
|
||||
),
|
||||
{"jid": index_job_id, "sid": rag_session_id, "status": status},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_job(
|
||||
self,
|
||||
index_job_id: str,
|
||||
*,
|
||||
status: str,
|
||||
indexed_files: int,
|
||||
failed_files: int,
|
||||
cache_hit_files: int = 0,
|
||||
cache_miss_files: int = 0,
|
||||
error_code: str | None = None,
|
||||
error_desc: str | None = None,
|
||||
error_module: str | None = None,
|
||||
) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_index_jobs
|
||||
SET status = :status,
|
||||
indexed_files = :indexed,
|
||||
failed_files = :failed,
|
||||
cache_hit_files = :cache_hit_files,
|
||||
cache_miss_files = :cache_miss_files,
|
||||
error_code = :ecode,
|
||||
error_desc = :edesc,
|
||||
error_module = :emodule,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE index_job_id = :jid
|
||||
"""
|
||||
),
|
||||
{
|
||||
"jid": index_job_id,
|
||||
"status": status,
|
||||
"indexed": indexed_files,
|
||||
"failed": failed_files,
|
||||
"cache_hit_files": cache_hit_files,
|
||||
"cache_miss_files": cache_miss_files,
|
||||
"ecode": error_code,
|
||||
"edesc": error_desc,
|
||||
"emodule": error_module,
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_job(self, index_job_id: str) -> RagJobRow | None:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT index_job_id, rag_session_id, status, indexed_files, failed_files,
|
||||
cache_hit_files, cache_miss_files, error_code, error_desc, error_module
|
||||
FROM rag_index_jobs
|
||||
WHERE index_job_id = :jid
|
||||
"""
|
||||
),
|
||||
{"jid": index_job_id},
|
||||
).mappings().fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return RagJobRow(**dict(row))
|
||||
|
||||
def replace_chunks(self, rag_session_id: str, items: list[dict]) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(text("DELETE FROM rag_chunks WHERE rag_session_id = :sid"), {"sid": rag_session_id})
|
||||
conn.execute(text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid"), {"sid": rag_session_id})
|
||||
self._insert_chunks(conn, rag_session_id, items)
|
||||
conn.commit()
|
||||
|
||||
def apply_changes(self, rag_session_id: str, delete_paths: list[str], upserts: list[dict]) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
if delete_paths:
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": delete_paths},
|
||||
)
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": delete_paths},
|
||||
)
|
||||
if upserts:
|
||||
paths = sorted({str(x["path"]) for x in upserts})
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": paths},
|
||||
)
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": paths},
|
||||
)
|
||||
self._insert_chunks(conn, rag_session_id, upserts)
|
||||
conn.commit()
|
||||
|
||||
def get_cached_chunks(self, repo_id: str, blob_sha: str) -> list[dict]:
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT chunk_index, content, embedding::text AS embedding_txt, section
|
||||
FROM rag_chunk_cache
|
||||
WHERE repo_id = :repo_id AND blob_sha = :blob_sha
|
||||
ORDER BY chunk_index ASC
|
||||
"""
|
||||
),
|
||||
{"repo_id": repo_id, "blob_sha": blob_sha},
|
||||
).mappings().fetchall()
|
||||
output: list[dict] = []
|
||||
for row in rows:
|
||||
output.append(
|
||||
{
|
||||
"chunk_index": int(row["chunk_index"]),
|
||||
"content": str(row["content"] or ""),
|
||||
"embedding": self._parse_vector(str(row["embedding_txt"] or "")),
|
||||
"section": row.get("section"),
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def record_repo_cache(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
commit_sha: str | None,
|
||||
changed_files: list[str],
|
||||
summary: str,
|
||||
) -> None:
|
||||
repo_session_id = f"repo:{project_id}"
|
||||
with get_engine().connect() as conn:
|
||||
for path in changed_files:
|
||||
key = f"{commit_sha or 'no-commit'}:{path}"
|
||||
blob_sha = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_blob_cache (
|
||||
repo_id,
|
||||
blob_sha,
|
||||
path,
|
||||
artifact_type,
|
||||
section
|
||||
)
|
||||
VALUES (
|
||||
:repo_id,
|
||||
:blob_sha,
|
||||
:path,
|
||||
:artifact_type,
|
||||
:section
|
||||
)
|
||||
ON CONFLICT (repo_id, blob_sha, path) DO UPDATE SET
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"""
|
||||
),
|
||||
{
|
||||
"repo_id": project_id,
|
||||
"blob_sha": blob_sha,
|
||||
"path": path,
|
||||
"artifact_type": "CODE",
|
||||
"section": "repo_webhook",
|
||||
},
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_chunk_cache (
|
||||
repo_id,
|
||||
blob_sha,
|
||||
chunk_index,
|
||||
content,
|
||||
embedding,
|
||||
section
|
||||
)
|
||||
VALUES (
|
||||
:repo_id,
|
||||
:blob_sha,
|
||||
0,
|
||||
:content,
|
||||
NULL,
|
||||
:section
|
||||
)
|
||||
ON CONFLICT (repo_id, blob_sha, chunk_index) DO UPDATE SET
|
||||
content = EXCLUDED.content,
|
||||
section = EXCLUDED.section,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"""
|
||||
),
|
||||
{
|
||||
"repo_id": project_id,
|
||||
"blob_sha": blob_sha,
|
||||
"content": f"repo_webhook:{path}:{summary[:300]}",
|
||||
"section": "repo_webhook",
|
||||
},
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_session_chunk_map (
|
||||
rag_session_id,
|
||||
repo_id,
|
||||
blob_sha,
|
||||
chunk_index,
|
||||
path
|
||||
)
|
||||
VALUES (
|
||||
:rag_session_id,
|
||||
:repo_id,
|
||||
:blob_sha,
|
||||
0,
|
||||
:path
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"rag_session_id": repo_session_id,
|
||||
"repo_id": project_id,
|
||||
"blob_sha": blob_sha,
|
||||
"path": path,
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def cache_file_chunks(self, repo_id: str, path: str, blob_sha: str, items: list[dict]) -> None:
|
||||
if not items:
|
||||
return
|
||||
meta = items[0]
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_blob_cache (
|
||||
repo_id,
|
||||
blob_sha,
|
||||
path,
|
||||
artifact_type,
|
||||
section,
|
||||
doc_id,
|
||||
doc_version,
|
||||
owner,
|
||||
system_component,
|
||||
last_modified,
|
||||
staleness_score
|
||||
)
|
||||
VALUES (
|
||||
:repo_id,
|
||||
:blob_sha,
|
||||
:path,
|
||||
:artifact_type,
|
||||
:section,
|
||||
:doc_id,
|
||||
:doc_version,
|
||||
:owner,
|
||||
:system_component,
|
||||
:last_modified,
|
||||
:staleness_score
|
||||
)
|
||||
ON CONFLICT (repo_id, blob_sha, path) DO UPDATE SET
|
||||
artifact_type = EXCLUDED.artifact_type,
|
||||
section = EXCLUDED.section,
|
||||
doc_id = EXCLUDED.doc_id,
|
||||
doc_version = EXCLUDED.doc_version,
|
||||
owner = EXCLUDED.owner,
|
||||
system_component = EXCLUDED.system_component,
|
||||
last_modified = EXCLUDED.last_modified,
|
||||
staleness_score = EXCLUDED.staleness_score,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"""
|
||||
),
|
||||
{
|
||||
"repo_id": repo_id,
|
||||
"blob_sha": blob_sha,
|
||||
"path": path,
|
||||
"artifact_type": meta.get("artifact_type"),
|
||||
"section": meta.get("section"),
|
||||
"doc_id": meta.get("doc_id"),
|
||||
"doc_version": meta.get("doc_version"),
|
||||
"owner": meta.get("owner"),
|
||||
"system_component": meta.get("system_component"),
|
||||
"last_modified": meta.get("last_modified"),
|
||||
"staleness_score": meta.get("staleness_score"),
|
||||
},
|
||||
)
|
||||
for item in items:
|
||||
emb = item.get("embedding") or []
|
||||
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_chunk_cache (
|
||||
repo_id,
|
||||
blob_sha,
|
||||
chunk_index,
|
||||
content,
|
||||
embedding,
|
||||
section
|
||||
)
|
||||
VALUES (
|
||||
:repo_id,
|
||||
:blob_sha,
|
||||
:chunk_index,
|
||||
:content,
|
||||
CAST(:embedding AS vector),
|
||||
:section
|
||||
)
|
||||
ON CONFLICT (repo_id, blob_sha, chunk_index) DO UPDATE SET
|
||||
content = EXCLUDED.content,
|
||||
embedding = EXCLUDED.embedding,
|
||||
section = EXCLUDED.section,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"""
|
||||
),
|
||||
{
|
||||
"repo_id": repo_id,
|
||||
"blob_sha": blob_sha,
|
||||
"chunk_index": int(item["chunk_index"]),
|
||||
"content": item["content"],
|
||||
"embedding": emb_str,
|
||||
"section": item.get("section"),
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def retrieve(self, rag_session_id: str, query_embedding: list[float], limit: int = 5) -> list[dict]:
|
||||
emb = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
ORDER BY embedding <=> CAST(:emb AS vector)
|
||||
LIMIT :lim
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "emb": emb, "lim": limit},
|
||||
).mappings().fetchall()
|
||||
return [dict(x) for x in rows]
|
||||
|
||||
def fallback_chunks(self, rag_session_id: str, limit: int = 5) -> list[dict]:
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
ORDER BY id DESC
|
||||
LIMIT :lim
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "lim": limit},
|
||||
).mappings().fetchall()
|
||||
return [dict(x) for x in rows]
|
||||
|
||||
def _insert_chunks(self, conn, rag_session_id: str, items: list[dict]) -> None:
|
||||
for item in items:
|
||||
emb = item.get("embedding") or []
|
||||
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_chunks (
|
||||
rag_session_id,
|
||||
path,
|
||||
chunk_index,
|
||||
content,
|
||||
embedding,
|
||||
artifact_type,
|
||||
section,
|
||||
doc_id,
|
||||
doc_version,
|
||||
owner,
|
||||
system_component,
|
||||
last_modified,
|
||||
staleness_score,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
VALUES (
|
||||
:sid,
|
||||
:path,
|
||||
:idx,
|
||||
:content,
|
||||
CAST(:emb AS vector),
|
||||
:artifact_type,
|
||||
:section,
|
||||
:doc_id,
|
||||
:doc_version,
|
||||
:owner,
|
||||
:system_component,
|
||||
:last_modified,
|
||||
:staleness_score,
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": rag_session_id,
|
||||
"path": item["path"],
|
||||
"idx": int(item["chunk_index"]),
|
||||
"content": item["content"],
|
||||
"emb": emb_str,
|
||||
"artifact_type": item.get("artifact_type"),
|
||||
"section": item.get("section"),
|
||||
"doc_id": item.get("doc_id"),
|
||||
"doc_version": item.get("doc_version"),
|
||||
"owner": item.get("owner"),
|
||||
"system_component": item.get("system_component"),
|
||||
"last_modified": item.get("last_modified"),
|
||||
"staleness_score": item.get("staleness_score"),
|
||||
},
|
||||
)
|
||||
repo_id = str(item.get("repo_id") or "").strip()
|
||||
blob_sha = str(item.get("blob_sha") or "").strip()
|
||||
if repo_id and blob_sha:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_session_chunk_map (
|
||||
rag_session_id,
|
||||
repo_id,
|
||||
blob_sha,
|
||||
chunk_index,
|
||||
path
|
||||
) VALUES (
|
||||
:sid,
|
||||
:repo_id,
|
||||
:blob_sha,
|
||||
:chunk_index,
|
||||
:path
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": rag_session_id,
|
||||
"repo_id": repo_id,
|
||||
"blob_sha": blob_sha,
|
||||
"chunk_index": int(item["chunk_index"]),
|
||||
"path": item["path"],
|
||||
},
|
||||
)
|
||||
|
||||
def _parse_vector(self, value: str) -> list[float]:
|
||||
text_value = value.strip()
|
||||
if not text_value:
|
||||
return []
|
||||
if text_value.startswith("[") and text_value.endswith("]"):
|
||||
text_value = text_value[1:-1]
|
||||
if not text_value:
|
||||
return []
|
||||
return [float(part.strip()) for part in text_value.split(",") if part.strip()]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,20 @@
|
||||
class TextChunker:
|
||||
def __init__(self, chunk_size: int = 900, overlap: int = 120) -> None:
|
||||
self._chunk_size = chunk_size
|
||||
self._overlap = overlap
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
cleaned = text.replace("\r\n", "\n")
|
||||
if not cleaned.strip():
|
||||
return []
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
while start < len(cleaned):
|
||||
end = min(len(cleaned), start + self._chunk_size)
|
||||
piece = cleaned[start:end].strip()
|
||||
if piece:
|
||||
chunks.append(piece)
|
||||
if end == len(cleaned):
|
||||
break
|
||||
start = max(0, end - self._overlap)
|
||||
return chunks
|
||||
@@ -0,0 +1,12 @@
|
||||
import math
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return -1.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(y * y for y in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return -1.0
|
||||
return dot / (norm_a * norm_b)
|
||||
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import isawaitable
|
||||
|
||||
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag_session.repository import RagRepository
|
||||
from app.modules.rag_session.retrieval.chunker import TextChunker
|
||||
|
||||
|
||||
class RagService:
|
||||
def __init__(
|
||||
self,
|
||||
embedder: GigaChatEmbedder,
|
||||
repository: RagRepository,
|
||||
chunker: TextChunker | None = None,
|
||||
) -> None:
|
||||
self._embedder = embedder
|
||||
self._repo = repository
|
||||
self._chunker = chunker or TextChunker()
|
||||
|
||||
async def index_snapshot(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int, int, int]:
|
||||
total_files = len(files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
cache_hit_files = 0
|
||||
cache_miss_files = 0
|
||||
all_chunks: list[dict] = []
|
||||
repo_id = self._resolve_repo_id(rag_session_id)
|
||||
for index, file in enumerate(files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
try:
|
||||
blob_sha = self._blob_sha(file)
|
||||
cached = await asyncio.to_thread(self._repo.get_cached_chunks, repo_id, blob_sha)
|
||||
if cached:
|
||||
all_chunks.extend(self._build_cached_items(path, file, repo_id, blob_sha, cached))
|
||||
cache_hit_files += 1
|
||||
else:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks, file, repo_id, blob_sha)
|
||||
all_chunks.extend(embedded_chunks)
|
||||
await asyncio.to_thread(self._repo.cache_file_chunks, repo_id, path, blob_sha, embedded_chunks)
|
||||
cache_miss_files += 1
|
||||
indexed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
await asyncio.to_thread(self._repo.replace_chunks, rag_session_id, all_chunks)
|
||||
return indexed_files, failed_files, cache_hit_files, cache_miss_files
|
||||
|
||||
async def index_changes(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
changed_files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int, int, int]:
|
||||
total_files = len(changed_files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
cache_hit_files = 0
|
||||
cache_miss_files = 0
|
||||
delete_paths: list[str] = []
|
||||
upsert_chunks: list[dict] = []
|
||||
repo_id = self._resolve_repo_id(rag_session_id)
|
||||
|
||||
for index, file in enumerate(changed_files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
op = str(file.get("op", ""))
|
||||
try:
|
||||
if op == "delete":
|
||||
delete_paths.append(path)
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
if op == "upsert" and file.get("content") is not None:
|
||||
blob_sha = self._blob_sha(file)
|
||||
cached = await asyncio.to_thread(self._repo.get_cached_chunks, repo_id, blob_sha)
|
||||
if cached:
|
||||
upsert_chunks.extend(self._build_cached_items(path, file, repo_id, blob_sha, cached))
|
||||
cache_hit_files += 1
|
||||
else:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks, file, repo_id, blob_sha)
|
||||
upsert_chunks.extend(embedded_chunks)
|
||||
await asyncio.to_thread(self._repo.cache_file_chunks, repo_id, path, blob_sha, embedded_chunks)
|
||||
cache_miss_files += 1
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
failed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self._repo.apply_changes,
|
||||
rag_session_id,
|
||||
delete_paths,
|
||||
upsert_chunks,
|
||||
)
|
||||
return indexed_files, failed_files, cache_hit_files, cache_miss_files
|
||||
|
||||
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repo.retrieve(rag_session_id, query_embedding, limit=5)
|
||||
except Exception:
|
||||
rows = self._repo.fallback_chunks(rag_session_id, limit=5)
|
||||
return [{"source": row["path"], "content": row["content"]} for row in rows]
|
||||
|
||||
def _build_chunks_for_file(self, file: dict) -> list[tuple[str, int, str]]:
|
||||
path = str(file.get("path", ""))
|
||||
content = str(file.get("content", ""))
|
||||
output: list[tuple[str, int, str]] = []
|
||||
for idx, chunk in enumerate(self._chunker.chunk(content)):
|
||||
output.append((path, idx, chunk))
|
||||
return output
|
||||
|
||||
def _embed_chunks(self, raw_chunks: list[tuple[str, int, str]], file: dict, repo_id: str, blob_sha: str) -> list[dict]:
|
||||
if not raw_chunks:
|
||||
return []
|
||||
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
|
||||
metadata = self._chunk_metadata(file)
|
||||
|
||||
indexed: list[dict] = []
|
||||
for i in range(0, len(raw_chunks), batch_size):
|
||||
batch = raw_chunks[i : i + batch_size]
|
||||
texts = [x[2] for x in batch]
|
||||
vectors = self._embedder.embed(texts)
|
||||
for (path, chunk_index, content), vector in zip(batch, vectors):
|
||||
indexed.append(
|
||||
{
|
||||
"path": path,
|
||||
"chunk_index": chunk_index,
|
||||
"content": content,
|
||||
"embedding": vector,
|
||||
"repo_id": repo_id,
|
||||
"blob_sha": blob_sha,
|
||||
**metadata,
|
||||
}
|
||||
)
|
||||
return indexed
|
||||
|
||||
def _build_cached_items(
|
||||
self,
|
||||
path: str,
|
||||
file: dict,
|
||||
repo_id: str,
|
||||
blob_sha: str,
|
||||
cached: list[dict],
|
||||
) -> list[dict]:
|
||||
metadata = self._chunk_metadata(file)
|
||||
output: list[dict] = []
|
||||
for item in cached:
|
||||
output.append(
|
||||
{
|
||||
"path": path,
|
||||
"chunk_index": int(item["chunk_index"]),
|
||||
"content": str(item["content"]),
|
||||
"embedding": item.get("embedding") or [],
|
||||
"repo_id": repo_id,
|
||||
"blob_sha": blob_sha,
|
||||
**metadata,
|
||||
"section": item.get("section") or metadata.get("section"),
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def _resolve_repo_id(self, rag_session_id: str) -> str:
|
||||
session = self._repo.get_session(rag_session_id)
|
||||
if not session:
|
||||
return rag_session_id
|
||||
return str(session.get("project_id") or rag_session_id)
|
||||
|
||||
def _blob_sha(self, file: dict) -> str:
|
||||
raw = str(file.get("content_hash") or "").strip()
|
||||
if raw:
|
||||
return raw
|
||||
content = str(file.get("content") or "")
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def _chunk_metadata(self, file: dict) -> dict:
|
||||
return {
|
||||
"artifact_type": file.get("artifact_type"),
|
||||
"section": file.get("section"),
|
||||
"doc_id": file.get("doc_id"),
|
||||
"doc_version": file.get("doc_version"),
|
||||
"owner": file.get("owner"),
|
||||
"system_component": file.get("system_component"),
|
||||
"last_modified": file.get("last_modified"),
|
||||
"staleness_score": file.get("staleness_score"),
|
||||
}
|
||||
|
||||
async def _notify_progress(
|
||||
self,
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
|
||||
current_file_index: int,
|
||||
total_files: int,
|
||||
current_file_name: str,
|
||||
) -> None:
|
||||
if not progress_cb:
|
||||
return
|
||||
result = progress_cb(current_file_index, total_files, current_file_name)
|
||||
if isawaitable(result):
|
||||
await result
|
||||
@@ -0,0 +1,34 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from app.modules.rag_session.repository import RagRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagSession:
|
||||
rag_session_id: str
|
||||
project_id: str
|
||||
|
||||
|
||||
class RagSessionStore:
|
||||
def __init__(self, repository: RagRepository) -> None:
|
||||
self._repo = repository
|
||||
|
||||
def create(self, project_id: str) -> RagSession:
|
||||
session = RagSession(rag_session_id=str(uuid4()), project_id=project_id)
|
||||
self._repo.upsert_session(session.rag_session_id, session.project_id)
|
||||
return session
|
||||
|
||||
def put(self, rag_session_id: str, project_id: str) -> RagSession:
|
||||
session = RagSession(rag_session_id=rag_session_id, project_id=project_id)
|
||||
self._repo.upsert_session(rag_session_id, project_id)
|
||||
return session
|
||||
|
||||
def get(self, rag_session_id: str) -> RagSession | None:
|
||||
row = self._repo.get_session(rag_session_id)
|
||||
if not row:
|
||||
return None
|
||||
return RagSession(
|
||||
rag_session_id=str(row["rag_session_id"]),
|
||||
project_id=str(row["project_id"]),
|
||||
)
|
||||
Reference in New Issue
Block a user