38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.modules.shared.db import get_engine
|
|
|
|
|
|
class RagSessionRepository:
|
|
def upsert_session(self, rag_session_id: str, project_id: str) -> None:
|
|
with get_engine().connect() as conn:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_sessions (rag_session_id, project_id)
|
|
VALUES (:sid, :pid)
|
|
ON CONFLICT (rag_session_id) DO UPDATE SET project_id = EXCLUDED.project_id
|
|
"""
|
|
),
|
|
{"sid": rag_session_id, "pid": project_id},
|
|
)
|
|
conn.commit()
|
|
|
|
def session_exists(self, rag_session_id: str) -> bool:
|
|
with get_engine().connect() as conn:
|
|
row = conn.execute(
|
|
text("SELECT 1 FROM rag_sessions WHERE rag_session_id = :sid"),
|
|
{"sid": rag_session_id},
|
|
).fetchone()
|
|
return bool(row)
|
|
|
|
def get_session(self, rag_session_id: str) -> dict | None:
|
|
with get_engine().connect() as conn:
|
|
row = conn.execute(
|
|
text("SELECT rag_session_id, project_id FROM rag_sessions WHERE rag_session_id = :sid"),
|
|
{"sid": rag_session_id},
|
|
).mappings().fetchone()
|
|
return dict(row) if row else None
|