первый коммит
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,9 +0,0 @@
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
|
||||
|
||||
class GigaChatEmbedder:
|
||||
def __init__(self, client: GigaChatClient) -> None:
|
||||
self._client = client
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._client.embed(texts)
|
||||
@@ -1,141 +0,0 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
from app.schemas.common import ErrorPayload, ModuleName
|
||||
from app.schemas.indexing import IndexJobStatus
|
||||
from app.modules.contracts import RagIndexer
|
||||
from app.modules.rag.job_store import IndexJob, IndexJobStore
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
|
||||
|
||||
class IndexingOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
store: IndexJobStore,
|
||||
rag: RagIndexer,
|
||||
events: EventBus,
|
||||
retry: RetryExecutor,
|
||||
) -> None:
|
||||
self._store = store
|
||||
self._rag = rag
|
||||
self._events = events
|
||||
self._retry = retry
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
async def enqueue_snapshot(self, rag_session_id: str, files: list[dict]) -> IndexJob:
|
||||
job = self._store.create(rag_session_id)
|
||||
asyncio.create_task(self._process_snapshot(job.index_job_id, rag_session_id, files))
|
||||
return job
|
||||
|
||||
async def enqueue_changes(self, rag_session_id: str, changed_files: list[dict]) -> IndexJob:
|
||||
job = self._store.create(rag_session_id)
|
||||
asyncio.create_task(self._process_changes(job.index_job_id, rag_session_id, changed_files))
|
||||
return job
|
||||
|
||||
async def _process_snapshot(self, job_id: str, rag_session_id: str, files: list[dict]) -> None:
|
||||
await self._run_with_project_lock(
|
||||
job_id=job_id,
|
||||
rag_session_id=rag_session_id,
|
||||
total_files=len(files),
|
||||
operation=lambda progress_cb: self._rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=files,
|
||||
progress_cb=progress_cb,
|
||||
),
|
||||
)
|
||||
|
||||
async def _process_changes(self, job_id: str, rag_session_id: str, changed_files: list[dict]) -> None:
|
||||
await self._run_with_project_lock(
|
||||
job_id=job_id,
|
||||
rag_session_id=rag_session_id,
|
||||
total_files=len(changed_files),
|
||||
operation=lambda progress_cb: self._rag.index_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=changed_files,
|
||||
progress_cb=progress_cb,
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_with_project_lock(self, job_id: str, rag_session_id: str, total_files: int, operation) -> None:
|
||||
lock = self._locks[rag_session_id]
|
||||
async with lock:
|
||||
job = self._store.get(job_id)
|
||||
if not job:
|
||||
return
|
||||
job.status = IndexJobStatus.RUNNING
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
||||
)
|
||||
try:
|
||||
async def progress_cb(current_file_index: int, total: int, current_file_name: str) -> None:
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_progress",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"current_file_index": current_file_index,
|
||||
"total_files": total,
|
||||
"processed_files": current_file_index,
|
||||
"current_file_path": current_file_name,
|
||||
"current_file_name": current_file_name,
|
||||
},
|
||||
)
|
||||
|
||||
indexed, failed = await self._retry.run(lambda: operation(progress_cb))
|
||||
job.status = IndexJobStatus.DONE
|
||||
job.indexed_files = indexed
|
||||
job.failed_files = failed
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": job.status.value,
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"total_files": total_files,
|
||||
},
|
||||
)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"terminal",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": "done",
|
||||
"indexed_files": indexed,
|
||||
"failed_files": failed,
|
||||
"total_files": total_files,
|
||||
},
|
||||
)
|
||||
except (TimeoutError, ConnectionError, OSError) as exc:
|
||||
job.status = IndexJobStatus.ERROR
|
||||
job.error = ErrorPayload(
|
||||
code="index_retry_exhausted",
|
||||
desc=f"Temporary indexing failure after retries: {exc}",
|
||||
module=ModuleName.RAG,
|
||||
)
|
||||
self._store.save(job)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"index_status",
|
||||
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
||||
)
|
||||
await self._events.publish(
|
||||
job_id,
|
||||
"terminal",
|
||||
{
|
||||
"index_job_id": job_id,
|
||||
"status": "error",
|
||||
"total_files": total_files,
|
||||
"error": {
|
||||
"code": job.error.code,
|
||||
"desc": job.error.desc,
|
||||
"module": job.error.module.value,
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -1,66 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from app.modules.rag.repository import RagRepository
|
||||
from app.schemas.common import ErrorPayload, ModuleName
|
||||
from app.schemas.indexing import IndexJobStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexJob:
|
||||
index_job_id: str
|
||||
rag_session_id: str
|
||||
status: IndexJobStatus = IndexJobStatus.QUEUED
|
||||
indexed_files: int = 0
|
||||
failed_files: int = 0
|
||||
error: ErrorPayload | None = None
|
||||
|
||||
|
||||
class IndexJobStore:
|
||||
def __init__(self, repository: RagRepository) -> None:
|
||||
self._repo = repository
|
||||
|
||||
def create(self, rag_session_id: str) -> IndexJob:
|
||||
job = IndexJob(index_job_id=str(uuid4()), rag_session_id=rag_session_id)
|
||||
self._repo.create_job(job.index_job_id, rag_session_id, job.status.value)
|
||||
return job
|
||||
|
||||
def get(self, index_job_id: str) -> IndexJob | None:
|
||||
row = self._repo.get_job(index_job_id)
|
||||
if not row:
|
||||
return None
|
||||
payload = None
|
||||
if row.error_code:
|
||||
module = ModuleName.RAG
|
||||
if row.error_module:
|
||||
try:
|
||||
module = ModuleName(row.error_module)
|
||||
except ValueError:
|
||||
module = ModuleName.RAG
|
||||
payload = ErrorPayload(
|
||||
code=row.error_code,
|
||||
desc=row.error_desc or "",
|
||||
module=module,
|
||||
)
|
||||
return IndexJob(
|
||||
index_job_id=row.index_job_id,
|
||||
rag_session_id=row.rag_session_id,
|
||||
status=IndexJobStatus(row.status),
|
||||
indexed_files=row.indexed_files,
|
||||
failed_files=row.failed_files,
|
||||
error=payload,
|
||||
)
|
||||
|
||||
def save(self, job: IndexJob) -> None:
|
||||
error_code = job.error.code if job.error else None
|
||||
error_desc = job.error.desc if job.error else None
|
||||
error_module = job.error.module.value if job.error else None
|
||||
self._repo.update_job(
|
||||
job.index_job_id,
|
||||
status=job.status.value,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error_code=error_code,
|
||||
error_desc=error_desc,
|
||||
error_module=error_module,
|
||||
)
|
||||
@@ -1,247 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag.indexing_service import IndexingOrchestrator
|
||||
from app.modules.rag.job_store import IndexJobStore
|
||||
from app.modules.rag.repository import RagRepository
|
||||
from app.modules.rag.retrieval.chunker import TextChunker
|
||||
from app.modules.rag.session_store import RagSessionStore
|
||||
from app.modules.rag.service import RagService
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
from app.schemas.common import ModuleName
|
||||
from app.schemas.indexing import (
|
||||
IndexChangesRequest,
|
||||
IndexJobQueuedResponse,
|
||||
IndexJobResponse,
|
||||
IndexSnapshotRequest,
|
||||
)
|
||||
from app.schemas.rag_sessions import (
|
||||
RagSessionChangesRequest,
|
||||
RagSessionCreateRequest,
|
||||
RagSessionCreateResponse,
|
||||
RagSessionJobResponse,
|
||||
)
|
||||
|
||||
|
||||
class RagModule:
|
||||
def __init__(self, event_bus: EventBus, retry: RetryExecutor, repository: RagRepository) -> None:
|
||||
self._events = event_bus
|
||||
self.repository = repository
|
||||
settings = GigaChatSettings.from_env()
|
||||
token_provider = GigaChatTokenProvider(settings)
|
||||
client = GigaChatClient(settings, token_provider)
|
||||
embedder = GigaChatEmbedder(client)
|
||||
self.rag = RagService(embedder=embedder, repository=repository, chunker=TextChunker())
|
||||
self.sessions = RagSessionStore(repository)
|
||||
self.jobs = IndexJobStore(repository)
|
||||
self.indexing = IndexingOrchestrator(
|
||||
store=self.jobs,
|
||||
rag=self.rag,
|
||||
events=event_bus,
|
||||
retry=retry,
|
||||
)
|
||||
|
||||
def public_router(self) -> APIRouter:
|
||||
router = APIRouter(tags=["rag"])
|
||||
|
||||
@router.post("/api/rag/sessions", response_model=RagSessionCreateResponse)
|
||||
async def create_rag_session(request: RagSessionCreateRequest) -> RagSessionCreateResponse:
|
||||
session = self.sessions.create(request.project_id)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return RagSessionCreateResponse(
|
||||
rag_session_id=session.rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
)
|
||||
|
||||
@router.post("/api/rag/sessions/{rag_session_id}/changes", response_model=IndexJobQueuedResponse)
|
||||
async def rag_session_changes(
|
||||
rag_session_id: str,
|
||||
request: RagSessionChangesRequest,
|
||||
) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.get(rag_session_id)
|
||||
if not session:
|
||||
raise AppError("not_found", f"RAG session not found: {rag_session_id}", ModuleName.RAG)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}", response_model=RagSessionJobResponse)
|
||||
async def rag_session_job(rag_session_id: str, index_job_id: str) -> RagSessionJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return RagSessionJobResponse(
|
||||
rag_session_id=rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error=job.error.model_dump(mode="json") if job.error else None,
|
||||
)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events")
|
||||
async def rag_session_job_events(rag_session_id: str, index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# Legacy compatibility endpoints.
|
||||
legacy = APIRouter(prefix="/api/index", tags=["index"])
|
||||
|
||||
@legacy.post("/snapshot", response_model=IndexJobQueuedResponse)
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.put(
|
||||
rag_session_id=request.project_id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.post("/changes", response_model=IndexJobQueuedResponse)
|
||||
async def index_changes(request: IndexChangesRequest) -> IndexJobQueuedResponse:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}", response_model=IndexJobResponse)
|
||||
async def get_index_job(index_job_id: str) -> IndexJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return IndexJobResponse(
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error=job.error,
|
||||
)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}/events")
|
||||
async def get_index_job_events(index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
router.include_router(legacy)
|
||||
return router
|
||||
|
||||
def internal_router(self) -> APIRouter:
|
||||
router = APIRouter(prefix="/internal/rag", tags=["internal-rag"])
|
||||
|
||||
@router.post("/index/snapshot")
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
indexed, failed = await self.rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return {"indexed_files": indexed, "failed_files": failed}
|
||||
|
||||
@router.post("/index/changes")
|
||||
async def index_changes(request: IndexChangesRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
indexed, failed = await self.rag.index_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return {"indexed_files": indexed, "failed_files": failed}
|
||||
|
||||
@router.get("/index/jobs/{index_job_id}")
|
||||
async def get_job(index_job_id: str) -> dict:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
return {"status": "not_found"}
|
||||
return {
|
||||
"index_job_id": job.index_job_id,
|
||||
"status": job.status.value,
|
||||
"indexed_files": job.indexed_files,
|
||||
"failed_files": job.failed_files,
|
||||
"error": job.error.model_dump(mode="json") if job.error else None,
|
||||
}
|
||||
|
||||
@router.post("/retrieve")
|
||||
async def retrieve(payload: dict) -> dict:
|
||||
rag_session_id = payload.get("rag_session_id") or payload.get("project_id", "")
|
||||
ctx = await self.rag.retrieve(
|
||||
rag_session_id=rag_session_id,
|
||||
query=payload.get("query", ""),
|
||||
)
|
||||
return {"items": ctx}
|
||||
|
||||
return router
|
||||
@@ -1,261 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.modules.shared.db import get_engine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagJobRow:
|
||||
index_job_id: str
|
||||
rag_session_id: str
|
||||
status: str
|
||||
indexed_files: int
|
||||
failed_files: int
|
||||
error_code: str | None
|
||||
error_desc: str | None
|
||||
error_module: str | None
|
||||
|
||||
|
||||
class RagRepository:
|
||||
def ensure_tables(self) -> None:
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_sessions (
|
||||
rag_session_id VARCHAR(64) PRIMARY KEY,
|
||||
project_id VARCHAR(512) NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_index_jobs (
|
||||
index_job_id VARCHAR(64) PRIMARY KEY,
|
||||
rag_session_id VARCHAR(64) NOT NULL,
|
||||
status VARCHAR(16) NOT NULL,
|
||||
indexed_files INTEGER NOT NULL DEFAULT 0,
|
||||
failed_files INTEGER NOT NULL DEFAULT 0,
|
||||
error_code VARCHAR(128) NULL,
|
||||
error_desc TEXT NULL,
|
||||
error_module VARCHAR(64) NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chunks (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
rag_session_id VARCHAR(64) NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding vector NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE rag_chunks
|
||||
ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE rag_chunks
|
||||
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_session ON rag_chunks (rag_session_id)"))
|
||||
conn.commit()
|
||||
|
||||
def upsert_session(self, rag_session_id: str, project_id: str) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_sessions (rag_session_id, project_id)
|
||||
VALUES (:sid, :pid)
|
||||
ON CONFLICT (rag_session_id) DO UPDATE SET project_id = EXCLUDED.project_id
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "pid": project_id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def session_exists(self, rag_session_id: str) -> bool:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT 1 FROM rag_sessions WHERE rag_session_id = :sid"),
|
||||
{"sid": rag_session_id},
|
||||
).fetchone()
|
||||
return bool(row)
|
||||
|
||||
def get_session(self, rag_session_id: str) -> dict | None:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT rag_session_id, project_id FROM rag_sessions WHERE rag_session_id = :sid"),
|
||||
{"sid": rag_session_id},
|
||||
).mappings().fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def create_job(self, index_job_id: str, rag_session_id: str, status: str) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_index_jobs (index_job_id, rag_session_id, status)
|
||||
VALUES (:jid, :sid, :status)
|
||||
"""
|
||||
),
|
||||
{"jid": index_job_id, "sid": rag_session_id, "status": status},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_job(
|
||||
self,
|
||||
index_job_id: str,
|
||||
*,
|
||||
status: str,
|
||||
indexed_files: int,
|
||||
failed_files: int,
|
||||
error_code: str | None = None,
|
||||
error_desc: str | None = None,
|
||||
error_module: str | None = None,
|
||||
) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_index_jobs
|
||||
SET status = :status,
|
||||
indexed_files = :indexed,
|
||||
failed_files = :failed,
|
||||
error_code = :ecode,
|
||||
error_desc = :edesc,
|
||||
error_module = :emodule,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE index_job_id = :jid
|
||||
"""
|
||||
),
|
||||
{
|
||||
"jid": index_job_id,
|
||||
"status": status,
|
||||
"indexed": indexed_files,
|
||||
"failed": failed_files,
|
||||
"ecode": error_code,
|
||||
"edesc": error_desc,
|
||||
"emodule": error_module,
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_job(self, index_job_id: str) -> RagJobRow | None:
|
||||
with get_engine().connect() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT index_job_id, rag_session_id, status, indexed_files, failed_files,
|
||||
error_code, error_desc, error_module
|
||||
FROM rag_index_jobs
|
||||
WHERE index_job_id = :jid
|
||||
"""
|
||||
),
|
||||
{"jid": index_job_id},
|
||||
).mappings().fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return RagJobRow(**dict(row))
|
||||
|
||||
def replace_chunks(self, rag_session_id: str, items: list[dict]) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
conn.execute(text("DELETE FROM rag_chunks WHERE rag_session_id = :sid"), {"sid": rag_session_id})
|
||||
self._insert_chunks(conn, rag_session_id, items)
|
||||
conn.commit()
|
||||
|
||||
def apply_changes(self, rag_session_id: str, delete_paths: list[str], upserts: list[dict]) -> None:
|
||||
with get_engine().connect() as conn:
|
||||
if delete_paths:
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": delete_paths},
|
||||
)
|
||||
if upserts:
|
||||
paths = sorted({str(x["path"]) for x in upserts})
|
||||
conn.execute(
|
||||
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
||||
{"sid": rag_session_id, "paths": paths},
|
||||
)
|
||||
self._insert_chunks(conn, rag_session_id, upserts)
|
||||
conn.commit()
|
||||
|
||||
def retrieve(self, rag_session_id: str, query_embedding: list[float], limit: int = 5) -> list[dict]:
|
||||
emb = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
ORDER BY embedding <=> CAST(:emb AS vector)
|
||||
LIMIT :lim
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "emb": emb, "lim": limit},
|
||||
).mappings().fetchall()
|
||||
return [dict(x) for x in rows]
|
||||
|
||||
def fallback_chunks(self, rag_session_id: str, limit: int = 5) -> list[dict]:
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
ORDER BY id DESC
|
||||
LIMIT :lim
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "lim": limit},
|
||||
).mappings().fetchall()
|
||||
return [dict(x) for x in rows]
|
||||
|
||||
def _insert_chunks(self, conn, rag_session_id: str, items: list[dict]) -> None:
|
||||
for item in items:
|
||||
emb = item.get("embedding") or []
|
||||
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_chunks (rag_session_id, path, chunk_index, content, embedding, created_at, updated_at)
|
||||
VALUES (:sid, :path, :idx, :content, CAST(:emb AS vector), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": rag_session_id,
|
||||
"path": item["path"],
|
||||
"idx": int(item["chunk_index"]),
|
||||
"content": item["content"],
|
||||
"emb": emb_str,
|
||||
},
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,20 +0,0 @@
|
||||
class TextChunker:
|
||||
def __init__(self, chunk_size: int = 900, overlap: int = 120) -> None:
|
||||
self._chunk_size = chunk_size
|
||||
self._overlap = overlap
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
cleaned = text.replace("\r\n", "\n")
|
||||
if not cleaned.strip():
|
||||
return []
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
while start < len(cleaned):
|
||||
end = min(len(cleaned), start + self._chunk_size)
|
||||
piece = cleaned[start:end].strip()
|
||||
if piece:
|
||||
chunks.append(piece)
|
||||
if end == len(cleaned):
|
||||
break
|
||||
start = max(0, end - self._overlap)
|
||||
return chunks
|
||||
@@ -1,12 +0,0 @@
|
||||
import math
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return -1.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(y * y for y in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return -1.0
|
||||
return dot / (norm_a * norm_b)
|
||||
@@ -1,134 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import isawaitable
|
||||
|
||||
from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag.repository import RagRepository
|
||||
from app.modules.rag.retrieval.chunker import TextChunker
|
||||
|
||||
|
||||
class RagService:
|
||||
def __init__(
|
||||
self,
|
||||
embedder: GigaChatEmbedder,
|
||||
repository: RagRepository,
|
||||
chunker: TextChunker | None = None,
|
||||
) -> None:
|
||||
self._embedder = embedder
|
||||
self._repo = repository
|
||||
self._chunker = chunker or TextChunker()
|
||||
|
||||
async def index_snapshot(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int]:
|
||||
total_files = len(files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
all_chunks: list[dict] = []
|
||||
for index, file in enumerate(files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
try:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks)
|
||||
all_chunks.extend(embedded_chunks)
|
||||
indexed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
await asyncio.to_thread(self._repo.replace_chunks, rag_session_id, all_chunks)
|
||||
return indexed_files, failed_files
|
||||
|
||||
async def index_changes(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
changed_files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int]:
|
||||
total_files = len(changed_files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
delete_paths: list[str] = []
|
||||
upsert_chunks: list[dict] = []
|
||||
|
||||
for index, file in enumerate(changed_files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
op = str(file.get("op", ""))
|
||||
try:
|
||||
if op == "delete":
|
||||
delete_paths.append(path)
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
if op == "upsert" and file.get("content") is not None:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks)
|
||||
upsert_chunks.extend(embedded_chunks)
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
failed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self._repo.apply_changes,
|
||||
rag_session_id,
|
||||
delete_paths,
|
||||
upsert_chunks,
|
||||
)
|
||||
return indexed_files, failed_files
|
||||
|
||||
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repo.retrieve(rag_session_id, query_embedding, limit=5)
|
||||
except Exception:
|
||||
rows = self._repo.fallback_chunks(rag_session_id, limit=5)
|
||||
return [{"source": row["path"], "content": row["content"]} for row in rows]
|
||||
|
||||
def _build_chunks_for_file(self, file: dict) -> list[tuple[str, int, str]]:
|
||||
path = str(file.get("path", ""))
|
||||
content = str(file.get("content", ""))
|
||||
output: list[tuple[str, int, str]] = []
|
||||
for idx, chunk in enumerate(self._chunker.chunk(content)):
|
||||
output.append((path, idx, chunk))
|
||||
return output
|
||||
|
||||
def _embed_chunks(self, raw_chunks: list[tuple[str, int, str]]) -> list[dict]:
|
||||
if not raw_chunks:
|
||||
return []
|
||||
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
|
||||
|
||||
indexed: list[dict] = []
|
||||
for i in range(0, len(raw_chunks), batch_size):
|
||||
batch = raw_chunks[i : i + batch_size]
|
||||
texts = [x[2] for x in batch]
|
||||
vectors = self._embedder.embed(texts)
|
||||
for (path, chunk_index, content), vector in zip(batch, vectors):
|
||||
indexed.append(
|
||||
{
|
||||
"path": path,
|
||||
"chunk_index": chunk_index,
|
||||
"content": content,
|
||||
"embedding": vector,
|
||||
}
|
||||
)
|
||||
return indexed
|
||||
|
||||
async def _notify_progress(
|
||||
self,
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
|
||||
current_file_index: int,
|
||||
total_files: int,
|
||||
current_file_name: str,
|
||||
) -> None:
|
||||
if not progress_cb:
|
||||
return
|
||||
result = progress_cb(current_file_index, total_files, current_file_name)
|
||||
if isawaitable(result):
|
||||
await result
|
||||
@@ -1,34 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from app.modules.rag.repository import RagRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagSession:
|
||||
rag_session_id: str
|
||||
project_id: str
|
||||
|
||||
|
||||
class RagSessionStore:
|
||||
def __init__(self, repository: RagRepository) -> None:
|
||||
self._repo = repository
|
||||
|
||||
def create(self, project_id: str) -> RagSession:
|
||||
session = RagSession(rag_session_id=str(uuid4()), project_id=project_id)
|
||||
self._repo.upsert_session(session.rag_session_id, session.project_id)
|
||||
return session
|
||||
|
||||
def put(self, rag_session_id: str, project_id: str) -> RagSession:
|
||||
session = RagSession(rag_session_id=rag_session_id, project_id=project_id)
|
||||
self._repo.upsert_session(rag_session_id, project_id)
|
||||
return session
|
||||
|
||||
def get(self, rag_session_id: str) -> RagSession | None:
|
||||
row = self._repo.get_session(rag_session_id)
|
||||
if not row:
|
||||
return None
|
||||
return RagSession(
|
||||
rag_session_id=str(row["rag_session_id"]),
|
||||
project_id=str(row["project_id"]),
|
||||
)
|
||||
Reference in New Issue
Block a user