35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from dataclasses import dataclass
|
|
from uuid import uuid4
|
|
|
|
from app.modules.rag_session.repository import RagRepository
|
|
|
|
|
|
@dataclass
|
|
class RagSession:
|
|
rag_session_id: str
|
|
project_id: str
|
|
|
|
|
|
class RagSessionStore:
|
|
def __init__(self, repository: RagRepository) -> None:
|
|
self._repo = repository
|
|
|
|
def create(self, project_id: str) -> RagSession:
|
|
session = RagSession(rag_session_id=str(uuid4()), project_id=project_id)
|
|
self._repo.upsert_session(session.rag_session_id, session.project_id)
|
|
return session
|
|
|
|
def put(self, rag_session_id: str, project_id: str) -> RagSession:
|
|
session = RagSession(rag_session_id=rag_session_id, project_id=project_id)
|
|
self._repo.upsert_session(rag_session_id, project_id)
|
|
return session
|
|
|
|
def get(self, rag_session_id: str) -> RagSession | None:
|
|
row = self._repo.get_session(rag_session_id)
|
|
if not row:
|
|
return None
|
|
return RagSession(
|
|
rag_session_id=str(row["rag_session_id"]),
|
|
project_id=str(row["project_id"]),
|
|
)
|