85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
import time
|
|
|
|
import requests
|
|
|
|
from app.core.constants import MAX_RETRIES
|
|
from app.modules.shared.gigachat.errors import GigaChatError
|
|
from app.modules.shared.gigachat.settings import GigaChatSettings
|
|
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
|
|
|
|
|
class GigaChatClient:
|
|
def __init__(self, settings: GigaChatSettings, token_provider: GigaChatTokenProvider) -> None:
|
|
self._settings = settings
|
|
self._tokens = token_provider
|
|
|
|
def complete(self, system_prompt: str, user_prompt: str) -> str:
|
|
token = self._tokens.get_access_token()
|
|
payload = {
|
|
"model": self._settings.model,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
}
|
|
response = self._post_with_retry("/chat/completions", payload, token=token, timeout=90, operation_name="completion")
|
|
data = response.json()
|
|
choices = data.get("choices") or []
|
|
if not choices:
|
|
return ""
|
|
message = choices[0].get("message") or {}
|
|
return str(message.get("content") or "")
|
|
|
|
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
token = self._tokens.get_access_token()
|
|
payload = {
|
|
"model": self._settings.embedding_model,
|
|
"input": texts,
|
|
}
|
|
response = self._post_with_retry("/embeddings", payload, token=token, timeout=90, operation_name="embeddings")
|
|
data = response.json()
|
|
items = data.get("data")
|
|
if not isinstance(items, list):
|
|
raise GigaChatError("Unexpected GigaChat embeddings response")
|
|
return [list(map(float, x.get("embedding") or [])) for x in items]
|
|
|
|
def _post_with_retry(
|
|
self,
|
|
path: str,
|
|
payload: dict,
|
|
*,
|
|
token: str,
|
|
timeout: int,
|
|
operation_name: str,
|
|
):
|
|
last_error: Exception | None = None
|
|
for attempt in range(1, MAX_RETRIES + 1):
|
|
try:
|
|
response = requests.post(
|
|
f"{self._settings.api_url.rstrip('/')}{path}",
|
|
json=payload,
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
timeout=timeout,
|
|
verify=self._settings.ssl_verify,
|
|
)
|
|
except requests.RequestException as exc:
|
|
last_error = GigaChatError(f"GigaChat {operation_name} request failed: {exc}")
|
|
else:
|
|
if response.status_code < 400:
|
|
return response
|
|
last_error = GigaChatError(f"GigaChat {operation_name} error {response.status_code}: {response.text}")
|
|
if not self._is_retryable_status(response.status_code):
|
|
raise last_error
|
|
if attempt == MAX_RETRIES:
|
|
break
|
|
time.sleep(0.1 * attempt)
|
|
if last_error is None:
|
|
raise GigaChatError(f"GigaChat {operation_name} failed without response")
|
|
raise last_error
|
|
|
|
def _is_retryable_status(self, status_code: int) -> bool:
|
|
return status_code == 429 or status_code >= 500
|