From e899f54f04487e64a2ef40a53199ea3b760e5a29 Mon Sep 17 00:00:00 2001 From: zosimovaa Date: Fri, 30 Jan 2026 22:53:01 +0300 Subject: [PATCH] =?UTF-8?q?=D0=94=D0=9E=D0=B1=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D1=8B=20=D1=8D=D0=BC=D0=B1=D0=B5=D0=B4=D0=B4=D0=B8=D0=BD?= =?UTF-8?q?=D0=B3=D0=B8=20=D0=BD=D0=B0=20=D0=B1=D0=B0=D0=B7=D0=B5=20=D0=B3?= =?UTF-8?q?=D0=B8=D0=B3=D0=B0=D1=87=D0=B0=D1=82=D0=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 ++ src/rag_agent/config.py | 7 ++++ src/rag_agent/index/embeddings.py | 53 +++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f6968af..764e13d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,8 @@ dependencies = [ "psycopg[binary]>=3.1.18", "pgvector>=0.2.5", "pydantic>=2.7.0", + "python-dotenv>=1.0.0", + "gigachat>=0.2.0", ] [project.scripts] diff --git a/src/rag_agent/config.py b/src/rag_agent/config.py index 3f0ff0e..e7025d4 100644 --- a/src/rag_agent/config.py +++ b/src/rag_agent/config.py @@ -2,8 +2,15 @@ from __future__ import annotations import os from dataclasses import dataclass +from pathlib import Path from typing import Iterable, Sequence +from dotenv import load_dotenv + +# Load .env from repo root when config is used (e.g. for local runs) +_repo_root = Path(__file__).resolve().parent.parent.parent +load_dotenv(_repo_root / ".env") + @dataclass(frozen=True) class AppConfig: diff --git a/src/rag_agent/index/embeddings.py b/src/rag_agent/index/embeddings.py index 19cd659..968d8b0 100644 --- a/src/rag_agent/index/embeddings.py +++ b/src/rag_agent/index/embeddings.py @@ -1,9 +1,17 @@ from __future__ import annotations import hashlib +import os from dataclasses import dataclass +from pathlib import Path from typing import Iterable, Protocol +from dotenv import load_dotenv + +# Ensure .env is loaded when resolving embedding client (e.g. GIGACHAT_CREDENTIALS) +_repo_root = Path(__file__).resolve().parent.parent.parent.parent +load_dotenv(_repo_root / ".env") + class EmbeddingClient(Protocol): def embed_texts(self, texts: Iterable[str]) -> list[list[float]]: @@ -25,5 +33,50 @@ class StubEmbeddingClient: return vectors +_GIGACHAT_BATCH_SIZE = 50 + + +class GigaChatEmbeddingClient: + """Embeddings via GigaChat API. Credentials from env GIGACHAT_CREDENTIALS.""" + + def __init__( + self, + credentials: str, + model: str = "Embeddings", + verify_ssl_certs: bool = False, + ) -> None: + self._credentials = credentials.strip() + self._model = model + self._verify_ssl_certs = verify_ssl_certs + + def embed_texts(self, texts: Iterable[str]) -> list[list[float]]: + from gigachat import GigaChat + + texts_list = list(texts) + if not texts_list: + return [] + + result: list[list[float]] = [] + for i in range(0, len(texts_list), _GIGACHAT_BATCH_SIZE): + batch = texts_list[i : i + _GIGACHAT_BATCH_SIZE] + with GigaChat( + credentials=self._credentials, + verify_ssl_certs=self._verify_ssl_certs, + ) as giga: + response = giga.embeddings(model=self._model, input=batch) + # Preserve order by index + by_index = {item.index: item.embedding for item in response.data} + result.extend(by_index[j] for j in range(len(batch))) + return result + + def get_embedding_client(dim: int) -> EmbeddingClient: + credentials = os.getenv("GIGACHAT_CREDENTIALS", "").strip() + if credentials: + return GigaChatEmbeddingClient( + credentials=credentials, + model=os.getenv("GIGACHAT_EMBEDDINGS_MODEL", "Embeddings"), + verify_ssl_certs=os.getenv("GIGACHAT_VERIFY_SSL", "false").lower() + in ("1", "true", "yes"), + ) return StubEmbeddingClient(dim=dim)