329 lines
12 KiB
Python
329 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
from app.modules.rag.contracts.enums import RagLayer
|
|
from app.modules.rag.explain.intent_builder import ExplainIntentBuilder
|
|
from app.modules.rag.explain.layered_gateway import LayerRetrievalResult, LayeredRetrievalGateway
|
|
from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, ExplainPack, LayeredRetrievalItem
|
|
from app.modules.rag.explain.source_excerpt_fetcher import SourceExcerptFetcher
|
|
from app.modules.rag.explain.trace_builder import TraceBuilder
|
|
from app.modules.rag.retrieval.test_filter import exclude_tests_default, is_test_path
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
_MIN_EXCERPTS = 2
|
|
|
|
if TYPE_CHECKING:
|
|
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
|
from app.modules.rag.explain.models import ExplainIntent
|
|
|
|
|
|
class CodeExplainRetrieverV2:
|
|
def __init__(
|
|
self,
|
|
gateway: LayeredRetrievalGateway,
|
|
graph_repository: CodeGraphRepository,
|
|
intent_builder: ExplainIntentBuilder | None = None,
|
|
trace_builder: TraceBuilder | None = None,
|
|
excerpt_fetcher: SourceExcerptFetcher | None = None,
|
|
) -> None:
|
|
self._gateway = gateway
|
|
self._graph = graph_repository
|
|
self._intent_builder = intent_builder or ExplainIntentBuilder()
|
|
self._trace_builder = trace_builder or TraceBuilder(graph_repository)
|
|
self._excerpt_fetcher = excerpt_fetcher or SourceExcerptFetcher(graph_repository)
|
|
|
|
def build_pack(
|
|
self,
|
|
rag_session_id: str,
|
|
user_query: str,
|
|
*,
|
|
file_candidates: list[dict] | None = None,
|
|
) -> ExplainPack:
|
|
intent = self._intent_builder.build(user_query)
|
|
path_prefixes = _path_prefixes(intent, file_candidates or [])
|
|
exclude_tests = exclude_tests_default() and not intent.include_tests
|
|
pack = self._run_pass(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
|
|
if exclude_tests and len(pack.code_excerpts) < _MIN_EXCERPTS:
|
|
self._merge_test_fallback(pack, rag_session_id, intent, path_prefixes)
|
|
self._log_pack(rag_session_id, pack)
|
|
return pack
|
|
|
|
def _run_pass(
|
|
self,
|
|
rag_session_id: str,
|
|
intent: ExplainIntent,
|
|
path_prefixes: list[str],
|
|
*,
|
|
exclude_tests: bool,
|
|
) -> ExplainPack:
|
|
missing: list[str] = []
|
|
entrypoints_result = self._entrypoints(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
|
|
missing.extend(entrypoints_result.missing)
|
|
selected_entrypoints = self._filter_entrypoints(intent, entrypoints_result.items)
|
|
if not selected_entrypoints:
|
|
missing.append("layer:C3 empty")
|
|
seed_result = self._seed_symbols(rag_session_id, intent, path_prefixes, selected_entrypoints, exclude_tests=exclude_tests)
|
|
missing.extend(seed_result.missing)
|
|
seed_symbols = seed_result.items
|
|
if not seed_symbols:
|
|
missing.append("layer:C1 empty")
|
|
depth = 4 if intent.depth == "deep" else 3 if intent.depth == "medium" else 2
|
|
trace_paths = self._trace_builder.build_paths(rag_session_id, seed_symbols, max_depth=depth) if seed_symbols else []
|
|
excerpts, excerpt_evidence = self._excerpt_fetcher.fetch(rag_session_id, trace_paths) if trace_paths else ([], {})
|
|
if not excerpts:
|
|
lexical_result = self._gateway.retrieve_lexical_code(
|
|
rag_session_id,
|
|
intent.normalized_query,
|
|
limit=6,
|
|
path_prefixes=path_prefixes or None,
|
|
exclude_tests=exclude_tests,
|
|
include_spans=True,
|
|
)
|
|
missing.extend(lexical_result.missing)
|
|
excerpts, excerpt_evidence = _lexical_excerpts(lexical_result.items)
|
|
if not excerpts:
|
|
missing.append("layer:C0 empty")
|
|
evidence_index = _evidence_index(selected_entrypoints, seed_symbols)
|
|
evidence_index.update(excerpt_evidence)
|
|
missing.extend(_missing(selected_entrypoints, seed_symbols, trace_paths, excerpts))
|
|
return ExplainPack(
|
|
intent=intent,
|
|
selected_entrypoints=selected_entrypoints,
|
|
seed_symbols=seed_symbols,
|
|
trace_paths=trace_paths,
|
|
evidence_index=evidence_index,
|
|
code_excerpts=excerpts,
|
|
missing=_cleanup_missing(_dedupe(missing), has_excerpts=bool(excerpts)),
|
|
conflicts=[],
|
|
)
|
|
|
|
def _merge_test_fallback(
|
|
self,
|
|
pack: ExplainPack,
|
|
rag_session_id: str,
|
|
intent: ExplainIntent,
|
|
path_prefixes: list[str],
|
|
) -> None:
|
|
lexical_result = self._gateway.retrieve_lexical_code(
|
|
rag_session_id,
|
|
intent.normalized_query,
|
|
limit=6,
|
|
path_prefixes=path_prefixes or None,
|
|
exclude_tests=False,
|
|
include_spans=True,
|
|
)
|
|
excerpt_offset = len([key for key in pack.evidence_index if key.startswith("excerpt_")])
|
|
excerpts, evidence = _lexical_excerpts(
|
|
lexical_result.items,
|
|
start_index=excerpt_offset,
|
|
is_test_fallback=True,
|
|
)
|
|
if not excerpts:
|
|
pack.missing = _dedupe(pack.missing + lexical_result.missing)
|
|
return
|
|
seen = {(item.path, item.start_line, item.end_line, item.content) for item in pack.code_excerpts}
|
|
for excerpt in excerpts:
|
|
key = (excerpt.path, excerpt.start_line, excerpt.end_line, excerpt.content)
|
|
if key in seen:
|
|
continue
|
|
pack.code_excerpts.append(excerpt)
|
|
seen.add(key)
|
|
pack.evidence_index.update(evidence)
|
|
pack.missing = _cleanup_missing(_dedupe(pack.missing + lexical_result.missing), has_excerpts=bool(pack.code_excerpts))
|
|
|
|
def _entrypoints(
|
|
self,
|
|
rag_session_id: str,
|
|
intent: ExplainIntent,
|
|
path_prefixes: list[str],
|
|
*,
|
|
exclude_tests: bool,
|
|
) -> LayerRetrievalResult:
|
|
return self._gateway.retrieve_layer(
|
|
rag_session_id,
|
|
intent.normalized_query,
|
|
RagLayer.CODE_ENTRYPOINTS,
|
|
limit=6,
|
|
path_prefixes=path_prefixes or None,
|
|
exclude_tests=exclude_tests,
|
|
prefer_non_tests=True,
|
|
include_spans=True,
|
|
)
|
|
|
|
def _filter_entrypoints(self, intent: ExplainIntent, items: list[LayeredRetrievalItem]) -> list[LayeredRetrievalItem]:
|
|
if not intent.expected_entry_types:
|
|
return items[:3]
|
|
filtered = [item for item in items if str(item.metadata.get("entry_type") or "") in intent.expected_entry_types]
|
|
return filtered[:3] or items[:3]
|
|
|
|
def _seed_symbols(
|
|
self,
|
|
rag_session_id: str,
|
|
intent: ExplainIntent,
|
|
path_prefixes: list[str],
|
|
entrypoints: list[LayeredRetrievalItem],
|
|
*,
|
|
exclude_tests: bool,
|
|
) -> LayerRetrievalResult:
|
|
symbol_result = self._gateway.retrieve_layer(
|
|
rag_session_id,
|
|
intent.normalized_query,
|
|
RagLayer.CODE_SYMBOL_CATALOG,
|
|
limit=12,
|
|
path_prefixes=path_prefixes or None,
|
|
exclude_tests=exclude_tests,
|
|
prefer_non_tests=True,
|
|
include_spans=True,
|
|
)
|
|
handlers: list[LayeredRetrievalItem] = []
|
|
handler_ids = [str(item.metadata.get("handler_symbol_id") or "") for item in entrypoints]
|
|
if handler_ids:
|
|
handlers = self._graph.get_symbols_by_ids(rag_session_id, [item for item in handler_ids if item])
|
|
seeds: list[LayeredRetrievalItem] = []
|
|
seen: set[str] = set()
|
|
for item in handlers + symbol_result.items:
|
|
symbol_id = str(item.metadata.get("symbol_id") or "")
|
|
if not symbol_id or symbol_id in seen:
|
|
continue
|
|
seen.add(symbol_id)
|
|
seeds.append(item)
|
|
if len(seeds) >= 8:
|
|
break
|
|
return LayerRetrievalResult(items=seeds, missing=list(symbol_result.missing))
|
|
|
|
def _log_pack(self, rag_session_id: str, pack: ExplainPack) -> None:
|
|
prod_excerpt_count = len([excerpt for excerpt in pack.code_excerpts if not _is_test_excerpt(excerpt)])
|
|
test_excerpt_count = len(pack.code_excerpts) - prod_excerpt_count
|
|
LOGGER.warning(
|
|
"code explain pack: rag_session_id=%s entrypoints=%s seeds=%s paths=%s excerpts=%s prod_excerpt_count=%s test_excerpt_count=%s missing=%s",
|
|
rag_session_id,
|
|
len(pack.selected_entrypoints),
|
|
len(pack.seed_symbols),
|
|
len(pack.trace_paths),
|
|
len(pack.code_excerpts),
|
|
prod_excerpt_count,
|
|
test_excerpt_count,
|
|
pack.missing,
|
|
)
|
|
|
|
|
|
def _evidence_index(
|
|
entrypoints: list[LayeredRetrievalItem],
|
|
seed_symbols: list[LayeredRetrievalItem],
|
|
) -> dict[str, EvidenceItem]:
|
|
result: dict[str, EvidenceItem] = {}
|
|
for index, item in enumerate(entrypoints, start=1):
|
|
evidence_id = f"entrypoint_{index}"
|
|
result[evidence_id] = EvidenceItem(
|
|
evidence_id=evidence_id,
|
|
kind="entrypoint",
|
|
summary=item.title,
|
|
location=item.location,
|
|
supports=[str(item.metadata.get("handler_symbol_id") or "")],
|
|
)
|
|
for index, item in enumerate(seed_symbols, start=1):
|
|
evidence_id = f"symbol_{index}"
|
|
result[evidence_id] = EvidenceItem(
|
|
evidence_id=evidence_id,
|
|
kind="symbol",
|
|
summary=item.title,
|
|
location=item.location,
|
|
supports=[str(item.metadata.get("symbol_id") or "")],
|
|
)
|
|
return result
|
|
|
|
|
|
def _missing(
|
|
entrypoints: list[LayeredRetrievalItem],
|
|
seed_symbols: list[LayeredRetrievalItem],
|
|
trace_paths,
|
|
excerpts,
|
|
) -> list[str]:
|
|
missing: list[str] = []
|
|
if not entrypoints:
|
|
missing.append("entrypoints")
|
|
if not seed_symbols:
|
|
missing.append("seed_symbols")
|
|
if not trace_paths:
|
|
missing.append("trace_paths")
|
|
if not excerpts:
|
|
missing.append("code_excerpts")
|
|
return missing
|
|
|
|
|
|
def _lexical_excerpts(
|
|
items: list[LayeredRetrievalItem],
|
|
*,
|
|
start_index: int = 0,
|
|
is_test_fallback: bool = False,
|
|
) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]:
|
|
excerpts: list[CodeExcerpt] = []
|
|
evidence_index: dict[str, EvidenceItem] = {}
|
|
for item in items:
|
|
evidence_id = f"excerpt_{start_index + len(evidence_index) + 1}"
|
|
location = item.location
|
|
evidence_index[evidence_id] = EvidenceItem(
|
|
evidence_id=evidence_id,
|
|
kind="excerpt",
|
|
summary=item.title or item.source,
|
|
location=location,
|
|
supports=[],
|
|
)
|
|
focus = "lexical"
|
|
if _item_is_test(item):
|
|
focus = "test:lexical"
|
|
elif is_test_fallback:
|
|
focus = "lexical"
|
|
excerpts.append(
|
|
CodeExcerpt(
|
|
evidence_id=evidence_id,
|
|
symbol_id=str(item.metadata.get("symbol_id") or "") or None,
|
|
title=item.title or item.source,
|
|
path=item.source,
|
|
start_line=location.start_line if location else None,
|
|
end_line=location.end_line if location else None,
|
|
content=item.content,
|
|
focus=focus,
|
|
)
|
|
)
|
|
return excerpts, evidence_index
|
|
|
|
|
|
def _item_is_test(item: LayeredRetrievalItem) -> bool:
|
|
return bool(item.metadata.get("is_test")) or is_test_path(item.source)
|
|
|
|
|
|
def _is_test_excerpt(excerpt: CodeExcerpt) -> bool:
|
|
return excerpt.focus.startswith("test:") or is_test_path(excerpt.path)
|
|
|
|
|
|
def _path_prefixes(intent: ExplainIntent, file_candidates: list[dict]) -> list[str]:
|
|
values: list[str] = []
|
|
for path in intent.hints.paths:
|
|
prefix = path.rsplit("/", 1)[0] if "/" in path else path
|
|
if prefix and prefix not in values:
|
|
values.append(prefix)
|
|
for item in file_candidates[:6]:
|
|
path = str(item.get("path") or "")
|
|
prefix = path.rsplit("/", 1)[0] if "/" in path else ""
|
|
if prefix and prefix not in values:
|
|
values.append(prefix)
|
|
return values
|
|
|
|
|
|
def _cleanup_missing(values: list[str], *, has_excerpts: bool) -> list[str]:
|
|
if not has_excerpts:
|
|
return values
|
|
return [value for value in values if value not in {"code_excerpts", "layer:C0 empty"}]
|
|
|
|
|
|
def _dedupe(values: list[str]) -> list[str]:
|
|
result: list[str] = []
|
|
for value in values:
|
|
item = value.strip()
|
|
if item and item not in result:
|
|
result.append(item)
|
|
return result
|