Фиксация изменений
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
__all__ = [
|
||||
"CodeExcerpt",
|
||||
"CodeExplainRetrieverV2",
|
||||
"CodeGraphRepository",
|
||||
"EvidenceItem",
|
||||
"ExplainIntent",
|
||||
"ExplainIntentBuilder",
|
||||
"ExplainPack",
|
||||
"LayeredRetrievalGateway",
|
||||
"PromptBudgeter",
|
||||
"TracePath",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
module_map = {
|
||||
"CodeExcerpt": "app.modules.rag.explain.models",
|
||||
"EvidenceItem": "app.modules.rag.explain.models",
|
||||
"ExplainIntent": "app.modules.rag.explain.models",
|
||||
"ExplainPack": "app.modules.rag.explain.models",
|
||||
"TracePath": "app.modules.rag.explain.models",
|
||||
"ExplainIntentBuilder": "app.modules.rag.explain.intent_builder",
|
||||
"PromptBudgeter": "app.modules.rag.explain.budgeter",
|
||||
"LayeredRetrievalGateway": "app.modules.rag.explain.layered_gateway",
|
||||
"CodeGraphRepository": "app.modules.rag.explain.graph_repository",
|
||||
"CodeExplainRetrieverV2": "app.modules.rag.explain.retriever_v2",
|
||||
}
|
||||
module_name = module_map.get(name)
|
||||
if module_name is None:
|
||||
raise AttributeError(name)
|
||||
module = import_module(module_name)
|
||||
return getattr(module, name)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.modules.rag.explain.models import ExplainPack
|
||||
|
||||
|
||||
class PromptBudgeter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_paths: int = 3,
|
||||
max_symbols: int = 25,
|
||||
max_excerpts: int = 40,
|
||||
max_chars: int = 30000,
|
||||
) -> None:
|
||||
self._max_paths = max_paths
|
||||
self._max_symbols = max_symbols
|
||||
self._max_excerpts = max_excerpts
|
||||
self._max_chars = max_chars
|
||||
|
||||
def build_prompt_input(self, question: str, pack: ExplainPack) -> str:
|
||||
symbol_ids: list[str] = []
|
||||
for path in pack.trace_paths[: self._max_paths]:
|
||||
for symbol_id in path.symbol_ids:
|
||||
if symbol_id and symbol_id not in symbol_ids and len(symbol_ids) < self._max_symbols:
|
||||
symbol_ids.append(symbol_id)
|
||||
excerpts = []
|
||||
total_chars = 0
|
||||
for excerpt in pack.code_excerpts:
|
||||
if symbol_ids and excerpt.symbol_id and excerpt.symbol_id not in symbol_ids:
|
||||
continue
|
||||
body = excerpt.content.strip()
|
||||
remaining = self._max_chars - total_chars
|
||||
if remaining <= 0 or len(excerpts) >= self._max_excerpts:
|
||||
break
|
||||
if len(body) > remaining:
|
||||
body = body[:remaining].rstrip() + "...[truncated]"
|
||||
excerpts.append(
|
||||
{
|
||||
"evidence_id": excerpt.evidence_id,
|
||||
"title": excerpt.title,
|
||||
"path": excerpt.path,
|
||||
"start_line": excerpt.start_line,
|
||||
"end_line": excerpt.end_line,
|
||||
"focus": excerpt.focus,
|
||||
"content": body,
|
||||
}
|
||||
)
|
||||
total_chars += len(body)
|
||||
payload = {
|
||||
"question": question,
|
||||
"intent": pack.intent.model_dump(mode="json"),
|
||||
"selected_entrypoints": [item.model_dump(mode="json") for item in pack.selected_entrypoints[:5]],
|
||||
"seed_symbols": [item.model_dump(mode="json") for item in pack.seed_symbols[: self._max_symbols]],
|
||||
"trace_paths": [path.model_dump(mode="json") for path in pack.trace_paths[: self._max_paths]],
|
||||
"evidence_index": {key: value.model_dump(mode="json") for key, value in pack.evidence_index.items()},
|
||||
"code_excerpts": excerpts,
|
||||
"missing": pack.missing,
|
||||
"conflicts": pack.conflicts,
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.rag.explain.models import CodeExcerpt, LayeredRetrievalItem
|
||||
|
||||
|
||||
class ExcerptPlanner:
|
||||
_FOCUS_TOKENS = ("raise", "except", "db", "select", "insert", "update", "delete", "http", "publish", "emit")
|
||||
|
||||
def plan(self, chunk: LayeredRetrievalItem, *, evidence_id: str, symbol_id: str | None) -> list[CodeExcerpt]:
|
||||
location = chunk.location
|
||||
if location is None:
|
||||
return []
|
||||
excerpts = [
|
||||
CodeExcerpt(
|
||||
evidence_id=evidence_id,
|
||||
symbol_id=symbol_id,
|
||||
title=chunk.title,
|
||||
path=location.path,
|
||||
start_line=location.start_line,
|
||||
end_line=location.end_line,
|
||||
content=chunk.content.strip(),
|
||||
focus="overview",
|
||||
)
|
||||
]
|
||||
focus = self._focus_excerpt(chunk, evidence_id=evidence_id, symbol_id=symbol_id)
|
||||
if focus is not None:
|
||||
excerpts.append(focus)
|
||||
return excerpts
|
||||
|
||||
def _focus_excerpt(
|
||||
self,
|
||||
chunk: LayeredRetrievalItem,
|
||||
*,
|
||||
evidence_id: str,
|
||||
symbol_id: str | None,
|
||||
) -> CodeExcerpt | None:
|
||||
location = chunk.location
|
||||
if location is None:
|
||||
return None
|
||||
lines = chunk.content.splitlines()
|
||||
for index, line in enumerate(lines):
|
||||
lowered = line.lower()
|
||||
if not any(token in lowered for token in self._FOCUS_TOKENS):
|
||||
continue
|
||||
start = max(0, index - 2)
|
||||
end = min(len(lines), index + 3)
|
||||
if end - start >= len(lines):
|
||||
return None
|
||||
return CodeExcerpt(
|
||||
evidence_id=evidence_id,
|
||||
symbol_id=symbol_id,
|
||||
title=f"{chunk.title}:focus",
|
||||
path=location.path,
|
||||
start_line=(location.start_line or 1) + start,
|
||||
end_line=(location.start_line or 1) + end - 1,
|
||||
content="\n".join(lines[start:end]).strip(),
|
||||
focus="focus",
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.shared.db import get_engine
|
||||
|
||||
|
||||
class CodeGraphRepository:
|
||||
def get_out_edges(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
src_symbol_ids: list[str],
|
||||
edge_types: list[str],
|
||||
limit_per_src: int,
|
||||
) -> list[LayeredRetrievalItem]:
|
||||
if not src_symbol_ids:
|
||||
return []
|
||||
sql = """
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C2_DEPENDENCY_GRAPH'
|
||||
AND CAST(metadata_json AS jsonb)->>'src_symbol_id' = ANY(:src_ids)
|
||||
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
|
||||
ORDER BY path, span_start
|
||||
"""
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(sql),
|
||||
{"sid": rag_session_id, "src_ids": src_symbol_ids, "edge_types": edge_types},
|
||||
).mappings().fetchall()
|
||||
grouped: dict[str, int] = {}
|
||||
items: list[LayeredRetrievalItem] = []
|
||||
for row in rows:
|
||||
metadata = self._loads(row.get("metadata_json"))
|
||||
src_symbol_id = str(metadata.get("src_symbol_id") or "")
|
||||
grouped[src_symbol_id] = grouped.get(src_symbol_id, 0) + 1
|
||||
if grouped[src_symbol_id] > limit_per_src:
|
||||
continue
|
||||
items.append(self._to_item(row, metadata))
|
||||
return items
|
||||
|
||||
def get_in_edges(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
dst_symbol_ids: list[str],
|
||||
edge_types: list[str],
|
||||
limit_per_dst: int,
|
||||
) -> list[LayeredRetrievalItem]:
|
||||
if not dst_symbol_ids:
|
||||
return []
|
||||
sql = """
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C2_DEPENDENCY_GRAPH'
|
||||
AND CAST(metadata_json AS jsonb)->>'dst_symbol_id' = ANY(:dst_ids)
|
||||
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
|
||||
ORDER BY path, span_start
|
||||
"""
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(sql),
|
||||
{"sid": rag_session_id, "dst_ids": dst_symbol_ids, "edge_types": edge_types},
|
||||
).mappings().fetchall()
|
||||
grouped: dict[str, int] = {}
|
||||
items: list[LayeredRetrievalItem] = []
|
||||
for row in rows:
|
||||
metadata = self._loads(row.get("metadata_json"))
|
||||
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
|
||||
grouped[dst_symbol_id] = grouped.get(dst_symbol_id, 0) + 1
|
||||
if grouped[dst_symbol_id] > limit_per_dst:
|
||||
continue
|
||||
items.append(self._to_item(row, metadata))
|
||||
return items
|
||||
|
||||
def resolve_symbol_by_ref(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
dst_ref: str,
|
||||
package_hint: str | None = None,
|
||||
) -> LayeredRetrievalItem | None:
|
||||
ref = (dst_ref or "").strip()
|
||||
if not ref:
|
||||
return None
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end, qname
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C1_SYMBOL_CATALOG'
|
||||
AND (qname = :ref OR title = :ref OR qname LIKE :tail)
|
||||
ORDER BY path
|
||||
LIMIT 12
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "ref": ref, "tail": f"%{ref}"},
|
||||
).mappings().fetchall()
|
||||
best: LayeredRetrievalItem | None = None
|
||||
best_score = -1
|
||||
for row in rows:
|
||||
metadata = self._loads(row.get("metadata_json"))
|
||||
package = str(metadata.get("package_or_module") or "")
|
||||
score = 0
|
||||
if str(row.get("qname") or "") == ref:
|
||||
score += 3
|
||||
if str(row.get("title") or "") == ref:
|
||||
score += 2
|
||||
if package_hint and package.startswith(package_hint):
|
||||
score += 3
|
||||
if package_hint and package_hint in str(row.get("path") or ""):
|
||||
score += 1
|
||||
if score > best_score:
|
||||
best = self._to_item(row, metadata)
|
||||
best_score = score
|
||||
return best
|
||||
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]) -> list[LayeredRetrievalItem]:
|
||||
if not symbol_ids:
|
||||
return []
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C1_SYMBOL_CATALOG'
|
||||
AND symbol_id = ANY(:symbol_ids)
|
||||
ORDER BY path, span_start
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "symbol_ids": symbol_ids},
|
||||
).mappings().fetchall()
|
||||
return [self._to_item(row, self._loads(row.get("metadata_json"))) for row in rows]
|
||||
|
||||
def get_chunks_by_symbol_ids(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
symbol_ids: list[str],
|
||||
prefer_chunk_type: str = "symbol_block",
|
||||
) -> list[LayeredRetrievalItem]:
|
||||
symbols = self.get_symbols_by_ids(rag_session_id, symbol_ids)
|
||||
chunks: list[LayeredRetrievalItem] = []
|
||||
for symbol in symbols:
|
||||
location = symbol.location
|
||||
if location is None:
|
||||
continue
|
||||
chunk = self._chunk_for_symbol(rag_session_id, symbol, prefer_chunk_type=prefer_chunk_type)
|
||||
if chunk is not None:
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
def _chunk_for_symbol(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
symbol: LayeredRetrievalItem,
|
||||
*,
|
||||
prefer_chunk_type: str,
|
||||
) -> LayeredRetrievalItem | None:
|
||||
location = symbol.location
|
||||
if location is None:
|
||||
return None
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C0_SOURCE_CHUNKS'
|
||||
AND path = :path
|
||||
AND COALESCE(span_start, 0) <= :end_line
|
||||
AND COALESCE(span_end, 999999) >= :start_line
|
||||
ORDER BY
|
||||
CASE WHEN CAST(metadata_json AS jsonb)->>'chunk_type' = :prefer_chunk_type THEN 0 ELSE 1 END,
|
||||
ABS(COALESCE(span_start, 0) - :start_line)
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": rag_session_id,
|
||||
"path": location.path,
|
||||
"start_line": location.start_line or 0,
|
||||
"end_line": location.end_line or 999999,
|
||||
"prefer_chunk_type": prefer_chunk_type,
|
||||
},
|
||||
).mappings().fetchall()
|
||||
if not rows:
|
||||
return None
|
||||
row = rows[0]
|
||||
return self._to_item(row, self._loads(row.get("metadata_json")))
|
||||
|
||||
def _to_item(self, row, metadata: dict) -> LayeredRetrievalItem:
|
||||
return LayeredRetrievalItem(
|
||||
source=str(row.get("path") or ""),
|
||||
content=str(row.get("content") or ""),
|
||||
layer=str(row.get("layer") or ""),
|
||||
title=str(row.get("title") or ""),
|
||||
metadata=metadata,
|
||||
location=CodeLocation(
|
||||
path=str(row.get("path") or ""),
|
||||
start_line=row.get("span_start"),
|
||||
end_line=row.get("span_end"),
|
||||
),
|
||||
)
|
||||
|
||||
def _loads(self, value) -> dict:
|
||||
if not value:
|
||||
return {}
|
||||
return json.loads(str(value))
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.modules.rag.explain.models import ExplainHints, ExplainIntent
|
||||
from app.modules.rag.retrieval.query_terms import extract_query_terms
|
||||
|
||||
|
||||
class ExplainIntentBuilder:
|
||||
_ROUTE_RE = re.compile(r"(/[A-Za-z0-9_./{}:-]+)")
|
||||
_FILE_RE = re.compile(r"([A-Za-z0-9_./-]+\.py)")
|
||||
_SYMBOL_RE = re.compile(r"\b([A-Z][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*|[A-Z][A-Za-z0-9_]{2,}|[a-z_][A-Za-z0-9_]{2,})\b")
|
||||
_COMMAND_RE = re.compile(r"`([A-Za-z0-9:_-]+)`")
|
||||
_TEST_KEYWORDS = (
|
||||
"тест",
|
||||
"tests",
|
||||
"test ",
|
||||
"unit-test",
|
||||
"unit test",
|
||||
"юнит-тест",
|
||||
"pytest",
|
||||
"spec",
|
||||
"как покрыто тестами",
|
||||
"как проверяется",
|
||||
"how is it tested",
|
||||
"how it's tested",
|
||||
)
|
||||
|
||||
def build(self, user_query: str) -> ExplainIntent:
|
||||
normalized = " ".join((user_query or "").split())
|
||||
lowered = normalized.lower()
|
||||
keywords = self._keywords(normalized)
|
||||
hints = ExplainHints(
|
||||
paths=self._dedupe(self._FILE_RE.findall(normalized)),
|
||||
symbols=self._symbols(normalized),
|
||||
endpoints=self._dedupe(self._ROUTE_RE.findall(normalized)),
|
||||
commands=self._commands(normalized, lowered),
|
||||
)
|
||||
return ExplainIntent(
|
||||
raw_query=user_query,
|
||||
normalized_query=normalized,
|
||||
keywords=keywords[:12],
|
||||
hints=hints,
|
||||
include_tests=self._include_tests(lowered),
|
||||
expected_entry_types=self._entry_types(lowered, hints),
|
||||
depth=self._depth(lowered),
|
||||
)
|
||||
|
||||
def _keywords(self, text: str) -> list[str]:
|
||||
keywords = extract_query_terms(text)
|
||||
for token in self._symbols(text):
|
||||
if token not in keywords:
|
||||
keywords.append(token)
|
||||
for token in self._ROUTE_RE.findall(text):
|
||||
if token not in keywords:
|
||||
keywords.append(token)
|
||||
return self._dedupe(keywords)
|
||||
|
||||
def _symbols(self, text: str) -> list[str]:
|
||||
values = []
|
||||
for raw in self._SYMBOL_RE.findall(text):
|
||||
token = raw.strip()
|
||||
if len(token) < 3:
|
||||
continue
|
||||
if token.endswith(".py"):
|
||||
continue
|
||||
values.append(token)
|
||||
return self._dedupe(values)
|
||||
|
||||
def _commands(self, text: str, lowered: str) -> list[str]:
|
||||
values = list(self._COMMAND_RE.findall(text))
|
||||
if " command " in f" {lowered} ":
|
||||
values.extend(re.findall(r"command\s+([A-Za-z0-9:_-]+)", lowered))
|
||||
if " cli " in f" {lowered} ":
|
||||
values.extend(re.findall(r"cli\s+([A-Za-z0-9:_-]+)", lowered))
|
||||
return self._dedupe(values)
|
||||
|
||||
def _entry_types(self, lowered: str, hints: ExplainHints) -> list[str]:
|
||||
if hints.endpoints or any(token in lowered for token in ("endpoint", "route", "handler", "http", "api")):
|
||||
return ["http"]
|
||||
if hints.commands or any(token in lowered for token in ("cli", "command", "click", "typer")):
|
||||
return ["cli"]
|
||||
return ["http", "cli"]
|
||||
|
||||
def _depth(self, lowered: str) -> str:
|
||||
if any(token in lowered for token in ("deep", "подроб", "деталь", "full flow", "trace")):
|
||||
return "deep"
|
||||
if any(token in lowered for token in ("high level", "overview", "кратко", "summary")):
|
||||
return "high"
|
||||
return "medium"
|
||||
|
||||
def _include_tests(self, lowered: str) -> bool:
|
||||
normalized = f" {lowered} "
|
||||
return any(token in normalized for token in self._TEST_KEYWORDS)
|
||||
|
||||
def _dedupe(self, values: list[str]) -> list[str]:
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
item = value.strip()
|
||||
if item and item not in result:
|
||||
result.append(item)
|
||||
return result
|
||||
@@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.retrieval.test_filter import build_test_filters, debug_disable_test_filter
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.persistence.repository import RagRepository
|
||||
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LayerRetrievalResult:
|
||||
items: list[LayeredRetrievalItem]
|
||||
missing: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class LayeredRetrievalGateway:
|
||||
def __init__(self, repository: RagRepository, embedder: GigaChatEmbedder) -> None:
|
||||
self._repository = repository
|
||||
self._embedder = embedder
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
) -> LayerRetrievalResult:
|
||||
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
||||
filter_args = self._filter_args(effective_exclude_tests)
|
||||
query_embedding: list[float] | None = None
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repository.retrieve(
|
||||
rag_session_id,
|
||||
query_embedding,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
layers=[layer],
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
||||
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
||||
prefer_non_tests=prefer_non_tests or not effective_exclude_tests,
|
||||
)
|
||||
return self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label="layered retrieval",
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
except Exception as exc:
|
||||
if query_embedding is None:
|
||||
self._log_failure(
|
||||
label="layered retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
)
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
||||
retry_result = self._retry_without_test_filter(
|
||||
operation=lambda: self._repository.retrieve(
|
||||
rag_session_id,
|
||||
query_embedding,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
layers=[layer],
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=None,
|
||||
exclude_like_patterns=None,
|
||||
prefer_non_tests=True,
|
||||
),
|
||||
label="layered retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
missing_prefix=f"layer:{layer} retrieval_failed",
|
||||
)
|
||||
if retry_result is not None:
|
||||
return retry_result
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
) -> LayerRetrievalResult:
|
||||
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
||||
filter_args = self._filter_args(effective_exclude_tests)
|
||||
try:
|
||||
rows = self._repository.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
||||
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
||||
prefer_non_tests=not effective_exclude_tests,
|
||||
)
|
||||
return self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label="lexical retrieval",
|
||||
include_spans=include_spans,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
except Exception as exc:
|
||||
retry_result = self._retry_without_test_filter(
|
||||
operation=lambda: self._repository.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=None,
|
||||
exclude_like_patterns=None,
|
||||
prefer_non_tests=True,
|
||||
),
|
||||
label="lexical retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
include_spans=include_spans,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
missing_prefix="layer:C0 lexical_retrieval_failed",
|
||||
)
|
||||
if retry_result is not None:
|
||||
return retry_result
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing("layer:C0 lexical_retrieval_failed", exc)])
|
||||
|
||||
def _retry_without_test_filter(
|
||||
self,
|
||||
*,
|
||||
operation: Callable[[], list[dict]],
|
||||
label: str,
|
||||
rag_session_id: str,
|
||||
include_spans: bool,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
exc: Exception,
|
||||
missing_prefix: str,
|
||||
layer: str | None = None,
|
||||
) -> LayerRetrievalResult | None:
|
||||
if not exclude_tests:
|
||||
self._log_failure(
|
||||
label=label,
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
)
|
||||
return None
|
||||
self._log_failure(
|
||||
label=label,
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
retried_without_test_filter=True,
|
||||
)
|
||||
try:
|
||||
rows = operation()
|
||||
except Exception as retry_exc:
|
||||
self._log_failure(
|
||||
label=f"{label} retry",
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=False,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=retry_exc,
|
||||
)
|
||||
return None
|
||||
result = self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label=f"{label} retry",
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=False,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
result.missing.append(f"{missing_prefix}:retried_without_test_filter")
|
||||
return result
|
||||
|
||||
def _success_result(
|
||||
self,
|
||||
rows: list[dict],
|
||||
*,
|
||||
rag_session_id: str,
|
||||
label: str,
|
||||
include_spans: bool,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
layer: str | None = None,
|
||||
) -> LayerRetrievalResult:
|
||||
items = [self._to_item(row, include_spans=include_spans) for row in rows]
|
||||
LOGGER.warning(
|
||||
"%s: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s returned_count=%s top_paths=%s",
|
||||
label,
|
||||
rag_session_id,
|
||||
layer,
|
||||
exclude_tests,
|
||||
path_prefixes or [],
|
||||
len(items),
|
||||
[item.source for item in items[:3]],
|
||||
)
|
||||
return LayerRetrievalResult(items=items)
|
||||
|
||||
def _log_failure(
|
||||
self,
|
||||
*,
|
||||
label: str,
|
||||
rag_session_id: str,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
exc: Exception,
|
||||
layer: str | None = None,
|
||||
retried_without_test_filter: bool = False,
|
||||
) -> None:
|
||||
LOGGER.warning(
|
||||
"%s failed: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s retried_without_test_filter=%s error=%s",
|
||||
label,
|
||||
rag_session_id,
|
||||
layer,
|
||||
exclude_tests,
|
||||
path_prefixes or [],
|
||||
retried_without_test_filter,
|
||||
self._exception_summary(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _filter_args(self, exclude_tests: bool) -> dict[str, list[str] | None]:
|
||||
test_filters = build_test_filters() if exclude_tests else None
|
||||
return {
|
||||
"exclude_path_prefixes": test_filters.exclude_path_prefixes if test_filters else None,
|
||||
"exclude_like_patterns": test_filters.exclude_like_patterns if test_filters else None,
|
||||
}
|
||||
|
||||
def _failure_missing(self, prefix: str, exc: Exception) -> str:
|
||||
return f"{prefix}:{self._exception_summary(exc)}"
|
||||
|
||||
def _exception_summary(self, exc: Exception) -> str:
|
||||
message = " ".join(str(exc).split())
|
||||
if len(message) > 180:
|
||||
message = message[:177] + "..."
|
||||
return f"{type(exc).__name__}:{message or 'no_message'}"
|
||||
|
||||
def _to_item(self, row: dict, *, include_spans: bool) -> LayeredRetrievalItem:
|
||||
location = None
|
||||
if include_spans:
|
||||
location = CodeLocation(
|
||||
path=str(row.get("path") or ""),
|
||||
start_line=row.get("span_start"),
|
||||
end_line=row.get("span_end"),
|
||||
)
|
||||
return LayeredRetrievalItem(
|
||||
source=str(row.get("path") or ""),
|
||||
content=str(row.get("content") or ""),
|
||||
layer=str(row.get("layer") or ""),
|
||||
title=str(row.get("title") or ""),
|
||||
metadata=dict(row.get("metadata", {}) or {}),
|
||||
score=row.get("distance"),
|
||||
location=location,
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ExplainHints(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
paths: list[str] = Field(default_factory=list)
|
||||
symbols: list[str] = Field(default_factory=list)
|
||||
endpoints: list[str] = Field(default_factory=list)
|
||||
commands: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExplainIntent(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
raw_query: str
|
||||
normalized_query: str
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
hints: ExplainHints = Field(default_factory=ExplainHints)
|
||||
include_tests: bool = False
|
||||
expected_entry_types: list[Literal["http", "cli"]] = Field(default_factory=list)
|
||||
depth: Literal["high", "medium", "deep"] = "medium"
|
||||
|
||||
|
||||
class CodeLocation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
path: str
|
||||
start_line: int | None = None
|
||||
end_line: int | None = None
|
||||
|
||||
|
||||
class LayeredRetrievalItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: str
|
||||
content: str
|
||||
layer: str
|
||||
title: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
score: float | None = None
|
||||
location: CodeLocation | None = None
|
||||
|
||||
|
||||
class TracePath(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
symbol_ids: list[str] = Field(default_factory=list)
|
||||
score: float = 0.0
|
||||
entrypoint_id: str | None = None
|
||||
notes: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EvidenceItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
evidence_id: str
|
||||
kind: Literal["entrypoint", "symbol", "edge", "excerpt"]
|
||||
summary: str
|
||||
location: CodeLocation | None = None
|
||||
supports: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CodeExcerpt(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
evidence_id: str
|
||||
symbol_id: str | None = None
|
||||
title: str
|
||||
path: str
|
||||
start_line: int | None = None
|
||||
end_line: int | None = None
|
||||
content: str
|
||||
focus: str = "overview"
|
||||
|
||||
|
||||
class ExplainPack(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
intent: ExplainIntent
|
||||
selected_entrypoints: list[LayeredRetrievalItem] = Field(default_factory=list)
|
||||
seed_symbols: list[LayeredRetrievalItem] = Field(default_factory=list)
|
||||
trace_paths: list[TracePath] = Field(default_factory=list)
|
||||
evidence_index: dict[str, EvidenceItem] = Field(default_factory=dict)
|
||||
code_excerpts: list[CodeExcerpt] = Field(default_factory=list)
|
||||
missing: list[str] = Field(default_factory=list)
|
||||
conflicts: list[str] = Field(default_factory=list)
|
||||
@@ -0,0 +1,328 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.explain.intent_builder import ExplainIntentBuilder
|
||||
from app.modules.rag.explain.layered_gateway import LayerRetrievalResult, LayeredRetrievalGateway
|
||||
from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, ExplainPack, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.source_excerpt_fetcher import SourceExcerptFetcher
|
||||
from app.modules.rag.explain.trace_builder import TraceBuilder
|
||||
from app.modules.rag.retrieval.test_filter import exclude_tests_default, is_test_path
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
_MIN_EXCERPTS = 2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
||||
from app.modules.rag.explain.models import ExplainIntent
|
||||
|
||||
|
||||
class CodeExplainRetrieverV2:
|
||||
def __init__(
|
||||
self,
|
||||
gateway: LayeredRetrievalGateway,
|
||||
graph_repository: CodeGraphRepository,
|
||||
intent_builder: ExplainIntentBuilder | None = None,
|
||||
trace_builder: TraceBuilder | None = None,
|
||||
excerpt_fetcher: SourceExcerptFetcher | None = None,
|
||||
) -> None:
|
||||
self._gateway = gateway
|
||||
self._graph = graph_repository
|
||||
self._intent_builder = intent_builder or ExplainIntentBuilder()
|
||||
self._trace_builder = trace_builder or TraceBuilder(graph_repository)
|
||||
self._excerpt_fetcher = excerpt_fetcher or SourceExcerptFetcher(graph_repository)
|
||||
|
||||
def build_pack(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
user_query: str,
|
||||
*,
|
||||
file_candidates: list[dict] | None = None,
|
||||
) -> ExplainPack:
|
||||
intent = self._intent_builder.build(user_query)
|
||||
path_prefixes = _path_prefixes(intent, file_candidates or [])
|
||||
exclude_tests = exclude_tests_default() and not intent.include_tests
|
||||
pack = self._run_pass(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
|
||||
if exclude_tests and len(pack.code_excerpts) < _MIN_EXCERPTS:
|
||||
self._merge_test_fallback(pack, rag_session_id, intent, path_prefixes)
|
||||
self._log_pack(rag_session_id, pack)
|
||||
return pack
|
||||
|
||||
def _run_pass(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
*,
|
||||
exclude_tests: bool,
|
||||
) -> ExplainPack:
|
||||
missing: list[str] = []
|
||||
entrypoints_result = self._entrypoints(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
|
||||
missing.extend(entrypoints_result.missing)
|
||||
selected_entrypoints = self._filter_entrypoints(intent, entrypoints_result.items)
|
||||
if not selected_entrypoints:
|
||||
missing.append("layer:C3 empty")
|
||||
seed_result = self._seed_symbols(rag_session_id, intent, path_prefixes, selected_entrypoints, exclude_tests=exclude_tests)
|
||||
missing.extend(seed_result.missing)
|
||||
seed_symbols = seed_result.items
|
||||
if not seed_symbols:
|
||||
missing.append("layer:C1 empty")
|
||||
depth = 4 if intent.depth == "deep" else 3 if intent.depth == "medium" else 2
|
||||
trace_paths = self._trace_builder.build_paths(rag_session_id, seed_symbols, max_depth=depth) if seed_symbols else []
|
||||
excerpts, excerpt_evidence = self._excerpt_fetcher.fetch(rag_session_id, trace_paths) if trace_paths else ([], {})
|
||||
if not excerpts:
|
||||
lexical_result = self._gateway.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
limit=6,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=exclude_tests,
|
||||
include_spans=True,
|
||||
)
|
||||
missing.extend(lexical_result.missing)
|
||||
excerpts, excerpt_evidence = _lexical_excerpts(lexical_result.items)
|
||||
if not excerpts:
|
||||
missing.append("layer:C0 empty")
|
||||
evidence_index = _evidence_index(selected_entrypoints, seed_symbols)
|
||||
evidence_index.update(excerpt_evidence)
|
||||
missing.extend(_missing(selected_entrypoints, seed_symbols, trace_paths, excerpts))
|
||||
return ExplainPack(
|
||||
intent=intent,
|
||||
selected_entrypoints=selected_entrypoints,
|
||||
seed_symbols=seed_symbols,
|
||||
trace_paths=trace_paths,
|
||||
evidence_index=evidence_index,
|
||||
code_excerpts=excerpts,
|
||||
missing=_cleanup_missing(_dedupe(missing), has_excerpts=bool(excerpts)),
|
||||
conflicts=[],
|
||||
)
|
||||
|
||||
def _merge_test_fallback(
|
||||
self,
|
||||
pack: ExplainPack,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
) -> None:
|
||||
lexical_result = self._gateway.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
limit=6,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=False,
|
||||
include_spans=True,
|
||||
)
|
||||
excerpt_offset = len([key for key in pack.evidence_index if key.startswith("excerpt_")])
|
||||
excerpts, evidence = _lexical_excerpts(
|
||||
lexical_result.items,
|
||||
start_index=excerpt_offset,
|
||||
is_test_fallback=True,
|
||||
)
|
||||
if not excerpts:
|
||||
pack.missing = _dedupe(pack.missing + lexical_result.missing)
|
||||
return
|
||||
seen = {(item.path, item.start_line, item.end_line, item.content) for item in pack.code_excerpts}
|
||||
for excerpt in excerpts:
|
||||
key = (excerpt.path, excerpt.start_line, excerpt.end_line, excerpt.content)
|
||||
if key in seen:
|
||||
continue
|
||||
pack.code_excerpts.append(excerpt)
|
||||
seen.add(key)
|
||||
pack.evidence_index.update(evidence)
|
||||
pack.missing = _cleanup_missing(_dedupe(pack.missing + lexical_result.missing), has_excerpts=bool(pack.code_excerpts))
|
||||
|
||||
def _entrypoints(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
*,
|
||||
exclude_tests: bool,
|
||||
) -> LayerRetrievalResult:
|
||||
return self._gateway.retrieve_layer(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
RagLayer.CODE_ENTRYPOINTS,
|
||||
limit=6,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=exclude_tests,
|
||||
prefer_non_tests=True,
|
||||
include_spans=True,
|
||||
)
|
||||
|
||||
def _filter_entrypoints(self, intent: ExplainIntent, items: list[LayeredRetrievalItem]) -> list[LayeredRetrievalItem]:
|
||||
if not intent.expected_entry_types:
|
||||
return items[:3]
|
||||
filtered = [item for item in items if str(item.metadata.get("entry_type") or "") in intent.expected_entry_types]
|
||||
return filtered[:3] or items[:3]
|
||||
|
||||
def _seed_symbols(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
entrypoints: list[LayeredRetrievalItem],
|
||||
*,
|
||||
exclude_tests: bool,
|
||||
) -> LayerRetrievalResult:
|
||||
symbol_result = self._gateway.retrieve_layer(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
RagLayer.CODE_SYMBOL_CATALOG,
|
||||
limit=12,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=exclude_tests,
|
||||
prefer_non_tests=True,
|
||||
include_spans=True,
|
||||
)
|
||||
handlers: list[LayeredRetrievalItem] = []
|
||||
handler_ids = [str(item.metadata.get("handler_symbol_id") or "") for item in entrypoints]
|
||||
if handler_ids:
|
||||
handlers = self._graph.get_symbols_by_ids(rag_session_id, [item for item in handler_ids if item])
|
||||
seeds: list[LayeredRetrievalItem] = []
|
||||
seen: set[str] = set()
|
||||
for item in handlers + symbol_result.items:
|
||||
symbol_id = str(item.metadata.get("symbol_id") or "")
|
||||
if not symbol_id or symbol_id in seen:
|
||||
continue
|
||||
seen.add(symbol_id)
|
||||
seeds.append(item)
|
||||
if len(seeds) >= 8:
|
||||
break
|
||||
return LayerRetrievalResult(items=seeds, missing=list(symbol_result.missing))
|
||||
|
||||
def _log_pack(self, rag_session_id: str, pack: ExplainPack) -> None:
|
||||
prod_excerpt_count = len([excerpt for excerpt in pack.code_excerpts if not _is_test_excerpt(excerpt)])
|
||||
test_excerpt_count = len(pack.code_excerpts) - prod_excerpt_count
|
||||
LOGGER.warning(
|
||||
"code explain pack: rag_session_id=%s entrypoints=%s seeds=%s paths=%s excerpts=%s prod_excerpt_count=%s test_excerpt_count=%s missing=%s",
|
||||
rag_session_id,
|
||||
len(pack.selected_entrypoints),
|
||||
len(pack.seed_symbols),
|
||||
len(pack.trace_paths),
|
||||
len(pack.code_excerpts),
|
||||
prod_excerpt_count,
|
||||
test_excerpt_count,
|
||||
pack.missing,
|
||||
)
|
||||
|
||||
|
||||
def _evidence_index(
|
||||
entrypoints: list[LayeredRetrievalItem],
|
||||
seed_symbols: list[LayeredRetrievalItem],
|
||||
) -> dict[str, EvidenceItem]:
|
||||
result: dict[str, EvidenceItem] = {}
|
||||
for index, item in enumerate(entrypoints, start=1):
|
||||
evidence_id = f"entrypoint_{index}"
|
||||
result[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="entrypoint",
|
||||
summary=item.title,
|
||||
location=item.location,
|
||||
supports=[str(item.metadata.get("handler_symbol_id") or "")],
|
||||
)
|
||||
for index, item in enumerate(seed_symbols, start=1):
|
||||
evidence_id = f"symbol_{index}"
|
||||
result[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="symbol",
|
||||
summary=item.title,
|
||||
location=item.location,
|
||||
supports=[str(item.metadata.get("symbol_id") or "")],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _missing(
|
||||
entrypoints: list[LayeredRetrievalItem],
|
||||
seed_symbols: list[LayeredRetrievalItem],
|
||||
trace_paths,
|
||||
excerpts,
|
||||
) -> list[str]:
|
||||
missing: list[str] = []
|
||||
if not entrypoints:
|
||||
missing.append("entrypoints")
|
||||
if not seed_symbols:
|
||||
missing.append("seed_symbols")
|
||||
if not trace_paths:
|
||||
missing.append("trace_paths")
|
||||
if not excerpts:
|
||||
missing.append("code_excerpts")
|
||||
return missing
|
||||
|
||||
|
||||
def _lexical_excerpts(
|
||||
items: list[LayeredRetrievalItem],
|
||||
*,
|
||||
start_index: int = 0,
|
||||
is_test_fallback: bool = False,
|
||||
) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]:
|
||||
excerpts: list[CodeExcerpt] = []
|
||||
evidence_index: dict[str, EvidenceItem] = {}
|
||||
for item in items:
|
||||
evidence_id = f"excerpt_{start_index + len(evidence_index) + 1}"
|
||||
location = item.location
|
||||
evidence_index[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="excerpt",
|
||||
summary=item.title or item.source,
|
||||
location=location,
|
||||
supports=[],
|
||||
)
|
||||
focus = "lexical"
|
||||
if _item_is_test(item):
|
||||
focus = "test:lexical"
|
||||
elif is_test_fallback:
|
||||
focus = "lexical"
|
||||
excerpts.append(
|
||||
CodeExcerpt(
|
||||
evidence_id=evidence_id,
|
||||
symbol_id=str(item.metadata.get("symbol_id") or "") or None,
|
||||
title=item.title or item.source,
|
||||
path=item.source,
|
||||
start_line=location.start_line if location else None,
|
||||
end_line=location.end_line if location else None,
|
||||
content=item.content,
|
||||
focus=focus,
|
||||
)
|
||||
)
|
||||
return excerpts, evidence_index
|
||||
|
||||
|
||||
def _item_is_test(item: LayeredRetrievalItem) -> bool:
|
||||
return bool(item.metadata.get("is_test")) or is_test_path(item.source)
|
||||
|
||||
|
||||
def _is_test_excerpt(excerpt: CodeExcerpt) -> bool:
|
||||
return excerpt.focus.startswith("test:") or is_test_path(excerpt.path)
|
||||
|
||||
|
||||
def _path_prefixes(intent: ExplainIntent, file_candidates: list[dict]) -> list[str]:
|
||||
values: list[str] = []
|
||||
for path in intent.hints.paths:
|
||||
prefix = path.rsplit("/", 1)[0] if "/" in path else path
|
||||
if prefix and prefix not in values:
|
||||
values.append(prefix)
|
||||
for item in file_candidates[:6]:
|
||||
path = str(item.get("path") or "")
|
||||
prefix = path.rsplit("/", 1)[0] if "/" in path else ""
|
||||
if prefix and prefix not in values:
|
||||
values.append(prefix)
|
||||
return values
|
||||
|
||||
|
||||
def _cleanup_missing(values: list[str], *, has_excerpts: bool) -> list[str]:
|
||||
if not has_excerpts:
|
||||
return values
|
||||
return [value for value in values if value not in {"code_excerpts", "layer:C0 empty"}]
|
||||
|
||||
|
||||
def _dedupe(values: list[str]) -> list[str]:
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
item = value.strip()
|
||||
if item and item not in result:
|
||||
result.append(item)
|
||||
return result
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.modules.rag.explain.excerpt_planner import ExcerptPlanner
|
||||
from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, TracePath
|
||||
from app.modules.rag.retrieval.test_filter import is_test_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
||||
|
||||
|
||||
class SourceExcerptFetcher:
|
||||
def __init__(self, graph_repository: CodeGraphRepository, planner: ExcerptPlanner | None = None) -> None:
|
||||
self._graph = graph_repository
|
||||
self._planner = planner or ExcerptPlanner()
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
trace_paths: list[TracePath],
|
||||
*,
|
||||
max_excerpts: int = 40,
|
||||
) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]:
|
||||
ordered_symbol_ids: list[str] = []
|
||||
for path in trace_paths:
|
||||
for symbol_id in path.symbol_ids:
|
||||
if symbol_id and symbol_id not in ordered_symbol_ids:
|
||||
ordered_symbol_ids.append(symbol_id)
|
||||
chunks = self._graph.get_chunks_by_symbol_ids(rag_session_id, ordered_symbol_ids)
|
||||
excerpts: list[CodeExcerpt] = []
|
||||
evidence_index: dict[str, EvidenceItem] = {}
|
||||
for chunk in chunks:
|
||||
symbol_id = str(chunk.metadata.get("symbol_id") or "")
|
||||
evidence_id = f"excerpt_{len(evidence_index) + 1}"
|
||||
location = chunk.location
|
||||
evidence_index[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="excerpt",
|
||||
summary=chunk.title,
|
||||
location=location,
|
||||
supports=[symbol_id] if symbol_id else [],
|
||||
)
|
||||
is_test_chunk = bool(chunk.metadata.get("is_test")) or is_test_path(location.path if location else chunk.source)
|
||||
for excerpt in self._planner.plan(chunk, evidence_id=evidence_id, symbol_id=symbol_id):
|
||||
if len(excerpts) >= max_excerpts:
|
||||
break
|
||||
if is_test_chunk and not excerpt.focus.startswith("test:"):
|
||||
excerpt.focus = f"test:{excerpt.focus}"
|
||||
excerpts.append(excerpt)
|
||||
if len(excerpts) >= max_excerpts:
|
||||
break
|
||||
return excerpts, evidence_index
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.modules.rag.explain.models import LayeredRetrievalItem, TracePath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
||||
|
||||
|
||||
class TraceBuilder:
|
||||
def __init__(self, graph_repository: CodeGraphRepository) -> None:
|
||||
self._graph = graph_repository
|
||||
|
||||
def build_paths(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
seed_symbols: list[LayeredRetrievalItem],
|
||||
*,
|
||||
max_depth: int,
|
||||
max_paths: int = 3,
|
||||
edge_types: list[str] | None = None,
|
||||
) -> list[TracePath]:
|
||||
edges_filter = edge_types or ["calls", "imports", "inherits"]
|
||||
symbol_map = self._symbol_map(seed_symbols)
|
||||
paths: list[TracePath] = []
|
||||
for seed in seed_symbols:
|
||||
seed_id = str(seed.metadata.get("symbol_id") or "")
|
||||
if not seed_id:
|
||||
continue
|
||||
queue: list[tuple[list[str], float, list[str]]] = [([seed_id], 0.0, [])]
|
||||
while queue and len(paths) < max_paths * 3:
|
||||
current_path, score, notes = queue.pop(0)
|
||||
src_symbol_id = current_path[-1]
|
||||
out_edges = self._graph.get_out_edges(rag_session_id, [src_symbol_id], edges_filter, limit_per_src=4)
|
||||
if not out_edges or len(current_path) >= max_depth:
|
||||
paths.append(TracePath(symbol_ids=current_path, score=score, notes=notes))
|
||||
continue
|
||||
for edge in out_edges:
|
||||
metadata = edge.metadata
|
||||
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
|
||||
next_notes = list(notes)
|
||||
next_score = score + self._edge_score(edge, symbol_map.get(src_symbol_id))
|
||||
if not dst_symbol_id:
|
||||
dst_ref = str(metadata.get("dst_ref") or "")
|
||||
package_hint = self._package_hint(symbol_map.get(src_symbol_id))
|
||||
resolved = self._graph.resolve_symbol_by_ref(rag_session_id, dst_ref, package_hint=package_hint)
|
||||
if resolved is not None:
|
||||
dst_symbol_id = str(resolved.metadata.get("symbol_id") or "")
|
||||
symbol_map[dst_symbol_id] = resolved
|
||||
next_score += 2.0
|
||||
next_notes.append(f"resolved:{dst_ref}")
|
||||
if not dst_symbol_id or dst_symbol_id in current_path:
|
||||
paths.append(TracePath(symbol_ids=current_path, score=next_score, notes=next_notes))
|
||||
continue
|
||||
if dst_symbol_id not in symbol_map:
|
||||
symbols = self._graph.get_symbols_by_ids(rag_session_id, [dst_symbol_id])
|
||||
if symbols:
|
||||
symbol_map[dst_symbol_id] = symbols[0]
|
||||
queue.append((current_path + [dst_symbol_id], next_score, next_notes))
|
||||
unique = self._unique_paths(paths)
|
||||
unique.sort(key=lambda item: item.score, reverse=True)
|
||||
return unique[:max_paths] or [TracePath(symbol_ids=[seed.metadata.get("symbol_id", "")], score=0.0) for seed in seed_symbols[:1]]
|
||||
|
||||
def _edge_score(self, edge: LayeredRetrievalItem, source_symbol: LayeredRetrievalItem | None) -> float:
|
||||
metadata = edge.metadata
|
||||
score = 1.0
|
||||
if str(metadata.get("resolution") or "") == "resolved":
|
||||
score += 2.0
|
||||
source_path = source_symbol.source if source_symbol is not None else ""
|
||||
if source_path and edge.source == source_path:
|
||||
score += 1.0
|
||||
if "tests/" in edge.source or "/tests/" in edge.source:
|
||||
score -= 3.0
|
||||
return score
|
||||
|
||||
def _package_hint(self, symbol: LayeredRetrievalItem | None) -> str | None:
|
||||
if symbol is None:
|
||||
return None
|
||||
package = str(symbol.metadata.get("package_or_module") or "")
|
||||
if not package:
|
||||
return None
|
||||
return ".".join(package.split(".")[:-1]) or package
|
||||
|
||||
def _symbol_map(self, items: list[LayeredRetrievalItem]) -> dict[str, LayeredRetrievalItem]:
|
||||
result: dict[str, LayeredRetrievalItem] = {}
|
||||
for item in items:
|
||||
symbol_id = str(item.metadata.get("symbol_id") or "")
|
||||
if symbol_id:
|
||||
result[symbol_id] = item
|
||||
return result
|
||||
|
||||
def _unique_paths(self, items: list[TracePath]) -> list[TracePath]:
|
||||
result: list[TracePath] = []
|
||||
seen: set[tuple[str, ...]] = set()
|
||||
for item in items:
|
||||
key = tuple(symbol_id for symbol_id in item.symbol_ids if symbol_id)
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(item)
|
||||
return result
|
||||
Reference in New Issue
Block a user