112 lines
4.2 KiB
Python
112 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import os
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable, Protocol
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
# Ensure .env is loaded when resolving embedding client (e.g. GIGACHAT_CREDENTIALS)
|
|
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
|
load_dotenv(_repo_root / ".env")
|
|
|
|
|
|
class EmbeddingClient(Protocol):
|
|
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class StubEmbeddingClient:
|
|
dim: int
|
|
|
|
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
|
vectors: list[list[float]] = []
|
|
for text in texts:
|
|
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
|
values = [b / 255.0 for b in digest]
|
|
if len(values) < self.dim:
|
|
values = (values * ((self.dim // len(values)) + 1))[: self.dim]
|
|
vectors.append(values[: self.dim])
|
|
return vectors
|
|
|
|
|
|
_GIGACHAT_BATCH_SIZE = 50
|
|
# GigaChat embeddings: max 514 tokens per input; Russian/English ~3 chars/token → truncate to stay under
|
|
_GIGACHAT_MAX_CHARS_PER_INPUT = 1200
|
|
|
|
|
|
class GigaChatEmbeddingClient:
|
|
"""Embeddings via GigaChat API. Credentials from env GIGACHAT_CREDENTIALS."""
|
|
|
|
def __init__(
|
|
self,
|
|
credentials: str,
|
|
model: str = "Embeddings",
|
|
verify_ssl_certs: bool = False,
|
|
) -> None:
|
|
self._credentials = credentials.strip()
|
|
self._model = model
|
|
self._verify_ssl_certs = verify_ssl_certs
|
|
|
|
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
|
from gigachat import GigaChat
|
|
|
|
texts_list = list(texts)
|
|
if not texts_list:
|
|
return []
|
|
|
|
result: list[list[float]] = []
|
|
try:
|
|
for i in range(0, len(texts_list), _GIGACHAT_BATCH_SIZE):
|
|
raw_batch = texts_list[i : i + _GIGACHAT_BATCH_SIZE]
|
|
batch = [
|
|
t[: _GIGACHAT_MAX_CHARS_PER_INPUT] if len(t) > _GIGACHAT_MAX_CHARS_PER_INPUT else t
|
|
for t in raw_batch
|
|
]
|
|
with GigaChat(
|
|
credentials=self._credentials,
|
|
model=self._model,
|
|
verify_ssl_certs=self._verify_ssl_certs,
|
|
) as giga:
|
|
# API: embeddings(texts: list[str]) — single positional argument (gigachat 0.2+)
|
|
response = giga.embeddings(batch)
|
|
# Preserve order by index
|
|
by_index = {item.index: item.embedding for item in response.data}
|
|
result.extend(by_index[j] for j in range(len(batch)))
|
|
except Exception as e:
|
|
from gigachat.exceptions import ResponseError, RequestEntityTooLargeError
|
|
|
|
is_402 = (
|
|
isinstance(e, ResponseError)
|
|
and (getattr(e, "status_code", None) == 402 or "402" in str(e) or "Payment Required" in str(e))
|
|
)
|
|
if is_402:
|
|
raise ValueError(
|
|
"GigaChat: недостаточно средств (402 Payment Required). "
|
|
"Пополните баланс в кабинете GigaChat или отключите GIGACHAT_CREDENTIALS для stub-режима."
|
|
)
|
|
if isinstance(e, RequestEntityTooLargeError) or (
|
|
isinstance(e, ResponseError) and getattr(e, "status_code", None) == 413
|
|
):
|
|
raise ValueError(
|
|
"GigaChat: превышен лимит токенов на один запрос (413). "
|
|
"Уменьшите RAG_CHUNK_SIZE_LINES или RAG_CHUNK_SIZE в .env (текущий лимит ~514 токенов на чанк)."
|
|
)
|
|
raise
|
|
return result
|
|
|
|
|
|
def get_embedding_client(dim: int) -> EmbeddingClient:
|
|
credentials = os.getenv("GIGACHAT_CREDENTIALS", "").strip()
|
|
if credentials:
|
|
return GigaChatEmbeddingClient(
|
|
credentials=credentials,
|
|
model=os.getenv("GIGACHAT_EMBEDDINGS_MODEL", "Embeddings"),
|
|
verify_ssl_certs=os.getenv("GIGACHAT_VERIFY_SSL", "false").lower()
|
|
in ("1", "true", "yes"),
|
|
)
|
|
return StubEmbeddingClient(dim=dim)
|