from __future__ import annotations import hashlib import os from dataclasses import dataclass from pathlib import Path from typing import Iterable, Protocol from dotenv import load_dotenv # Ensure .env is loaded when resolving embedding client (e.g. GIGACHAT_CREDENTIALS) _repo_root = Path(__file__).resolve().parent.parent.parent.parent load_dotenv(_repo_root / ".env") class EmbeddingClient(Protocol): def embed_texts(self, texts: Iterable[str]) -> list[list[float]]: raise NotImplementedError @dataclass class StubEmbeddingClient: dim: int def embed_texts(self, texts: Iterable[str]) -> list[list[float]]: vectors: list[list[float]] = [] for text in texts: digest = hashlib.sha256(text.encode("utf-8")).digest() values = [b / 255.0 for b in digest] if len(values) < self.dim: values = (values * ((self.dim // len(values)) + 1))[: self.dim] vectors.append(values[: self.dim]) return vectors _GIGACHAT_BATCH_SIZE = 50 # GigaChat embeddings: max 514 tokens per input; Russian/English ~3 chars/token → truncate to stay under _GIGACHAT_MAX_CHARS_PER_INPUT = 1200 class GigaChatEmbeddingClient: """Embeddings via GigaChat API. Credentials from env GIGACHAT_CREDENTIALS.""" def __init__( self, credentials: str, model: str = "Embeddings", verify_ssl_certs: bool = False, ) -> None: self._credentials = credentials.strip() self._model = model self._verify_ssl_certs = verify_ssl_certs def embed_texts(self, texts: Iterable[str]) -> list[list[float]]: from gigachat import GigaChat texts_list = list(texts) if not texts_list: return [] result: list[list[float]] = [] try: for i in range(0, len(texts_list), _GIGACHAT_BATCH_SIZE): raw_batch = texts_list[i : i + _GIGACHAT_BATCH_SIZE] batch = [ t[: _GIGACHAT_MAX_CHARS_PER_INPUT] if len(t) > _GIGACHAT_MAX_CHARS_PER_INPUT else t for t in raw_batch ] with GigaChat( credentials=self._credentials, model=self._model, verify_ssl_certs=self._verify_ssl_certs, ) as giga: # API: embeddings(texts: list[str]) — single positional argument (gigachat 0.2+) response = giga.embeddings(batch) # Preserve order by index by_index = {item.index: item.embedding for item in response.data} result.extend(by_index[j] for j in range(len(batch))) except Exception as e: from gigachat.exceptions import ResponseError, RequestEntityTooLargeError is_402 = ( isinstance(e, ResponseError) and (getattr(e, "status_code", None) == 402 or "402" in str(e) or "Payment Required" in str(e)) ) if is_402: raise ValueError( "GigaChat: недостаточно средств (402 Payment Required). " "Пополните баланс в кабинете GigaChat или отключите GIGACHAT_CREDENTIALS для stub-режима." ) if isinstance(e, RequestEntityTooLargeError) or ( isinstance(e, ResponseError) and getattr(e, "status_code", None) == 413 ): raise ValueError( "GigaChat: превышен лимит токенов на один запрос (413). " "Уменьшите RAG_CHUNK_SIZE_LINES или RAG_CHUNK_SIZE в .env (текущий лимит ~514 токенов на чанк)." ) raise return result def get_embedding_client(dim: int) -> EmbeddingClient: credentials = os.getenv("GIGACHAT_CREDENTIALS", "").strip() if credentials: return GigaChatEmbeddingClient( credentials=credentials, model=os.getenv("GIGACHAT_EMBEDDINGS_MODEL", "Embeddings"), verify_ssl_certs=os.getenv("GIGACHAT_VERIFY_SSL", "false").lower() in ("1", "true", "yes"), ) return StubEmbeddingClient(dim=dim)