ДОбавлены эмбеддинги на базе гигачата

This commit is contained in:
2026-01-30 22:53:01 +03:00
parent d07578f489
commit e899f54f04
3 changed files with 62 additions and 0 deletions

View File

@@ -8,6 +8,8 @@ dependencies = [
"psycopg[binary]>=3.1.18", "psycopg[binary]>=3.1.18",
"pgvector>=0.2.5", "pgvector>=0.2.5",
"pydantic>=2.7.0", "pydantic>=2.7.0",
"python-dotenv>=1.0.0",
"gigachat>=0.2.0",
] ]
[project.scripts] [project.scripts]

View File

@@ -2,8 +2,15 @@ from __future__ import annotations
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Sequence from typing import Iterable, Sequence
from dotenv import load_dotenv
# Load .env from repo root when config is used (e.g. for local runs)
_repo_root = Path(__file__).resolve().parent.parent.parent
load_dotenv(_repo_root / ".env")
@dataclass(frozen=True) @dataclass(frozen=True)
class AppConfig: class AppConfig:

View File

@@ -1,9 +1,17 @@
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Protocol from typing import Iterable, Protocol
from dotenv import load_dotenv
# Ensure .env is loaded when resolving embedding client (e.g. GIGACHAT_CREDENTIALS)
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
load_dotenv(_repo_root / ".env")
class EmbeddingClient(Protocol): class EmbeddingClient(Protocol):
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]: def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
@@ -25,5 +33,50 @@ class StubEmbeddingClient:
return vectors return vectors
_GIGACHAT_BATCH_SIZE = 50
class GigaChatEmbeddingClient:
"""Embeddings via GigaChat API. Credentials from env GIGACHAT_CREDENTIALS."""
def __init__(
self,
credentials: str,
model: str = "Embeddings",
verify_ssl_certs: bool = False,
) -> None:
self._credentials = credentials.strip()
self._model = model
self._verify_ssl_certs = verify_ssl_certs
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
from gigachat import GigaChat
texts_list = list(texts)
if not texts_list:
return []
result: list[list[float]] = []
for i in range(0, len(texts_list), _GIGACHAT_BATCH_SIZE):
batch = texts_list[i : i + _GIGACHAT_BATCH_SIZE]
with GigaChat(
credentials=self._credentials,
verify_ssl_certs=self._verify_ssl_certs,
) as giga:
response = giga.embeddings(model=self._model, input=batch)
# Preserve order by index
by_index = {item.index: item.embedding for item in response.data}
result.extend(by_index[j] for j in range(len(batch)))
return result
def get_embedding_client(dim: int) -> EmbeddingClient: def get_embedding_client(dim: int) -> EmbeddingClient:
credentials = os.getenv("GIGACHAT_CREDENTIALS", "").strip()
if credentials:
return GigaChatEmbeddingClient(
credentials=credentials,
model=os.getenv("GIGACHAT_EMBEDDINGS_MODEL", "Embeddings"),
verify_ssl_certs=os.getenv("GIGACHAT_VERIFY_SSL", "false").lower()
in ("1", "true", "yes"),
)
return StubEmbeddingClient(dim=dim) return StubEmbeddingClient(dim=dim)