Первый коммит
This commit is contained in:
0
app/modules/rag/__init__.py
Normal file
0
app/modules/rag/__init__.py
Normal file
BIN
app/modules/rag/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/modules/rag/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/__pycache__/indexing_service.cpython-312.pyc
Normal file
BIN
app/modules/rag/__pycache__/indexing_service.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/__pycache__/job_store.cpython-312.pyc
Normal file
BIN
app/modules/rag/__pycache__/job_store.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/__pycache__/module.cpython-312.pyc
Normal file
BIN
app/modules/rag/__pycache__/module.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/__pycache__/repository.cpython-312.pyc
Normal file
BIN
app/modules/rag/__pycache__/repository.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/__pycache__/service.cpython-312.pyc
Normal file
BIN
app/modules/rag/__pycache__/service.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/__pycache__/session_store.cpython-312.pyc
Normal file
BIN
app/modules/rag/__pycache__/session_store.cpython-312.pyc
Normal file
Binary file not shown.
0
app/modules/rag/embedding/__init__.py
Normal file
0
app/modules/rag/embedding/__init__.py
Normal file
BIN
app/modules/rag/embedding/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/modules/rag/embedding/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
9
app/modules/rag/embedding/gigachat_embedder.py
Normal file
9
app/modules/rag/embedding/gigachat_embedder.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
|
||||
|
||||
class GigaChatEmbedder:
|
||||
def __init__(self, client: GigaChatClient) -> None:
|
||||
self._client = client
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._client.embed(texts)
|
||||
141
app/modules/rag/indexing_service.py
Normal file
141
app/modules/rag/indexing_service.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
from app.schemas.common import ErrorPayload, ModuleName
|
||||
from app.schemas.indexing import IndexJobStatus
|
||||
from app.modules.contracts import RagIndexer
|
||||
from app.modules.rag.job_store import IndexJob, IndexJobStore
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
|
||||
|
||||
class IndexingOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
store: IndexJobStore,
|
||||
rag: RagIndexer,
|
||||
events: EventBus,
|
||||
retry: RetryExecutor,
|
||||
) -> None:
|
||||
self._store = store
|
||||
self._rag = rag
|
||||
self._events = events
|
||||
self._retry = retry
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
async def enqueue_snapshot(self, rag_session_id: str, files: list[dict]) -> IndexJob:
|
||||
job = self._store.create(rag_session_id)
|
||||
asyncio.create_task(self._process_snapshot(job.index_job_id, rag_session_id, files))
|
||||
return job
|
||||
|
||||
async def enqueue_changes(self, rag_session_id: str, changed_files: list[dict]) -> IndexJob:
|
||||
job = self._store.create(rag_session_id)
|
||||
asyncio.create_task(self._process_changes(job.index_job_id, rag_session_id, changed_files))
|
||||
return job
|
||||
|
||||
async def _process_snapshot(self, job_id: str, rag_session_id: str, files: list[dict]) -> None:
|
||||
await self._run_with_project_lock(
|
||||
job_id=job_id,
|
||||
rag_session_id=rag_session_id,
|
||||
total_files=len(files),
|
||||
operation=lambda progress_cb: self._rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=files,
|
||||
progress_cb=progress_cb,
|
||||
),
|
||||
)
|
||||
|
||||
async def _process_changes(self, job_id: str, rag_session_id: str, changed_files: list[dict]) -> None:
|
||||
await self._run_with_project_lock(
|
||||
job_id=job_id,
|
||||
rag_session_id=rag_session_id,
|
||||
total_files=len(changed_files),
|
||||
operation=lambda progress_cb: self._rag.index_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=changed_files,
|
||||
progress_cb=progress_cb,
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_with_project_lock(self, job_id: str, rag_session_id: str, total_files: int, operation) -> None:
|
||||
lock = self._locks[rag_session_id]
|
||||
async with lock:
|
||||
job = self._store.get(job_id)
|
||||
if not job:
|
||||
return
|
||||
job.status = IndexJobStatus.RUNNING
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
||||
)
|
||||
try:
|
||||
async def progress_cb(current_file_index: int, total: int, current_file_name: str) -> None:
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_progress",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"current_file_index": current_file_index,
|
||||
"total_files": total,
|
||||
"processed_files": current_file_index,
|
||||
"current_file_path": current_file_name,
|
||||
"current_file_name": current_file_name,
|
||||
},
|
||||
)
|
||||
|
||||
indexed, failed = await self._retry.run(lambda: operation(progress_cb))
|
||||
job.status = IndexJobStatus.DONE
|
||||
job.indexed_files = indexed
|
||||
job.failed_files = failed
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": job.status.value,
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"total_files": total_files,
|
||||
},
|
||||
)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"terminal",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": "done",
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"total_files": total_files,
|
||||
},
|
||||
)
|
||||
except (TimeoutError, ConnectionError, OSError) as exc:
|
||||
job.status = IndexJobStatus.ERROR
|
||||
job.error = ErrorPayload(
|
||||
code="index_retry_exhausted",
|
||||
desc=f"Temporary indexing failure after retries: {exc}",
|
||||
module=ModuleName.RAG,
|
||||
)
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
||||
)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"terminal",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": "error",
|
||||
"total_files": total_files,
|
||||
"error": {
|
||||
"code": job.error.code,
|
||||
"desc": job.error.desc,
|
||||
"module": job.error.module.value,
|
||||
},
|
||||
},
|
||||
)
|
||||
66
app/modules/rag/job_store.py
Normal file
66
app/modules/rag/job_store.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from app.modules.rag.repository import RagRepository
|
||||
from app.schemas.common import ErrorPayload, ModuleName
|
||||
from app.schemas.indexing import IndexJobStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexJob:
|
||||
index_job_id: str
|
||||
rag_session_id: str
|
||||
status: IndexJobStatus = IndexJobStatus.QUEUED
|
||||
indexed_files: int = 0
|
||||
failed_files: int = 0
|
||||
error: ErrorPayload | None = None
|
||||
|
||||
|
||||
class IndexJobStore:
|
||||
def __init__(self, repository: RagRepository) -> None:
|
||||
self._repo = repository
|
||||
|
||||
def create(self, rag_session_id: str) -> IndexJob:
|
||||
job = IndexJob(index_job_id=str(uuid4()), rag_session_id=rag_session_id)
|
||||
self._repo.create_job(job.index_job_id, rag_session_id, job.status.value)
|
||||
return job
|
||||
|
||||
def get(self, index_job_id: str) -> IndexJob | None:
|
||||
row = self._repo.get_job(index_job_id)
|
||||
if not row:
|
||||
return None
|
||||
payload = None
|
||||
if row.error_code:
|
||||
module = ModuleName.RAG
|
||||
if row.error_module:
|
||||
try:
|
||||
module = ModuleName(row.error_module)
|
||||
except ValueError:
|
||||
module = ModuleName.RAG
|
||||
payload = ErrorPayload(
|
||||
code=row.error_code,
|
||||
desc=row.error_desc or "",
|
||||
module=module,
|
||||
)
|
||||
return IndexJob(
|
||||
index_job_id=row.index_job_id,
|
||||
rag_session_id=row.rag_session_id,
|
||||
status=IndexJobStatus(row.status),
|
||||
indexed_files=row.indexed_files,
|
||||
failed_files=row.failed_files,
|
||||
error=payload,
|
||||
)
|
||||
|
||||
def save(self, job: IndexJob) -> None:
|
||||
error_code = job.error.code if job.error else None
|
||||
error_desc = job.error.desc if job.error else None
|
||||
error_module = job.error.module.value if job.error else None
|
||||
self._repo.update_job(
|
||||
job.index_job_id,
|
||||
status=job.status.value,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error_code=error_code,
|
||||
error_desc=error_desc,
|
||||
error_module=error_module,
|
||||
)
|
||||
247
app/modules/rag/module.py
Normal file
247
app/modules/rag/module.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag.indexing_service import IndexingOrchestrator
|
||||
from app.modules.rag.job_store import IndexJobStore
|
||||
from app.modules.rag.repository import RagRepository
|
||||
from app.modules.rag.retrieval.chunker import TextChunker
|
||||
from app.modules.rag.session_store import RagSessionStore
|
||||
from app.modules.rag.service import RagService
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
from app.schemas.common import ModuleName
|
||||
from app.schemas.indexing import (
|
||||
IndexChangesRequest,
|
||||
IndexJobQueuedResponse,
|
||||
IndexJobResponse,
|
||||
IndexSnapshotRequest,
|
||||
)
|
||||
from app.schemas.rag_sessions import (
|
||||
RagSessionChangesRequest,
|
||||
RagSessionCreateRequest,
|
||||
RagSessionCreateResponse,
|
||||
RagSessionJobResponse,
|
||||
)
|
||||
|
||||
|
||||
class RagModule:
|
||||
def __init__(self, event_bus: EventBus, retry: RetryExecutor, repository: RagRepository) -> None:
|
||||
self._events = event_bus
|
||||
self.repository = repository
|
||||
settings = GigaChatSettings.from_env()
|
||||
token_provider = GigaChatTokenProvider(settings)
|
||||
client = GigaChatClient(settings, token_provider)
|
||||
embedder = GigaChatEmbedder(client)
|
||||
self.rag = RagService(embedder=embedder, repository=repository, chunker=TextChunker())
|
||||
self.sessions = RagSessionStore(repository)
|
||||
self.jobs = IndexJobStore(repository)
|
||||
self.indexing = IndexingOrchestrator(
|
||||
store=self.jobs,
|
||||
rag=self.rag,
|
||||
events=event_bus,
|
||||
retry=retry,
|
||||
)
|
||||
|
||||
def public_router(self) -> APIRouter:
|
||||
router = APIRouter(tags=["rag"])
|
||||
|
||||
@router.post("/api/rag/sessions", response_model=RagSessionCreateResponse)
|
||||
async def create_rag_session(request: RagSessionCreateRequest) -> RagSessionCreateResponse:
|
||||
session = self.sessions.create(request.project_id)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return RagSessionCreateResponse(
|
||||
rag_session_id=session.rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
)
|
||||
|
||||
@router.post("/api/rag/sessions/{rag_session_id}/changes", response_model=IndexJobQueuedResponse)
|
||||
async def rag_session_changes(
|
||||
rag_session_id: str,
|
||||
request: RagSessionChangesRequest,
|
||||
) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.get(rag_session_id)
|
||||
if not session:
|
||||
raise AppError("not_found", f"RAG session not found: {rag_session_id}", ModuleName.RAG)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}", response_model=RagSessionJobResponse)
|
||||
async def rag_session_job(rag_session_id: str, index_job_id: str) -> RagSessionJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return RagSessionJobResponse(
|
||||
rag_session_id=rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error=job.error.model_dump(mode="json") if job.error else None,
|
||||
)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events")
|
||||
async def rag_session_job_events(rag_session_id: str, index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# Legacy compatibility endpoints.
|
||||
legacy = APIRouter(prefix="/api/index", tags=["index"])
|
||||
|
||||
@legacy.post("/snapshot", response_model=IndexJobQueuedResponse)
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.put(
|
||||
rag_session_id=request.project_id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.post("/changes", response_model=IndexJobQueuedResponse)
|
||||
async def index_changes(request: IndexChangesRequest) -> IndexJobQueuedResponse:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}", response_model=IndexJobResponse)
|
||||
async def get_index_job(index_job_id: str) -> IndexJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return IndexJobResponse(
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error=job.error,
|
||||
)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}/events")
|
||||
async def get_index_job_events(index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
router.include_router(legacy)
|
||||
return router
|
||||
|
||||
def internal_router(self) -> APIRouter:
|
||||
router = APIRouter(prefix="/internal/rag", tags=["internal-rag"])
|
||||
|
||||
@router.post("/index/snapshot")
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
indexed, failed = await self.rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return {"indexed_files": indexed, "failed_files": failed}
|
||||
|
||||
@router.post("/index/changes")
|
||||
async def index_changes(request: IndexChangesRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
indexed, failed = await self.rag.index_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return {"indexed_files": indexed, "failed_files": failed}
|
||||
|
||||
@router.get("/index/jobs/{index_job_id}")
|
||||
async def get_job(index_job_id: str) -> dict:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
return {"status": "not_found"}
|
||||
return {
|
||||
"index_job_id": job.index_job_id,
|
||||
"status": job.status.value,
|
||||
"indexed_files": job.indexed_files,
|
||||
"failed_files": job.failed_files,
|
||||
"error": job.error.model_dump(mode="json") if job.error else None,
|
||||
}
|
||||
|
||||
@router.post("/retrieve")
|
||||
async def retrieve(payload: dict) -> dict:
|
||||
rag_session_id = payload.get("rag_session_id") or payload.get("project_id", "")
|
||||
ctx = await self.rag.retrieve(
|
||||
rag_session_id=rag_session_id,
|
||||
query=payload.get("query", ""),
|
||||
)
|
||||
return {"items": ctx}
|
||||
|
||||
return router
|
||||
261
app/modules/rag/repository.py
Normal file
261
app/modules/rag/repository.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.modules.shared.db import get_engine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagJobRow:
|
||||
index_job_id: str
|
||||
rag_session_id: str
|
||||
status: str
|
||||
indexed_files: int
|
||||
failed_files: int
|
||||
error_code: str | None
|
||||
error_desc: str | None
|
||||
error_module: str | None
|
||||
|
||||
|
||||
class RagRepository:
|
||||
def ensure_tables(self) -> None:
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_sessions (
|
||||
rag_session_id VARCHAR(64) PRIMARY KEY,
|
||||
project_id VARCHAR(512) NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_index_jobs (
|
||||
index_job_id VARCHAR(64) PRIMARY KEY,
|
||||
rag_session_id VARCHAR(64) NOT NULL,
|
||||
status VARCHAR(16) NOT NULL,
|
||||
indexed_files INTEGER NOT NULL DEFAULT 0,
|
||||
failed_files INTEGER NOT NULL DEFAULT 0,
|
||||
error_code VARCHAR(128) NULL,
|
||||
error_desc TEXT NULL,
|
||||
error_module VARCHAR(64) NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chunks (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
rag_session_id VARCHAR(64) NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding vector NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE rag_chunks
|
||||
ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE rag_chunks
|
||||
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_session ON rag_chunks (rag_session_id)"))
|
||||
conn.commit()
|
||||
|
||||
def upsert_session(self, rag_session_id: str, project_id: str) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_sessions (rag_session_id, project_id)
|
||||
VALUES (:sid, :pid)
|
||||
ON CONFLICT (rag_session_id) DO UPDATE SET project_id = EXCLUDED.project_id
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "pid": project_id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def session_exists(self, rag_session_id: str) -> bool:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT 1 FROM rag_sessions WHERE rag_session_id = :sid"),
|
||||
{"sid": rag_session_id},
|
||||
).fetchone()
|
||||
return bool(row)
|
||||
|
||||
def get_session(self, rag_session_id: str) -> dict | None:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT rag_session_id, project_id FROM rag_sessions WHERE rag_session_id = :sid"),
|
||||
{"sid": rag_session_id},
|
||||
).mappings().fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def create_job(self, index_job_id: str, rag_session_id: str, status: str) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_index_jobs (index_job_id, rag_session_id, status)
|
||||
VALUES (:jid, :sid, :status)
|
||||
"""
|
||||
),
|
||||
{"jid": index_job_id, "sid": rag_session_id, "status": status},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_job(
|
||||
self,
|
||||
index_job_id: str,
|
||||
*,
|
||||
status: str,
|
||||
indexed_files: int,
|
||||
failed_files: int,
|
||||
error_code: str | None = None,
|
||||
error_desc: str | None = None,
|
||||
error_module: str | None = None,
|
||||
) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_index_jobs
|
||||
SET status = :status,
|
||||
indexed_files = :indexed,
|
||||
failed_files = :failed,
|
||||
error_code = :ecode,
|
||||
error_desc = :edesc,
|
||||
error_module = :emodule,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE index_job_id = :jid
|
||||
"""
|
||||
),
|
||||
{
|
||||
"jid": index_job_id,
|
||||
"status": status,
|
||||
"indexed": indexed_files,
|
||||
"failed": failed_files,
|
||||
"ecode": error_code,
|
||||
"edesc": error_desc,
|
||||
"emodule": error_module,
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_job(self, index_job_id: str) -> RagJobRow | None:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT index_job_id, rag_session_id, status, indexed_files, failed_files,
|
||||
error_code, error_desc, error_module
|
||||
FROM rag_index_jobs
|
||||
WHERE index_job_id = :jid
|
||||
"""
|
||||
),
|
||||
{"jid": index_job_id},
|
||||
).mappings().fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return RagJobRow(**dict(row))
|
||||
|
||||
def replace_chunks(self, rag_session_id: str, items: list[dict]) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(text("DELETE FROM rag_chunks WHERE rag_session_id = :sid"), {"sid": rag_session_id})
|
||||
self._insert_chunks(conn, rag_session_id, items)
|
||||
conn.commit()
|
||||
|
||||
def apply_changes(self, rag_session_id: str, delete_paths: list[str], upserts: list[dict]) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
if delete_paths:
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": delete_paths},
|
||||
)
|
||||
if upserts:
|
||||
paths = sorted({str(x["path"]) for x in upserts})
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": paths},
|
||||
)
|
||||
self._insert_chunks(conn, rag_session_id, upserts)
|
||||
conn.commit()
|
||||
|
||||
def retrieve(self, rag_session_id: str, query_embedding: list[float], limit: int = 5) -> list[dict]:
|
||||
emb = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
ORDER BY embedding <=> CAST(:emb AS vector)
|
||||
LIMIT :lim
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "emb": emb, "lim": limit},
|
||||
).mappings().fetchall()
|
||||
return [dict(x) for x in rows]
|
||||
|
||||
def fallback_chunks(self, rag_session_id: str, limit: int = 5) -> list[dict]:
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
ORDER BY id DESC
|
||||
LIMIT :lim
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "lim": limit},
|
||||
).mappings().fetchall()
|
||||
return [dict(x) for x in rows]
|
||||
|
||||
def _insert_chunks(self, conn, rag_session_id: str, items: list[dict]) -> None:
|
||||
for item in items:
|
||||
emb = item.get("embedding") or []
|
||||
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_chunks (rag_session_id, path, chunk_index, content, embedding, created_at, updated_at)
|
||||
VALUES (:sid, :path, :idx, :content, CAST(:emb AS vector), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": rag_session_id,
|
||||
"path": item["path"],
|
||||
"idx": int(item["chunk_index"]),
|
||||
"content": item["content"],
|
||||
"emb": emb_str,
|
||||
},
|
||||
)
|
||||
0
app/modules/rag/retrieval/__init__.py
Normal file
0
app/modules/rag/retrieval/__init__.py
Normal file
BIN
app/modules/rag/retrieval/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/modules/rag/retrieval/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/retrieval/__pycache__/chunker.cpython-312.pyc
Normal file
BIN
app/modules/rag/retrieval/__pycache__/chunker.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/retrieval/__pycache__/scoring.cpython-312.pyc
Normal file
BIN
app/modules/rag/retrieval/__pycache__/scoring.cpython-312.pyc
Normal file
Binary file not shown.
20
app/modules/rag/retrieval/chunker.py
Normal file
20
app/modules/rag/retrieval/chunker.py
Normal file
@@ -0,0 +1,20 @@
|
||||
class TextChunker:
|
||||
def __init__(self, chunk_size: int = 900, overlap: int = 120) -> None:
|
||||
self._chunk_size = chunk_size
|
||||
self._overlap = overlap
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
cleaned = text.replace("\r\n", "\n")
|
||||
if not cleaned.strip():
|
||||
return []
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
while start < len(cleaned):
|
||||
end = min(len(cleaned), start + self._chunk_size)
|
||||
piece = cleaned[start:end].strip()
|
||||
if piece:
|
||||
chunks.append(piece)
|
||||
if end == len(cleaned):
|
||||
break
|
||||
start = max(0, end - self._overlap)
|
||||
return chunks
|
||||
12
app/modules/rag/retrieval/scoring.py
Normal file
12
app/modules/rag/retrieval/scoring.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import math
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return -1.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(y * y for y in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return -1.0
|
||||
return dot / (norm_a * norm_b)
|
||||
134
app/modules/rag/service.py
Normal file
134
app/modules/rag/service.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import isawaitable
|
||||
|
||||
from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag.repository import RagRepository
|
||||
from app.modules.rag.retrieval.chunker import TextChunker
|
||||
|
||||
|
||||
class RagService:
|
||||
def __init__(
|
||||
self,
|
||||
embedder: GigaChatEmbedder,
|
||||
repository: RagRepository,
|
||||
chunker: TextChunker | None = None,
|
||||
) -> None:
|
||||
self._embedder = embedder
|
||||
self._repo = repository
|
||||
self._chunker = chunker or TextChunker()
|
||||
|
||||
async def index_snapshot(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int]:
|
||||
total_files = len(files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
all_chunks: list[dict] = []
|
||||
for index, file in enumerate(files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
try:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks)
|
||||
all_chunks.extend(embedded_chunks)
|
||||
indexed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
await asyncio.to_thread(self._repo.replace_chunks, rag_session_id, all_chunks)
|
||||
return indexed_files, failed_files
|
||||
|
||||
async def index_changes(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
changed_files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int]:
|
||||
total_files = len(changed_files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
delete_paths: list[str] = []
|
||||
upsert_chunks: list[dict] = []
|
||||
|
||||
for index, file in enumerate(changed_files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
op = str(file.get("op", ""))
|
||||
try:
|
||||
if op == "delete":
|
||||
delete_paths.append(path)
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
if op == "upsert" and file.get("content") is not None:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks)
|
||||
upsert_chunks.extend(embedded_chunks)
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
failed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self._repo.apply_changes,
|
||||
rag_session_id,
|
||||
delete_paths,
|
||||
upsert_chunks,
|
||||
)
|
||||
return indexed_files, failed_files
|
||||
|
||||
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repo.retrieve(rag_session_id, query_embedding, limit=5)
|
||||
except Exception:
|
||||
rows = self._repo.fallback_chunks(rag_session_id, limit=5)
|
||||
return [{"source": row["path"], "content": row["content"]} for row in rows]
|
||||
|
||||
def _build_chunks_for_file(self, file: dict) -> list[tuple[str, int, str]]:
|
||||
path = str(file.get("path", ""))
|
||||
content = str(file.get("content", ""))
|
||||
output: list[tuple[str, int, str]] = []
|
||||
for idx, chunk in enumerate(self._chunker.chunk(content)):
|
||||
output.append((path, idx, chunk))
|
||||
return output
|
||||
|
||||
def _embed_chunks(self, raw_chunks: list[tuple[str, int, str]]) -> list[dict]:
|
||||
if not raw_chunks:
|
||||
return []
|
||||
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
|
||||
|
||||
indexed: list[dict] = []
|
||||
for i in range(0, len(raw_chunks), batch_size):
|
||||
batch = raw_chunks[i : i + batch_size]
|
||||
texts = [x[2] for x in batch]
|
||||
vectors = self._embedder.embed(texts)
|
||||
for (path, chunk_index, content), vector in zip(batch, vectors):
|
||||
indexed.append(
|
||||
{
|
||||
"path": path,
|
||||
"chunk_index": chunk_index,
|
||||
"content": content,
|
||||
"embedding": vector,
|
||||
}
|
||||
)
|
||||
return indexed
|
||||
|
||||
async def _notify_progress(
|
||||
self,
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
|
||||
current_file_index: int,
|
||||
total_files: int,
|
||||
current_file_name: str,
|
||||
) -> None:
|
||||
if not progress_cb:
|
||||
return
|
||||
result = progress_cb(current_file_index, total_files, current_file_name)
|
||||
if isawaitable(result):
|
||||
await result
|
||||
34
app/modules/rag/session_store.py
Normal file
34
app/modules/rag/session_store.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from app.modules.rag.repository import RagRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagSession:
|
||||
rag_session_id: str
|
||||
project_id: str
|
||||
|
||||
|
||||
class RagSessionStore:
|
||||
def __init__(self, repository: RagRepository) -> None:
|
||||
self._repo = repository
|
||||
|
||||
def create(self, project_id: str) -> RagSession:
|
||||
session = RagSession(rag_session_id=str(uuid4()), project_id=project_id)
|
||||
self._repo.upsert_session(session.rag_session_id, session.project_id)
|
||||
return session
|
||||
|
||||
def put(self, rag_session_id: str, project_id: str) -> RagSession:
|
||||
session = RagSession(rag_session_id=rag_session_id, project_id=project_id)
|
||||
self._repo.upsert_session(rag_session_id, project_id)
|
||||
return session
|
||||
|
||||
def get(self, rag_session_id: str) -> RagSession | None:
|
||||
row = self._repo.get_session(rag_session_id)
|
||||
if not row:
|
||||
return None
|
||||
return RagSession(
|
||||
rag_session_id=str(row["rag_session_id"]),
|
||||
project_id=str(row["project_id"]),
|
||||
)
|
||||
Reference in New Issue
Block a user