from __future__ import annotations import logging from typing import TYPE_CHECKING from app.modules.rag.contracts.enums import RagLayer from app.modules.rag.explain.intent_builder import ExplainIntentBuilder from app.modules.rag.explain.layered_gateway import LayerRetrievalResult, LayeredRetrievalGateway from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, ExplainPack, LayeredRetrievalItem from app.modules.rag.explain.source_excerpt_fetcher import SourceExcerptFetcher from app.modules.rag.explain.trace_builder import TraceBuilder from app.modules.rag.retrieval.test_filter import exclude_tests_default, is_test_path LOGGER = logging.getLogger(__name__) _MIN_EXCERPTS = 2 if TYPE_CHECKING: from app.modules.rag.explain.graph_repository import CodeGraphRepository from app.modules.rag.explain.models import ExplainIntent class CodeExplainRetrieverV2: def __init__( self, gateway: LayeredRetrievalGateway, graph_repository: CodeGraphRepository, intent_builder: ExplainIntentBuilder | None = None, trace_builder: TraceBuilder | None = None, excerpt_fetcher: SourceExcerptFetcher | None = None, ) -> None: self._gateway = gateway self._graph = graph_repository self._intent_builder = intent_builder or ExplainIntentBuilder() self._trace_builder = trace_builder or TraceBuilder(graph_repository) self._excerpt_fetcher = excerpt_fetcher or SourceExcerptFetcher(graph_repository) def build_pack( self, rag_session_id: str, user_query: str, *, file_candidates: list[dict] | None = None, ) -> ExplainPack: intent = self._intent_builder.build(user_query) path_prefixes = _path_prefixes(intent, file_candidates or []) exclude_tests = exclude_tests_default() and not intent.include_tests pack = self._run_pass(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests) if exclude_tests and len(pack.code_excerpts) < _MIN_EXCERPTS: self._merge_test_fallback(pack, rag_session_id, intent, path_prefixes) self._log_pack(rag_session_id, pack) return pack def _run_pass( self, rag_session_id: str, intent: ExplainIntent, path_prefixes: list[str], *, exclude_tests: bool, ) -> ExplainPack: missing: list[str] = [] entrypoints_result = self._entrypoints(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests) missing.extend(entrypoints_result.missing) selected_entrypoints = self._filter_entrypoints(intent, entrypoints_result.items) if not selected_entrypoints: missing.append("layer:C3 empty") seed_result = self._seed_symbols(rag_session_id, intent, path_prefixes, selected_entrypoints, exclude_tests=exclude_tests) missing.extend(seed_result.missing) seed_symbols = seed_result.items if not seed_symbols: missing.append("layer:C1 empty") depth = 4 if intent.depth == "deep" else 3 if intent.depth == "medium" else 2 trace_paths = self._trace_builder.build_paths(rag_session_id, seed_symbols, max_depth=depth) if seed_symbols else [] excerpts, excerpt_evidence = self._excerpt_fetcher.fetch(rag_session_id, trace_paths) if trace_paths else ([], {}) if not excerpts: lexical_result = self._gateway.retrieve_lexical_code( rag_session_id, intent.normalized_query, limit=6, path_prefixes=path_prefixes or None, exclude_tests=exclude_tests, include_spans=True, ) missing.extend(lexical_result.missing) excerpts, excerpt_evidence = _lexical_excerpts(lexical_result.items) if not excerpts: missing.append("layer:C0 empty") evidence_index = _evidence_index(selected_entrypoints, seed_symbols) evidence_index.update(excerpt_evidence) missing.extend(_missing(selected_entrypoints, seed_symbols, trace_paths, excerpts)) return ExplainPack( intent=intent, selected_entrypoints=selected_entrypoints, seed_symbols=seed_symbols, trace_paths=trace_paths, evidence_index=evidence_index, code_excerpts=excerpts, missing=_cleanup_missing(_dedupe(missing), has_excerpts=bool(excerpts)), conflicts=[], ) def _merge_test_fallback( self, pack: ExplainPack, rag_session_id: str, intent: ExplainIntent, path_prefixes: list[str], ) -> None: lexical_result = self._gateway.retrieve_lexical_code( rag_session_id, intent.normalized_query, limit=6, path_prefixes=path_prefixes or None, exclude_tests=False, include_spans=True, ) excerpt_offset = len([key for key in pack.evidence_index if key.startswith("excerpt_")]) excerpts, evidence = _lexical_excerpts( lexical_result.items, start_index=excerpt_offset, is_test_fallback=True, ) if not excerpts: pack.missing = _dedupe(pack.missing + lexical_result.missing) return seen = {(item.path, item.start_line, item.end_line, item.content) for item in pack.code_excerpts} for excerpt in excerpts: key = (excerpt.path, excerpt.start_line, excerpt.end_line, excerpt.content) if key in seen: continue pack.code_excerpts.append(excerpt) seen.add(key) pack.evidence_index.update(evidence) pack.missing = _cleanup_missing(_dedupe(pack.missing + lexical_result.missing), has_excerpts=bool(pack.code_excerpts)) def _entrypoints( self, rag_session_id: str, intent: ExplainIntent, path_prefixes: list[str], *, exclude_tests: bool, ) -> LayerRetrievalResult: return self._gateway.retrieve_layer( rag_session_id, intent.normalized_query, RagLayer.CODE_ENTRYPOINTS, limit=6, path_prefixes=path_prefixes or None, exclude_tests=exclude_tests, prefer_non_tests=True, include_spans=True, ) def _filter_entrypoints(self, intent: ExplainIntent, items: list[LayeredRetrievalItem]) -> list[LayeredRetrievalItem]: if not intent.expected_entry_types: return items[:3] filtered = [item for item in items if str(item.metadata.get("entry_type") or "") in intent.expected_entry_types] return filtered[:3] or items[:3] def _seed_symbols( self, rag_session_id: str, intent: ExplainIntent, path_prefixes: list[str], entrypoints: list[LayeredRetrievalItem], *, exclude_tests: bool, ) -> LayerRetrievalResult: symbol_result = self._gateway.retrieve_layer( rag_session_id, intent.normalized_query, RagLayer.CODE_SYMBOL_CATALOG, limit=12, path_prefixes=path_prefixes or None, exclude_tests=exclude_tests, prefer_non_tests=True, include_spans=True, ) handlers: list[LayeredRetrievalItem] = [] handler_ids = [str(item.metadata.get("handler_symbol_id") or "") for item in entrypoints] if handler_ids: handlers = self._graph.get_symbols_by_ids(rag_session_id, [item for item in handler_ids if item]) seeds: list[LayeredRetrievalItem] = [] seen: set[str] = set() for item in handlers + symbol_result.items: symbol_id = str(item.metadata.get("symbol_id") or "") if not symbol_id or symbol_id in seen: continue seen.add(symbol_id) seeds.append(item) if len(seeds) >= 8: break return LayerRetrievalResult(items=seeds, missing=list(symbol_result.missing)) def _log_pack(self, rag_session_id: str, pack: ExplainPack) -> None: prod_excerpt_count = len([excerpt for excerpt in pack.code_excerpts if not _is_test_excerpt(excerpt)]) test_excerpt_count = len(pack.code_excerpts) - prod_excerpt_count LOGGER.warning( "code explain pack: rag_session_id=%s entrypoints=%s seeds=%s paths=%s excerpts=%s prod_excerpt_count=%s test_excerpt_count=%s missing=%s", rag_session_id, len(pack.selected_entrypoints), len(pack.seed_symbols), len(pack.trace_paths), len(pack.code_excerpts), prod_excerpt_count, test_excerpt_count, pack.missing, ) def _evidence_index( entrypoints: list[LayeredRetrievalItem], seed_symbols: list[LayeredRetrievalItem], ) -> dict[str, EvidenceItem]: result: dict[str, EvidenceItem] = {} for index, item in enumerate(entrypoints, start=1): evidence_id = f"entrypoint_{index}" result[evidence_id] = EvidenceItem( evidence_id=evidence_id, kind="entrypoint", summary=item.title, location=item.location, supports=[str(item.metadata.get("handler_symbol_id") or "")], ) for index, item in enumerate(seed_symbols, start=1): evidence_id = f"symbol_{index}" result[evidence_id] = EvidenceItem( evidence_id=evidence_id, kind="symbol", summary=item.title, location=item.location, supports=[str(item.metadata.get("symbol_id") or "")], ) return result def _missing( entrypoints: list[LayeredRetrievalItem], seed_symbols: list[LayeredRetrievalItem], trace_paths, excerpts, ) -> list[str]: missing: list[str] = [] if not entrypoints: missing.append("entrypoints") if not seed_symbols: missing.append("seed_symbols") if not trace_paths: missing.append("trace_paths") if not excerpts: missing.append("code_excerpts") return missing def _lexical_excerpts( items: list[LayeredRetrievalItem], *, start_index: int = 0, is_test_fallback: bool = False, ) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]: excerpts: list[CodeExcerpt] = [] evidence_index: dict[str, EvidenceItem] = {} for item in items: evidence_id = f"excerpt_{start_index + len(evidence_index) + 1}" location = item.location evidence_index[evidence_id] = EvidenceItem( evidence_id=evidence_id, kind="excerpt", summary=item.title or item.source, location=location, supports=[], ) focus = "lexical" if _item_is_test(item): focus = "test:lexical" elif is_test_fallback: focus = "lexical" excerpts.append( CodeExcerpt( evidence_id=evidence_id, symbol_id=str(item.metadata.get("symbol_id") or "") or None, title=item.title or item.source, path=item.source, start_line=location.start_line if location else None, end_line=location.end_line if location else None, content=item.content, focus=focus, ) ) return excerpts, evidence_index def _item_is_test(item: LayeredRetrievalItem) -> bool: return bool(item.metadata.get("is_test")) or is_test_path(item.source) def _is_test_excerpt(excerpt: CodeExcerpt) -> bool: return excerpt.focus.startswith("test:") or is_test_path(excerpt.path) def _path_prefixes(intent: ExplainIntent, file_candidates: list[dict]) -> list[str]: values: list[str] = [] for path in intent.hints.paths: prefix = path.rsplit("/", 1)[0] if "/" in path else path if prefix and prefix not in values: values.append(prefix) for item in file_candidates[:6]: path = str(item.get("path") or "") prefix = path.rsplit("/", 1)[0] if "/" in path else "" if prefix and prefix not in values: values.append(prefix) return values def _cleanup_missing(values: list[str], *, has_excerpts: bool) -> list[str]: if not has_excerpts: return values return [value for value in values if value not in {"code_excerpts", "layer:C0 empty"}] def _dedupe(values: list[str]) -> list[str]: result: list[str] = [] for value in values: item = value.strip() if item and item not in result: result.append(item) return result