Files
agent/app/modules/rag/explain/layered_gateway.py

290 lines
10 KiB
Python

from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
from app.modules.rag.retrieval.test_filter import build_test_filters, debug_disable_test_filter
LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from app.modules.rag.persistence.repository import RagRepository
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
@dataclass(slots=True)
class LayerRetrievalResult:
items: list[LayeredRetrievalItem]
missing: list[str] = field(default_factory=list)
class LayeredRetrievalGateway:
def __init__(self, repository: RagRepository, embedder: GigaChatEmbedder) -> None:
self._repository = repository
self._embedder = embedder
def retrieve_layer(
self,
rag_session_id: str,
query: str,
layer: str,
*,
limit: int,
path_prefixes: list[str] | None = None,
exclude_tests: bool = True,
prefer_non_tests: bool = False,
include_spans: bool = False,
) -> LayerRetrievalResult:
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
filter_args = self._filter_args(effective_exclude_tests)
query_embedding: list[float] | None = None
try:
query_embedding = self._embedder.embed([query])[0]
rows = self._repository.retrieve(
rag_session_id,
query_embedding,
query_text=query,
limit=limit,
layers=[layer],
path_prefixes=path_prefixes,
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
exclude_like_patterns=filter_args["exclude_like_patterns"],
prefer_non_tests=prefer_non_tests or not effective_exclude_tests,
)
return self._success_result(
rows,
rag_session_id=rag_session_id,
label="layered retrieval",
include_spans=include_spans,
layer=layer,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
)
except Exception as exc:
if query_embedding is None:
self._log_failure(
label="layered retrieval",
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
)
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
retry_result = self._retry_without_test_filter(
operation=lambda: self._repository.retrieve(
rag_session_id,
query_embedding,
query_text=query,
limit=limit,
layers=[layer],
path_prefixes=path_prefixes,
exclude_path_prefixes=None,
exclude_like_patterns=None,
prefer_non_tests=True,
),
label="layered retrieval",
rag_session_id=rag_session_id,
include_spans=include_spans,
layer=layer,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
missing_prefix=f"layer:{layer} retrieval_failed",
)
if retry_result is not None:
return retry_result
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
def retrieve_lexical_code(
self,
rag_session_id: str,
query: str,
*,
limit: int,
path_prefixes: list[str] | None = None,
exclude_tests: bool = True,
include_spans: bool = False,
) -> LayerRetrievalResult:
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
filter_args = self._filter_args(effective_exclude_tests)
try:
rows = self._repository.retrieve_lexical_code(
rag_session_id,
query_text=query,
limit=limit,
path_prefixes=path_prefixes,
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
exclude_like_patterns=filter_args["exclude_like_patterns"],
prefer_non_tests=not effective_exclude_tests,
)
return self._success_result(
rows,
rag_session_id=rag_session_id,
label="lexical retrieval",
include_spans=include_spans,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
)
except Exception as exc:
retry_result = self._retry_without_test_filter(
operation=lambda: self._repository.retrieve_lexical_code(
rag_session_id,
query_text=query,
limit=limit,
path_prefixes=path_prefixes,
exclude_path_prefixes=None,
exclude_like_patterns=None,
prefer_non_tests=True,
),
label="lexical retrieval",
rag_session_id=rag_session_id,
include_spans=include_spans,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
missing_prefix="layer:C0 lexical_retrieval_failed",
)
if retry_result is not None:
return retry_result
return LayerRetrievalResult(items=[], missing=[self._failure_missing("layer:C0 lexical_retrieval_failed", exc)])
def _retry_without_test_filter(
self,
*,
operation: Callable[[], list[dict]],
label: str,
rag_session_id: str,
include_spans: bool,
exclude_tests: bool,
path_prefixes: list[str] | None,
exc: Exception,
missing_prefix: str,
layer: str | None = None,
) -> LayerRetrievalResult | None:
if not exclude_tests:
self._log_failure(
label=label,
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
)
return None
self._log_failure(
label=label,
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
retried_without_test_filter=True,
)
try:
rows = operation()
except Exception as retry_exc:
self._log_failure(
label=f"{label} retry",
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=False,
path_prefixes=path_prefixes,
exc=retry_exc,
)
return None
result = self._success_result(
rows,
rag_session_id=rag_session_id,
label=f"{label} retry",
include_spans=include_spans,
layer=layer,
exclude_tests=False,
path_prefixes=path_prefixes,
)
result.missing.append(f"{missing_prefix}:retried_without_test_filter")
return result
def _success_result(
self,
rows: list[dict],
*,
rag_session_id: str,
label: str,
include_spans: bool,
exclude_tests: bool,
path_prefixes: list[str] | None,
layer: str | None = None,
) -> LayerRetrievalResult:
items = [self._to_item(row, include_spans=include_spans) for row in rows]
LOGGER.warning(
"%s: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s returned_count=%s top_paths=%s",
label,
rag_session_id,
layer,
exclude_tests,
path_prefixes or [],
len(items),
[item.source for item in items[:3]],
)
return LayerRetrievalResult(items=items)
def _log_failure(
self,
*,
label: str,
rag_session_id: str,
exclude_tests: bool,
path_prefixes: list[str] | None,
exc: Exception,
layer: str | None = None,
retried_without_test_filter: bool = False,
) -> None:
LOGGER.warning(
"%s failed: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s retried_without_test_filter=%s error=%s",
label,
rag_session_id,
layer,
exclude_tests,
path_prefixes or [],
retried_without_test_filter,
self._exception_summary(exc),
exc_info=True,
)
def _filter_args(self, exclude_tests: bool) -> dict[str, list[str] | None]:
test_filters = build_test_filters() if exclude_tests else None
return {
"exclude_path_prefixes": test_filters.exclude_path_prefixes if test_filters else None,
"exclude_like_patterns": test_filters.exclude_like_patterns if test_filters else None,
}
def _failure_missing(self, prefix: str, exc: Exception) -> str:
return f"{prefix}:{self._exception_summary(exc)}"
def _exception_summary(self, exc: Exception) -> str:
message = " ".join(str(exc).split())
if len(message) > 180:
message = message[:177] + "..."
return f"{type(exc).__name__}:{message or 'no_message'}"
def _to_item(self, row: dict, *, include_spans: bool) -> LayeredRetrievalItem:
location = None
if include_spans:
location = CodeLocation(
path=str(row.get("path") or ""),
start_line=row.get("span_start"),
end_line=row.get("span_end"),
)
return LayeredRetrievalItem(
source=str(row.get("path") or ""),
content=str(row.get("content") or ""),
layer=str(row.get("layer") or ""),
title=str(row.get("title") or ""),
metadata=dict(row.get("metadata", {}) or {}),
score=row.get("distance"),
location=location,
)