from __future__ import annotations import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Callable from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem from app.modules.rag.retrieval.test_filter import build_test_filters, debug_disable_test_filter LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: from app.modules.rag.persistence.repository import RagRepository from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder @dataclass(slots=True) class LayerRetrievalResult: items: list[LayeredRetrievalItem] missing: list[str] = field(default_factory=list) class LayeredRetrievalGateway: def __init__(self, repository: RagRepository, embedder: GigaChatEmbedder) -> None: self._repository = repository self._embedder = embedder def retrieve_layer( self, rag_session_id: str, query: str, layer: str, *, limit: int, path_prefixes: list[str] | None = None, exclude_tests: bool = True, prefer_non_tests: bool = False, include_spans: bool = False, ) -> LayerRetrievalResult: effective_exclude_tests = exclude_tests and not debug_disable_test_filter() filter_args = self._filter_args(effective_exclude_tests) query_embedding: list[float] | None = None try: query_embedding = self._embedder.embed([query])[0] rows = self._repository.retrieve( rag_session_id, query_embedding, query_text=query, limit=limit, layers=[layer], path_prefixes=path_prefixes, exclude_path_prefixes=filter_args["exclude_path_prefixes"], exclude_like_patterns=filter_args["exclude_like_patterns"], prefer_non_tests=prefer_non_tests or not effective_exclude_tests, ) return self._success_result( rows, rag_session_id=rag_session_id, label="layered retrieval", include_spans=include_spans, layer=layer, exclude_tests=effective_exclude_tests, path_prefixes=path_prefixes, ) except Exception as exc: if query_embedding is None: self._log_failure( label="layered retrieval", rag_session_id=rag_session_id, layer=layer, exclude_tests=effective_exclude_tests, path_prefixes=path_prefixes, exc=exc, ) return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)]) retry_result = self._retry_without_test_filter( operation=lambda: self._repository.retrieve( rag_session_id, query_embedding, query_text=query, limit=limit, layers=[layer], path_prefixes=path_prefixes, exclude_path_prefixes=None, exclude_like_patterns=None, prefer_non_tests=True, ), label="layered retrieval", rag_session_id=rag_session_id, include_spans=include_spans, layer=layer, exclude_tests=effective_exclude_tests, path_prefixes=path_prefixes, exc=exc, missing_prefix=f"layer:{layer} retrieval_failed", ) if retry_result is not None: return retry_result return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)]) def retrieve_lexical_code( self, rag_session_id: str, query: str, *, limit: int, path_prefixes: list[str] | None = None, exclude_tests: bool = True, include_spans: bool = False, ) -> LayerRetrievalResult: effective_exclude_tests = exclude_tests and not debug_disable_test_filter() filter_args = self._filter_args(effective_exclude_tests) try: rows = self._repository.retrieve_lexical_code( rag_session_id, query_text=query, limit=limit, path_prefixes=path_prefixes, exclude_path_prefixes=filter_args["exclude_path_prefixes"], exclude_like_patterns=filter_args["exclude_like_patterns"], prefer_non_tests=not effective_exclude_tests, ) return self._success_result( rows, rag_session_id=rag_session_id, label="lexical retrieval", include_spans=include_spans, exclude_tests=effective_exclude_tests, path_prefixes=path_prefixes, ) except Exception as exc: retry_result = self._retry_without_test_filter( operation=lambda: self._repository.retrieve_lexical_code( rag_session_id, query_text=query, limit=limit, path_prefixes=path_prefixes, exclude_path_prefixes=None, exclude_like_patterns=None, prefer_non_tests=True, ), label="lexical retrieval", rag_session_id=rag_session_id, include_spans=include_spans, exclude_tests=effective_exclude_tests, path_prefixes=path_prefixes, exc=exc, missing_prefix="layer:C0 lexical_retrieval_failed", ) if retry_result is not None: return retry_result return LayerRetrievalResult(items=[], missing=[self._failure_missing("layer:C0 lexical_retrieval_failed", exc)]) def _retry_without_test_filter( self, *, operation: Callable[[], list[dict]], label: str, rag_session_id: str, include_spans: bool, exclude_tests: bool, path_prefixes: list[str] | None, exc: Exception, missing_prefix: str, layer: str | None = None, ) -> LayerRetrievalResult | None: if not exclude_tests: self._log_failure( label=label, rag_session_id=rag_session_id, layer=layer, exclude_tests=exclude_tests, path_prefixes=path_prefixes, exc=exc, ) return None self._log_failure( label=label, rag_session_id=rag_session_id, layer=layer, exclude_tests=exclude_tests, path_prefixes=path_prefixes, exc=exc, retried_without_test_filter=True, ) try: rows = operation() except Exception as retry_exc: self._log_failure( label=f"{label} retry", rag_session_id=rag_session_id, layer=layer, exclude_tests=False, path_prefixes=path_prefixes, exc=retry_exc, ) return None result = self._success_result( rows, rag_session_id=rag_session_id, label=f"{label} retry", include_spans=include_spans, layer=layer, exclude_tests=False, path_prefixes=path_prefixes, ) result.missing.append(f"{missing_prefix}:retried_without_test_filter") return result def _success_result( self, rows: list[dict], *, rag_session_id: str, label: str, include_spans: bool, exclude_tests: bool, path_prefixes: list[str] | None, layer: str | None = None, ) -> LayerRetrievalResult: items = [self._to_item(row, include_spans=include_spans) for row in rows] LOGGER.warning( "%s: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s returned_count=%s top_paths=%s", label, rag_session_id, layer, exclude_tests, path_prefixes or [], len(items), [item.source for item in items[:3]], ) return LayerRetrievalResult(items=items) def _log_failure( self, *, label: str, rag_session_id: str, exclude_tests: bool, path_prefixes: list[str] | None, exc: Exception, layer: str | None = None, retried_without_test_filter: bool = False, ) -> None: LOGGER.warning( "%s failed: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s retried_without_test_filter=%s error=%s", label, rag_session_id, layer, exclude_tests, path_prefixes or [], retried_without_test_filter, self._exception_summary(exc), exc_info=True, ) def _filter_args(self, exclude_tests: bool) -> dict[str, list[str] | None]: test_filters = build_test_filters() if exclude_tests else None return { "exclude_path_prefixes": test_filters.exclude_path_prefixes if test_filters else None, "exclude_like_patterns": test_filters.exclude_like_patterns if test_filters else None, } def _failure_missing(self, prefix: str, exc: Exception) -> str: return f"{prefix}:{self._exception_summary(exc)}" def _exception_summary(self, exc: Exception) -> str: message = " ".join(str(exc).split()) if len(message) > 180: message = message[:177] + "..." return f"{type(exc).__name__}:{message or 'no_message'}" def _to_item(self, row: dict, *, include_spans: bool) -> LayeredRetrievalItem: location = None if include_spans: location = CodeLocation( path=str(row.get("path") or ""), start_line=row.get("span_start"), end_line=row.get("span_end"), ) return LayeredRetrievalItem( source=str(row.get("path") or ""), content=str(row.get("content") or ""), layer=str(row.get("layer") or ""), title=str(row.get("title") or ""), metadata=dict(row.get("metadata", {}) or {}), score=row.get("distance"), location=location, )