Фиксация изменений
This commit is contained in:
289
app/modules/rag/explain/layered_gateway.py
Normal file
289
app/modules/rag/explain/layered_gateway.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.retrieval.test_filter import build_test_filters, debug_disable_test_filter
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.persistence.repository import RagRepository
|
||||
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LayerRetrievalResult:
|
||||
items: list[LayeredRetrievalItem]
|
||||
missing: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class LayeredRetrievalGateway:
|
||||
def __init__(self, repository: RagRepository, embedder: GigaChatEmbedder) -> None:
|
||||
self._repository = repository
|
||||
self._embedder = embedder
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
) -> LayerRetrievalResult:
|
||||
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
||||
filter_args = self._filter_args(effective_exclude_tests)
|
||||
query_embedding: list[float] | None = None
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repository.retrieve(
|
||||
rag_session_id,
|
||||
query_embedding,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
layers=[layer],
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
||||
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
||||
prefer_non_tests=prefer_non_tests or not effective_exclude_tests,
|
||||
)
|
||||
return self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label="layered retrieval",
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
except Exception as exc:
|
||||
if query_embedding is None:
|
||||
self._log_failure(
|
||||
label="layered retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
)
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
||||
retry_result = self._retry_without_test_filter(
|
||||
operation=lambda: self._repository.retrieve(
|
||||
rag_session_id,
|
||||
query_embedding,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
layers=[layer],
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=None,
|
||||
exclude_like_patterns=None,
|
||||
prefer_non_tests=True,
|
||||
),
|
||||
label="layered retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
missing_prefix=f"layer:{layer} retrieval_failed",
|
||||
)
|
||||
if retry_result is not None:
|
||||
return retry_result
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
) -> LayerRetrievalResult:
|
||||
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
||||
filter_args = self._filter_args(effective_exclude_tests)
|
||||
try:
|
||||
rows = self._repository.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
||||
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
||||
prefer_non_tests=not effective_exclude_tests,
|
||||
)
|
||||
return self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label="lexical retrieval",
|
||||
include_spans=include_spans,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
except Exception as exc:
|
||||
retry_result = self._retry_without_test_filter(
|
||||
operation=lambda: self._repository.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=None,
|
||||
exclude_like_patterns=None,
|
||||
prefer_non_tests=True,
|
||||
),
|
||||
label="lexical retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
include_spans=include_spans,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
missing_prefix="layer:C0 lexical_retrieval_failed",
|
||||
)
|
||||
if retry_result is not None:
|
||||
return retry_result
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing("layer:C0 lexical_retrieval_failed", exc)])
|
||||
|
||||
def _retry_without_test_filter(
|
||||
self,
|
||||
*,
|
||||
operation: Callable[[], list[dict]],
|
||||
label: str,
|
||||
rag_session_id: str,
|
||||
include_spans: bool,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
exc: Exception,
|
||||
missing_prefix: str,
|
||||
layer: str | None = None,
|
||||
) -> LayerRetrievalResult | None:
|
||||
if not exclude_tests:
|
||||
self._log_failure(
|
||||
label=label,
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
)
|
||||
return None
|
||||
self._log_failure(
|
||||
label=label,
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
retried_without_test_filter=True,
|
||||
)
|
||||
try:
|
||||
rows = operation()
|
||||
except Exception as retry_exc:
|
||||
self._log_failure(
|
||||
label=f"{label} retry",
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=False,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=retry_exc,
|
||||
)
|
||||
return None
|
||||
result = self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label=f"{label} retry",
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=False,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
result.missing.append(f"{missing_prefix}:retried_without_test_filter")
|
||||
return result
|
||||
|
||||
def _success_result(
|
||||
self,
|
||||
rows: list[dict],
|
||||
*,
|
||||
rag_session_id: str,
|
||||
label: str,
|
||||
include_spans: bool,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
layer: str | None = None,
|
||||
) -> LayerRetrievalResult:
|
||||
items = [self._to_item(row, include_spans=include_spans) for row in rows]
|
||||
LOGGER.warning(
|
||||
"%s: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s returned_count=%s top_paths=%s",
|
||||
label,
|
||||
rag_session_id,
|
||||
layer,
|
||||
exclude_tests,
|
||||
path_prefixes or [],
|
||||
len(items),
|
||||
[item.source for item in items[:3]],
|
||||
)
|
||||
return LayerRetrievalResult(items=items)
|
||||
|
||||
def _log_failure(
|
||||
self,
|
||||
*,
|
||||
label: str,
|
||||
rag_session_id: str,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
exc: Exception,
|
||||
layer: str | None = None,
|
||||
retried_without_test_filter: bool = False,
|
||||
) -> None:
|
||||
LOGGER.warning(
|
||||
"%s failed: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s retried_without_test_filter=%s error=%s",
|
||||
label,
|
||||
rag_session_id,
|
||||
layer,
|
||||
exclude_tests,
|
||||
path_prefixes or [],
|
||||
retried_without_test_filter,
|
||||
self._exception_summary(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _filter_args(self, exclude_tests: bool) -> dict[str, list[str] | None]:
|
||||
test_filters = build_test_filters() if exclude_tests else None
|
||||
return {
|
||||
"exclude_path_prefixes": test_filters.exclude_path_prefixes if test_filters else None,
|
||||
"exclude_like_patterns": test_filters.exclude_like_patterns if test_filters else None,
|
||||
}
|
||||
|
||||
def _failure_missing(self, prefix: str, exc: Exception) -> str:
|
||||
return f"{prefix}:{self._exception_summary(exc)}"
|
||||
|
||||
def _exception_summary(self, exc: Exception) -> str:
|
||||
message = " ".join(str(exc).split())
|
||||
if len(message) > 180:
|
||||
message = message[:177] + "..."
|
||||
return f"{type(exc).__name__}:{message or 'no_message'}"
|
||||
|
||||
def _to_item(self, row: dict, *, include_spans: bool) -> LayeredRetrievalItem:
|
||||
location = None
|
||||
if include_spans:
|
||||
location = CodeLocation(
|
||||
path=str(row.get("path") or ""),
|
||||
start_line=row.get("span_start"),
|
||||
end_line=row.get("span_end"),
|
||||
)
|
||||
return LayeredRetrievalItem(
|
||||
source=str(row.get("path") or ""),
|
||||
content=str(row.get("content") or ""),
|
||||
layer=str(row.get("layer") or ""),
|
||||
title=str(row.get("title") or ""),
|
||||
metadata=dict(row.get("metadata", {}) or {}),
|
||||
score=row.get("distance"),
|
||||
location=location,
|
||||
)
|
||||
Reference in New Issue
Block a user