Первый коммит
This commit is contained in:
0
app/modules/shared/__init__.py
Normal file
0
app/modules/shared/__init__.py
Normal file
BIN
app/modules/shared/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/modules/shared/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/__pycache__/bootstrap.cpython-312.pyc
Normal file
BIN
app/modules/shared/__pycache__/bootstrap.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/__pycache__/checkpointer.cpython-312.pyc
Normal file
BIN
app/modules/shared/__pycache__/checkpointer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/__pycache__/db.cpython-312.pyc
Normal file
BIN
app/modules/shared/__pycache__/db.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/__pycache__/event_bus.cpython-312.pyc
Normal file
BIN
app/modules/shared/__pycache__/event_bus.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/__pycache__/idempotency_store.cpython-312.pyc
Normal file
BIN
app/modules/shared/__pycache__/idempotency_store.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/__pycache__/retry_executor.cpython-312.pyc
Normal file
BIN
app/modules/shared/__pycache__/retry_executor.cpython-312.pyc
Normal file
Binary file not shown.
21
app/modules/shared/bootstrap.py
Normal file
21
app/modules/shared/bootstrap.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import time
|
||||
|
||||
from app.modules.shared.checkpointer import get_checkpointer
|
||||
|
||||
|
||||
def bootstrap_database(rag_repository, chat_repository, agent_repository) -> None:
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, 16):
|
||||
try:
|
||||
rag_repository.ensure_tables()
|
||||
chat_repository.ensure_tables()
|
||||
agent_repository.ensure_tables()
|
||||
get_checkpointer()
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_error = exc
|
||||
if attempt == 15:
|
||||
break
|
||||
time.sleep(1)
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
30
app/modules/shared/checkpointer.py
Normal file
30
app/modules/shared/checkpointer.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import psycopg
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
from psycopg.rows import dict_row
|
||||
|
||||
from app.modules.shared.db import database_url
|
||||
|
||||
_CHECKPOINTER: PostgresSaver | None = None
|
||||
_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _conn_string() -> str:
|
||||
url = database_url()
|
||||
if url.startswith("postgresql+psycopg"):
|
||||
return url.replace("postgresql+psycopg", "postgresql", 1)
|
||||
return url
|
||||
|
||||
|
||||
def get_checkpointer() -> PostgresSaver:
|
||||
global _CHECKPOINTER
|
||||
with _LOCK:
|
||||
if _CHECKPOINTER is None:
|
||||
conn = psycopg.connect(_conn_string(), autocommit=True, row_factory=dict_row)
|
||||
cp = PostgresSaver(conn)
|
||||
cp.setup()
|
||||
_CHECKPOINTER = cp
|
||||
return _CHECKPOINTER
|
||||
29
app/modules/shared/db.py
Normal file
29
app/modules/shared/db.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
_ENGINE: Engine | None = None
|
||||
_SESSION_FACTORY: sessionmaker | None = None
|
||||
|
||||
|
||||
def database_url() -> str:
|
||||
return os.getenv("DATABASE_URL", "postgresql+psycopg://agent:agent@db:5432/agent")
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
global _ENGINE
|
||||
if _ENGINE is None:
|
||||
_ENGINE = create_engine(database_url(), poolclass=NullPool, future=True)
|
||||
return _ENGINE
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker:
|
||||
global _SESSION_FACTORY
|
||||
if _SESSION_FACTORY is None:
|
||||
_SESSION_FACTORY = sessionmaker(bind=get_engine(), autoflush=False, autocommit=False)
|
||||
return _SESSION_FACTORY
|
||||
57
app/modules/shared/event_bus.py
Normal file
57
app/modules/shared/event_bus.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
name: str
|
||||
payload: dict
|
||||
|
||||
|
||||
class EventBus:
|
||||
def __init__(self) -> None:
|
||||
self._channels: dict[str, list[asyncio.Queue[Event]]] = defaultdict(list)
|
||||
self._history: dict[str, list[Event]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
self._history_limit = 5000
|
||||
|
||||
async def subscribe(self, channel_id: str, replay: bool = True) -> asyncio.Queue[Event]:
|
||||
queue: asyncio.Queue[Event] = asyncio.Queue()
|
||||
snapshot: list[Event] = []
|
||||
async with self._lock:
|
||||
self._channels[channel_id].append(queue)
|
||||
if replay:
|
||||
snapshot = list(self._history.get(channel_id, []))
|
||||
for event in snapshot:
|
||||
await queue.put(event)
|
||||
return queue
|
||||
|
||||
async def unsubscribe(self, channel_id: str, queue: asyncio.Queue[Event]) -> None:
|
||||
async with self._lock:
|
||||
if channel_id not in self._channels:
|
||||
return
|
||||
items = self._channels[channel_id]
|
||||
if queue in items:
|
||||
items.remove(queue)
|
||||
if not items:
|
||||
del self._channels[channel_id]
|
||||
|
||||
async def publish(self, channel_id: str, name: str, payload: dict) -> None:
|
||||
event_payload = dict(payload)
|
||||
event_payload.setdefault("published_at_ms", int(time.time() * 1000))
|
||||
event = Event(name=name, payload=event_payload)
|
||||
async with self._lock:
|
||||
queues = list(self._channels.get(channel_id, []))
|
||||
history = self._history[channel_id]
|
||||
history.append(event)
|
||||
if len(history) > self._history_limit:
|
||||
del history[: len(history) - self._history_limit]
|
||||
for queue in queues:
|
||||
await queue.put(event)
|
||||
|
||||
@staticmethod
|
||||
def as_sse(event: Event) -> str:
|
||||
return f"event: {event.name}\ndata: {json.dumps(event.payload, ensure_ascii=False)}\n\n"
|
||||
0
app/modules/shared/gigachat/__init__.py
Normal file
0
app/modules/shared/gigachat/__init__.py
Normal file
BIN
app/modules/shared/gigachat/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/modules/shared/gigachat/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/gigachat/__pycache__/client.cpython-312.pyc
Normal file
BIN
app/modules/shared/gigachat/__pycache__/client.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/gigachat/__pycache__/errors.cpython-312.pyc
Normal file
BIN
app/modules/shared/gigachat/__pycache__/errors.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/shared/gigachat/__pycache__/settings.cpython-312.pyc
Normal file
BIN
app/modules/shared/gigachat/__pycache__/settings.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
73
app/modules/shared/gigachat/client.py
Normal file
73
app/modules/shared/gigachat/client.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import requests
|
||||
|
||||
from app.modules.shared.gigachat.errors import GigaChatError
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
||||
|
||||
|
||||
class GigaChatClient:
|
||||
def __init__(self, settings: GigaChatSettings, token_provider: GigaChatTokenProvider) -> None:
|
||||
self._settings = settings
|
||||
self._tokens = token_provider
|
||||
|
||||
def complete(self, system_prompt: str, user_prompt: str) -> str:
|
||||
token = self._tokens.get_access_token()
|
||||
payload = {
|
||||
"model": self._settings.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self._settings.api_url.rstrip('/')}/chat/completions",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=90,
|
||||
verify=self._settings.ssl_verify,
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
raise GigaChatError(f"GigaChat completion request failed: {exc}") from exc
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise GigaChatError(f"GigaChat completion error {response.status_code}: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
choices = data.get("choices") or []
|
||||
if not choices:
|
||||
return ""
|
||||
message = choices[0].get("message") or {}
|
||||
return str(message.get("content") or "")
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
token = self._tokens.get_access_token()
|
||||
payload = {
|
||||
"model": self._settings.embedding_model,
|
||||
"input": texts,
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self._settings.api_url.rstrip('/')}/embeddings",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=90,
|
||||
verify=self._settings.ssl_verify,
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
raise GigaChatError(f"GigaChat embeddings request failed: {exc}") from exc
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise GigaChatError(f"GigaChat embeddings error {response.status_code}: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
items = data.get("data")
|
||||
if not isinstance(items, list):
|
||||
raise GigaChatError("Unexpected GigaChat embeddings response")
|
||||
return [list(map(float, x.get("embedding") or [])) for x in items]
|
||||
2
app/modules/shared/gigachat/errors.py
Normal file
2
app/modules/shared/gigachat/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class GigaChatError(OSError):
|
||||
pass
|
||||
25
app/modules/shared/gigachat/settings.py
Normal file
25
app/modules/shared/gigachat/settings.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GigaChatSettings:
|
||||
auth_url: str
|
||||
api_url: str
|
||||
scope: str
|
||||
credentials: str
|
||||
ssl_verify: bool
|
||||
model: str
|
||||
embedding_model: str
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "GigaChatSettings":
|
||||
return cls(
|
||||
auth_url=os.getenv("GIGACHAT_AUTH_URL", "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"),
|
||||
api_url=os.getenv("GIGACHAT_API_URL", "https://gigachat.devices.sberbank.ru/api/v1"),
|
||||
scope=os.getenv("GIGACHAT_SCOPE", "GIGACHAT_API_PERS"),
|
||||
credentials=os.getenv("GIGACHAT_TOKEN", "").strip(),
|
||||
ssl_verify=os.getenv("GIGACHAT_SSL_VERIFY", "true").lower() in {"1", "true", "yes"},
|
||||
model=os.getenv("GIGACHAT_MODEL", "GigaChat"),
|
||||
embedding_model=os.getenv("GIGACHAT_EMBEDDING_MODEL", "Embeddings"),
|
||||
)
|
||||
58
app/modules/shared/gigachat/token_provider.py
Normal file
58
app/modules/shared/gigachat/token_provider.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from app.modules.shared.gigachat.errors import GigaChatError
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
|
||||
|
||||
class GigaChatTokenProvider:
|
||||
def __init__(self, settings: GigaChatSettings) -> None:
|
||||
self._settings = settings
|
||||
self._lock = threading.Lock()
|
||||
self._token: str | None = None
|
||||
self._expires_at_ms: float = 0
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
now_ms = time.time() * 1000
|
||||
with self._lock:
|
||||
if self._token and self._expires_at_ms - 300_000 > now_ms:
|
||||
return self._token
|
||||
|
||||
token, expires_at = self._fetch_token()
|
||||
with self._lock:
|
||||
self._token = token
|
||||
self._expires_at_ms = expires_at
|
||||
return token
|
||||
|
||||
def _fetch_token(self) -> tuple[str, float]:
|
||||
if not self._settings.credentials:
|
||||
raise GigaChatError("GIGACHAT_TOKEN is not set")
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Basic {self._settings.credentials}",
|
||||
"RqUID": str(uuid.uuid4()),
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
self._settings.auth_url,
|
||||
headers=headers,
|
||||
data=f"scope={self._settings.scope}",
|
||||
timeout=30,
|
||||
verify=self._settings.ssl_verify,
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
raise GigaChatError(f"GigaChat auth request failed: {exc}") from exc
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise GigaChatError(f"GigaChat auth error {response.status_code}: {response.text}")
|
||||
|
||||
payload = response.json()
|
||||
token = payload.get("access_token")
|
||||
expires_at = float(payload.get("expires_at", 0))
|
||||
if not token:
|
||||
raise GigaChatError("GigaChat auth: no access_token in response")
|
||||
return token, expires_at
|
||||
40
app/modules/shared/idempotency_store.py
Normal file
40
app/modules/shared/idempotency_store.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from threading import Lock
|
||||
|
||||
from app.core.constants import IDEMPOTENCY_TTL
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdempotencyRecord:
|
||||
task_id: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class IdempotencyStore:
|
||||
def __init__(self) -> None:
|
||||
self._records: dict[str, IdempotencyRecord] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get_task_id(self, key: str) -> str | None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with self._lock:
|
||||
self._cleanup_locked(now)
|
||||
record = self._records.get(key)
|
||||
return record.task_id if record else None
|
||||
|
||||
def put(self, key: str, task_id: str) -> None:
|
||||
with self._lock:
|
||||
self._records[key] = IdempotencyRecord(
|
||||
task_id=task_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def _cleanup_locked(self, now: datetime) -> None:
|
||||
expired = [
|
||||
key
|
||||
for key, rec in self._records.items()
|
||||
if now - rec.created_at > IDEMPOTENCY_TTL
|
||||
]
|
||||
for key in expired:
|
||||
del self._records[key]
|
||||
21
app/modules/shared/retry_executor.py
Normal file
21
app/modules/shared/retry_executor.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable, TypeVar
|
||||
|
||||
from app.core.constants import MAX_RETRIES
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RetryExecutor:
|
||||
async def run(self, operation: Callable[[], Awaitable[T]]) -> T:
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
return await operation()
|
||||
except (TimeoutError, ConnectionError, OSError) as exc:
|
||||
last_error = exc
|
||||
if attempt == MAX_RETRIES:
|
||||
break
|
||||
await asyncio.sleep(0.1 * attempt)
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
Reference in New Issue
Block a user