Files
RagAgent/src/rag_agent/index/embeddings.py

112 lines
4.2 KiB
Python

from __future__ import annotations
import hashlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Protocol
from dotenv import load_dotenv
# Ensure .env is loaded when resolving embedding client (e.g. GIGACHAT_CREDENTIALS)
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
load_dotenv(_repo_root / ".env")
class EmbeddingClient(Protocol):
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
raise NotImplementedError
@dataclass
class StubEmbeddingClient:
dim: int
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
vectors: list[list[float]] = []
for text in texts:
digest = hashlib.sha256(text.encode("utf-8")).digest()
values = [b / 255.0 for b in digest]
if len(values) < self.dim:
values = (values * ((self.dim // len(values)) + 1))[: self.dim]
vectors.append(values[: self.dim])
return vectors
_GIGACHAT_BATCH_SIZE = 50
# GigaChat embeddings: max 514 tokens per input; Russian/English ~3 chars/token → truncate to stay under
_GIGACHAT_MAX_CHARS_PER_INPUT = 1200
class GigaChatEmbeddingClient:
"""Embeddings via GigaChat API. Credentials from env GIGACHAT_CREDENTIALS."""
def __init__(
self,
credentials: str,
model: str = "Embeddings",
verify_ssl_certs: bool = False,
) -> None:
self._credentials = credentials.strip()
self._model = model
self._verify_ssl_certs = verify_ssl_certs
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
from gigachat import GigaChat
texts_list = list(texts)
if not texts_list:
return []
result: list[list[float]] = []
try:
for i in range(0, len(texts_list), _GIGACHAT_BATCH_SIZE):
raw_batch = texts_list[i : i + _GIGACHAT_BATCH_SIZE]
batch = [
t[: _GIGACHAT_MAX_CHARS_PER_INPUT] if len(t) > _GIGACHAT_MAX_CHARS_PER_INPUT else t
for t in raw_batch
]
with GigaChat(
credentials=self._credentials,
model=self._model,
verify_ssl_certs=self._verify_ssl_certs,
) as giga:
# API: embeddings(texts: list[str]) — single positional argument (gigachat 0.2+)
response = giga.embeddings(batch)
# Preserve order by index
by_index = {item.index: item.embedding for item in response.data}
result.extend(by_index[j] for j in range(len(batch)))
except Exception as e:
from gigachat.exceptions import ResponseError, RequestEntityTooLargeError
is_402 = (
isinstance(e, ResponseError)
and (getattr(e, "status_code", None) == 402 or "402" in str(e) or "Payment Required" in str(e))
)
if is_402:
raise ValueError(
"GigaChat: недостаточно средств (402 Payment Required). "
"Пополните баланс в кабинете GigaChat или отключите GIGACHAT_CREDENTIALS для stub-режима."
)
if isinstance(e, RequestEntityTooLargeError) or (
isinstance(e, ResponseError) and getattr(e, "status_code", None) == 413
):
raise ValueError(
"GigaChat: превышен лимит токенов на один запрос (413). "
"Уменьшите RAG_CHUNK_SIZE_LINES или RAG_CHUNK_SIZE в .env (текущий лимит ~514 токенов на чанк)."
)
raise
return result
def get_embedding_client(dim: int) -> EmbeddingClient:
credentials = os.getenv("GIGACHAT_CREDENTIALS", "").strip()
if credentials:
return GigaChatEmbeddingClient(
credentials=credentials,
model=os.getenv("GIGACHAT_EMBEDDINGS_MODEL", "Embeddings"),
verify_ssl_certs=os.getenv("GIGACHAT_VERIFY_SSL", "false").lower()
in ("1", "true", "yes"),
)
return StubEmbeddingClient(dim=dim)