290 lines
10 KiB
Python
290 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Callable
|
|
|
|
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
|
from app.modules.rag.retrieval.test_filter import build_test_filters, debug_disable_test_filter
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from app.modules.rag.persistence.repository import RagRepository
|
|
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class LayerRetrievalResult:
|
|
items: list[LayeredRetrievalItem]
|
|
missing: list[str] = field(default_factory=list)
|
|
|
|
|
|
class LayeredRetrievalGateway:
|
|
def __init__(self, repository: RagRepository, embedder: GigaChatEmbedder) -> None:
|
|
self._repository = repository
|
|
self._embedder = embedder
|
|
|
|
def retrieve_layer(
|
|
self,
|
|
rag_session_id: str,
|
|
query: str,
|
|
layer: str,
|
|
*,
|
|
limit: int,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_tests: bool = True,
|
|
prefer_non_tests: bool = False,
|
|
include_spans: bool = False,
|
|
) -> LayerRetrievalResult:
|
|
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
|
filter_args = self._filter_args(effective_exclude_tests)
|
|
query_embedding: list[float] | None = None
|
|
try:
|
|
query_embedding = self._embedder.embed([query])[0]
|
|
rows = self._repository.retrieve(
|
|
rag_session_id,
|
|
query_embedding,
|
|
query_text=query,
|
|
limit=limit,
|
|
layers=[layer],
|
|
path_prefixes=path_prefixes,
|
|
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
|
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
|
prefer_non_tests=prefer_non_tests or not effective_exclude_tests,
|
|
)
|
|
return self._success_result(
|
|
rows,
|
|
rag_session_id=rag_session_id,
|
|
label="layered retrieval",
|
|
include_spans=include_spans,
|
|
layer=layer,
|
|
exclude_tests=effective_exclude_tests,
|
|
path_prefixes=path_prefixes,
|
|
)
|
|
except Exception as exc:
|
|
if query_embedding is None:
|
|
self._log_failure(
|
|
label="layered retrieval",
|
|
rag_session_id=rag_session_id,
|
|
layer=layer,
|
|
exclude_tests=effective_exclude_tests,
|
|
path_prefixes=path_prefixes,
|
|
exc=exc,
|
|
)
|
|
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
|
retry_result = self._retry_without_test_filter(
|
|
operation=lambda: self._repository.retrieve(
|
|
rag_session_id,
|
|
query_embedding,
|
|
query_text=query,
|
|
limit=limit,
|
|
layers=[layer],
|
|
path_prefixes=path_prefixes,
|
|
exclude_path_prefixes=None,
|
|
exclude_like_patterns=None,
|
|
prefer_non_tests=True,
|
|
),
|
|
label="layered retrieval",
|
|
rag_session_id=rag_session_id,
|
|
include_spans=include_spans,
|
|
layer=layer,
|
|
exclude_tests=effective_exclude_tests,
|
|
path_prefixes=path_prefixes,
|
|
exc=exc,
|
|
missing_prefix=f"layer:{layer} retrieval_failed",
|
|
)
|
|
if retry_result is not None:
|
|
return retry_result
|
|
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
|
|
|
def retrieve_lexical_code(
|
|
self,
|
|
rag_session_id: str,
|
|
query: str,
|
|
*,
|
|
limit: int,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_tests: bool = True,
|
|
include_spans: bool = False,
|
|
) -> LayerRetrievalResult:
|
|
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
|
filter_args = self._filter_args(effective_exclude_tests)
|
|
try:
|
|
rows = self._repository.retrieve_lexical_code(
|
|
rag_session_id,
|
|
query_text=query,
|
|
limit=limit,
|
|
path_prefixes=path_prefixes,
|
|
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
|
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
|
prefer_non_tests=not effective_exclude_tests,
|
|
)
|
|
return self._success_result(
|
|
rows,
|
|
rag_session_id=rag_session_id,
|
|
label="lexical retrieval",
|
|
include_spans=include_spans,
|
|
exclude_tests=effective_exclude_tests,
|
|
path_prefixes=path_prefixes,
|
|
)
|
|
except Exception as exc:
|
|
retry_result = self._retry_without_test_filter(
|
|
operation=lambda: self._repository.retrieve_lexical_code(
|
|
rag_session_id,
|
|
query_text=query,
|
|
limit=limit,
|
|
path_prefixes=path_prefixes,
|
|
exclude_path_prefixes=None,
|
|
exclude_like_patterns=None,
|
|
prefer_non_tests=True,
|
|
),
|
|
label="lexical retrieval",
|
|
rag_session_id=rag_session_id,
|
|
include_spans=include_spans,
|
|
exclude_tests=effective_exclude_tests,
|
|
path_prefixes=path_prefixes,
|
|
exc=exc,
|
|
missing_prefix="layer:C0 lexical_retrieval_failed",
|
|
)
|
|
if retry_result is not None:
|
|
return retry_result
|
|
return LayerRetrievalResult(items=[], missing=[self._failure_missing("layer:C0 lexical_retrieval_failed", exc)])
|
|
|
|
def _retry_without_test_filter(
|
|
self,
|
|
*,
|
|
operation: Callable[[], list[dict]],
|
|
label: str,
|
|
rag_session_id: str,
|
|
include_spans: bool,
|
|
exclude_tests: bool,
|
|
path_prefixes: list[str] | None,
|
|
exc: Exception,
|
|
missing_prefix: str,
|
|
layer: str | None = None,
|
|
) -> LayerRetrievalResult | None:
|
|
if not exclude_tests:
|
|
self._log_failure(
|
|
label=label,
|
|
rag_session_id=rag_session_id,
|
|
layer=layer,
|
|
exclude_tests=exclude_tests,
|
|
path_prefixes=path_prefixes,
|
|
exc=exc,
|
|
)
|
|
return None
|
|
self._log_failure(
|
|
label=label,
|
|
rag_session_id=rag_session_id,
|
|
layer=layer,
|
|
exclude_tests=exclude_tests,
|
|
path_prefixes=path_prefixes,
|
|
exc=exc,
|
|
retried_without_test_filter=True,
|
|
)
|
|
try:
|
|
rows = operation()
|
|
except Exception as retry_exc:
|
|
self._log_failure(
|
|
label=f"{label} retry",
|
|
rag_session_id=rag_session_id,
|
|
layer=layer,
|
|
exclude_tests=False,
|
|
path_prefixes=path_prefixes,
|
|
exc=retry_exc,
|
|
)
|
|
return None
|
|
result = self._success_result(
|
|
rows,
|
|
rag_session_id=rag_session_id,
|
|
label=f"{label} retry",
|
|
include_spans=include_spans,
|
|
layer=layer,
|
|
exclude_tests=False,
|
|
path_prefixes=path_prefixes,
|
|
)
|
|
result.missing.append(f"{missing_prefix}:retried_without_test_filter")
|
|
return result
|
|
|
|
def _success_result(
|
|
self,
|
|
rows: list[dict],
|
|
*,
|
|
rag_session_id: str,
|
|
label: str,
|
|
include_spans: bool,
|
|
exclude_tests: bool,
|
|
path_prefixes: list[str] | None,
|
|
layer: str | None = None,
|
|
) -> LayerRetrievalResult:
|
|
items = [self._to_item(row, include_spans=include_spans) for row in rows]
|
|
LOGGER.warning(
|
|
"%s: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s returned_count=%s top_paths=%s",
|
|
label,
|
|
rag_session_id,
|
|
layer,
|
|
exclude_tests,
|
|
path_prefixes or [],
|
|
len(items),
|
|
[item.source for item in items[:3]],
|
|
)
|
|
return LayerRetrievalResult(items=items)
|
|
|
|
def _log_failure(
|
|
self,
|
|
*,
|
|
label: str,
|
|
rag_session_id: str,
|
|
exclude_tests: bool,
|
|
path_prefixes: list[str] | None,
|
|
exc: Exception,
|
|
layer: str | None = None,
|
|
retried_without_test_filter: bool = False,
|
|
) -> None:
|
|
LOGGER.warning(
|
|
"%s failed: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s retried_without_test_filter=%s error=%s",
|
|
label,
|
|
rag_session_id,
|
|
layer,
|
|
exclude_tests,
|
|
path_prefixes or [],
|
|
retried_without_test_filter,
|
|
self._exception_summary(exc),
|
|
exc_info=True,
|
|
)
|
|
|
|
def _filter_args(self, exclude_tests: bool) -> dict[str, list[str] | None]:
|
|
test_filters = build_test_filters() if exclude_tests else None
|
|
return {
|
|
"exclude_path_prefixes": test_filters.exclude_path_prefixes if test_filters else None,
|
|
"exclude_like_patterns": test_filters.exclude_like_patterns if test_filters else None,
|
|
}
|
|
|
|
def _failure_missing(self, prefix: str, exc: Exception) -> str:
|
|
return f"{prefix}:{self._exception_summary(exc)}"
|
|
|
|
def _exception_summary(self, exc: Exception) -> str:
|
|
message = " ".join(str(exc).split())
|
|
if len(message) > 180:
|
|
message = message[:177] + "..."
|
|
return f"{type(exc).__name__}:{message or 'no_message'}"
|
|
|
|
def _to_item(self, row: dict, *, include_spans: bool) -> LayeredRetrievalItem:
|
|
location = None
|
|
if include_spans:
|
|
location = CodeLocation(
|
|
path=str(row.get("path") or ""),
|
|
start_line=row.get("span_start"),
|
|
end_line=row.get("span_end"),
|
|
)
|
|
return LayeredRetrievalItem(
|
|
source=str(row.get("path") or ""),
|
|
content=str(row.get("content") or ""),
|
|
layer=str(row.get("layer") or ""),
|
|
title=str(row.get("title") or ""),
|
|
metadata=dict(row.get("metadata", {}) or {}),
|
|
score=row.get("distance"),
|
|
location=location,
|
|
)
|