ДОбавлены эмбеддинги на базе гигачата
This commit is contained in:
@@ -1,9 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Protocol
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Ensure .env is loaded when resolving embedding client (e.g. GIGACHAT_CREDENTIALS)
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
load_dotenv(_repo_root / ".env")
|
||||
|
||||
|
||||
class EmbeddingClient(Protocol):
|
||||
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
@@ -25,5 +33,50 @@ class StubEmbeddingClient:
|
||||
return vectors
|
||||
|
||||
|
||||
_GIGACHAT_BATCH_SIZE = 50
|
||||
|
||||
|
||||
class GigaChatEmbeddingClient:
|
||||
"""Embeddings via GigaChat API. Credentials from env GIGACHAT_CREDENTIALS."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: str,
|
||||
model: str = "Embeddings",
|
||||
verify_ssl_certs: bool = False,
|
||||
) -> None:
|
||||
self._credentials = credentials.strip()
|
||||
self._model = model
|
||||
self._verify_ssl_certs = verify_ssl_certs
|
||||
|
||||
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
from gigachat import GigaChat
|
||||
|
||||
texts_list = list(texts)
|
||||
if not texts_list:
|
||||
return []
|
||||
|
||||
result: list[list[float]] = []
|
||||
for i in range(0, len(texts_list), _GIGACHAT_BATCH_SIZE):
|
||||
batch = texts_list[i : i + _GIGACHAT_BATCH_SIZE]
|
||||
with GigaChat(
|
||||
credentials=self._credentials,
|
||||
verify_ssl_certs=self._verify_ssl_certs,
|
||||
) as giga:
|
||||
response = giga.embeddings(model=self._model, input=batch)
|
||||
# Preserve order by index
|
||||
by_index = {item.index: item.embedding for item in response.data}
|
||||
result.extend(by_index[j] for j in range(len(batch)))
|
||||
return result
|
||||
|
||||
|
||||
def get_embedding_client(dim: int) -> EmbeddingClient:
|
||||
credentials = os.getenv("GIGACHAT_CREDENTIALS", "").strip()
|
||||
if credentials:
|
||||
return GigaChatEmbeddingClient(
|
||||
credentials=credentials,
|
||||
model=os.getenv("GIGACHAT_EMBEDDINGS_MODEL", "Embeddings"),
|
||||
verify_ssl_certs=os.getenv("GIGACHAT_VERIFY_SSL", "false").lower()
|
||||
in ("1", "true", "yes"),
|
||||
)
|
||||
return StubEmbeddingClient(dim=dim)
|
||||
|
||||
Reference in New Issue
Block a user