Фиксация изменений
This commit is contained in:
36
app/modules/rag/explain/__init__.py
Normal file
36
app/modules/rag/explain/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
__all__ = [
|
||||
"CodeExcerpt",
|
||||
"CodeExplainRetrieverV2",
|
||||
"CodeGraphRepository",
|
||||
"EvidenceItem",
|
||||
"ExplainIntent",
|
||||
"ExplainIntentBuilder",
|
||||
"ExplainPack",
|
||||
"LayeredRetrievalGateway",
|
||||
"PromptBudgeter",
|
||||
"TracePath",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
module_map = {
|
||||
"CodeExcerpt": "app.modules.rag.explain.models",
|
||||
"EvidenceItem": "app.modules.rag.explain.models",
|
||||
"ExplainIntent": "app.modules.rag.explain.models",
|
||||
"ExplainPack": "app.modules.rag.explain.models",
|
||||
"TracePath": "app.modules.rag.explain.models",
|
||||
"ExplainIntentBuilder": "app.modules.rag.explain.intent_builder",
|
||||
"PromptBudgeter": "app.modules.rag.explain.budgeter",
|
||||
"LayeredRetrievalGateway": "app.modules.rag.explain.layered_gateway",
|
||||
"CodeGraphRepository": "app.modules.rag.explain.graph_repository",
|
||||
"CodeExplainRetrieverV2": "app.modules.rag.explain.retriever_v2",
|
||||
}
|
||||
module_name = module_map.get(name)
|
||||
if module_name is None:
|
||||
raise AttributeError(name)
|
||||
module = import_module(module_name)
|
||||
return getattr(module, name)
|
||||
BIN
app/modules/rag/explain/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/modules/rag/explain/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/explain/__pycache__/budgeter.cpython-312.pyc
Normal file
BIN
app/modules/rag/explain/__pycache__/budgeter.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
app/modules/rag/explain/__pycache__/models.cpython-312.pyc
Normal file
BIN
app/modules/rag/explain/__pycache__/models.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/modules/rag/explain/__pycache__/retriever_v2.cpython-312.pyc
Normal file
BIN
app/modules/rag/explain/__pycache__/retriever_v2.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
62
app/modules/rag/explain/budgeter.py
Normal file
62
app/modules/rag/explain/budgeter.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.modules.rag.explain.models import ExplainPack
|
||||
|
||||
|
||||
class PromptBudgeter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_paths: int = 3,
|
||||
max_symbols: int = 25,
|
||||
max_excerpts: int = 40,
|
||||
max_chars: int = 30000,
|
||||
) -> None:
|
||||
self._max_paths = max_paths
|
||||
self._max_symbols = max_symbols
|
||||
self._max_excerpts = max_excerpts
|
||||
self._max_chars = max_chars
|
||||
|
||||
def build_prompt_input(self, question: str, pack: ExplainPack) -> str:
|
||||
symbol_ids: list[str] = []
|
||||
for path in pack.trace_paths[: self._max_paths]:
|
||||
for symbol_id in path.symbol_ids:
|
||||
if symbol_id and symbol_id not in symbol_ids and len(symbol_ids) < self._max_symbols:
|
||||
symbol_ids.append(symbol_id)
|
||||
excerpts = []
|
||||
total_chars = 0
|
||||
for excerpt in pack.code_excerpts:
|
||||
if symbol_ids and excerpt.symbol_id and excerpt.symbol_id not in symbol_ids:
|
||||
continue
|
||||
body = excerpt.content.strip()
|
||||
remaining = self._max_chars - total_chars
|
||||
if remaining <= 0 or len(excerpts) >= self._max_excerpts:
|
||||
break
|
||||
if len(body) > remaining:
|
||||
body = body[:remaining].rstrip() + "...[truncated]"
|
||||
excerpts.append(
|
||||
{
|
||||
"evidence_id": excerpt.evidence_id,
|
||||
"title": excerpt.title,
|
||||
"path": excerpt.path,
|
||||
"start_line": excerpt.start_line,
|
||||
"end_line": excerpt.end_line,
|
||||
"focus": excerpt.focus,
|
||||
"content": body,
|
||||
}
|
||||
)
|
||||
total_chars += len(body)
|
||||
payload = {
|
||||
"question": question,
|
||||
"intent": pack.intent.model_dump(mode="json"),
|
||||
"selected_entrypoints": [item.model_dump(mode="json") for item in pack.selected_entrypoints[:5]],
|
||||
"seed_symbols": [item.model_dump(mode="json") for item in pack.seed_symbols[: self._max_symbols]],
|
||||
"trace_paths": [path.model_dump(mode="json") for path in pack.trace_paths[: self._max_paths]],
|
||||
"evidence_index": {key: value.model_dump(mode="json") for key, value in pack.evidence_index.items()},
|
||||
"code_excerpts": excerpts,
|
||||
"missing": pack.missing,
|
||||
"conflicts": pack.conflicts,
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
59
app/modules/rag/explain/excerpt_planner.py
Normal file
59
app/modules/rag/explain/excerpt_planner.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.rag.explain.models import CodeExcerpt, LayeredRetrievalItem
|
||||
|
||||
|
||||
class ExcerptPlanner:
|
||||
_FOCUS_TOKENS = ("raise", "except", "db", "select", "insert", "update", "delete", "http", "publish", "emit")
|
||||
|
||||
def plan(self, chunk: LayeredRetrievalItem, *, evidence_id: str, symbol_id: str | None) -> list[CodeExcerpt]:
|
||||
location = chunk.location
|
||||
if location is None:
|
||||
return []
|
||||
excerpts = [
|
||||
CodeExcerpt(
|
||||
evidence_id=evidence_id,
|
||||
symbol_id=symbol_id,
|
||||
title=chunk.title,
|
||||
path=location.path,
|
||||
start_line=location.start_line,
|
||||
end_line=location.end_line,
|
||||
content=chunk.content.strip(),
|
||||
focus="overview",
|
||||
)
|
||||
]
|
||||
focus = self._focus_excerpt(chunk, evidence_id=evidence_id, symbol_id=symbol_id)
|
||||
if focus is not None:
|
||||
excerpts.append(focus)
|
||||
return excerpts
|
||||
|
||||
def _focus_excerpt(
|
||||
self,
|
||||
chunk: LayeredRetrievalItem,
|
||||
*,
|
||||
evidence_id: str,
|
||||
symbol_id: str | None,
|
||||
) -> CodeExcerpt | None:
|
||||
location = chunk.location
|
||||
if location is None:
|
||||
return None
|
||||
lines = chunk.content.splitlines()
|
||||
for index, line in enumerate(lines):
|
||||
lowered = line.lower()
|
||||
if not any(token in lowered for token in self._FOCUS_TOKENS):
|
||||
continue
|
||||
start = max(0, index - 2)
|
||||
end = min(len(lines), index + 3)
|
||||
if end - start >= len(lines):
|
||||
return None
|
||||
return CodeExcerpt(
|
||||
evidence_id=evidence_id,
|
||||
symbol_id=symbol_id,
|
||||
title=f"{chunk.title}:focus",
|
||||
path=location.path,
|
||||
start_line=(location.start_line or 1) + start,
|
||||
end_line=(location.start_line or 1) + end - 1,
|
||||
content="\n".join(lines[start:end]).strip(),
|
||||
focus="focus",
|
||||
)
|
||||
return None
|
||||
216
app/modules/rag/explain/graph_repository.py
Normal file
216
app/modules/rag/explain/graph_repository.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.shared.db import get_engine
|
||||
|
||||
|
||||
class CodeGraphRepository:
|
||||
def get_out_edges(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
src_symbol_ids: list[str],
|
||||
edge_types: list[str],
|
||||
limit_per_src: int,
|
||||
) -> list[LayeredRetrievalItem]:
|
||||
if not src_symbol_ids:
|
||||
return []
|
||||
sql = """
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C2_DEPENDENCY_GRAPH'
|
||||
AND CAST(metadata_json AS jsonb)->>'src_symbol_id' = ANY(:src_ids)
|
||||
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
|
||||
ORDER BY path, span_start
|
||||
"""
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(sql),
|
||||
{"sid": rag_session_id, "src_ids": src_symbol_ids, "edge_types": edge_types},
|
||||
).mappings().fetchall()
|
||||
grouped: dict[str, int] = {}
|
||||
items: list[LayeredRetrievalItem] = []
|
||||
for row in rows:
|
||||
metadata = self._loads(row.get("metadata_json"))
|
||||
src_symbol_id = str(metadata.get("src_symbol_id") or "")
|
||||
grouped[src_symbol_id] = grouped.get(src_symbol_id, 0) + 1
|
||||
if grouped[src_symbol_id] > limit_per_src:
|
||||
continue
|
||||
items.append(self._to_item(row, metadata))
|
||||
return items
|
||||
|
||||
def get_in_edges(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
dst_symbol_ids: list[str],
|
||||
edge_types: list[str],
|
||||
limit_per_dst: int,
|
||||
) -> list[LayeredRetrievalItem]:
|
||||
if not dst_symbol_ids:
|
||||
return []
|
||||
sql = """
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C2_DEPENDENCY_GRAPH'
|
||||
AND CAST(metadata_json AS jsonb)->>'dst_symbol_id' = ANY(:dst_ids)
|
||||
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
|
||||
ORDER BY path, span_start
|
||||
"""
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(sql),
|
||||
{"sid": rag_session_id, "dst_ids": dst_symbol_ids, "edge_types": edge_types},
|
||||
).mappings().fetchall()
|
||||
grouped: dict[str, int] = {}
|
||||
items: list[LayeredRetrievalItem] = []
|
||||
for row in rows:
|
||||
metadata = self._loads(row.get("metadata_json"))
|
||||
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
|
||||
grouped[dst_symbol_id] = grouped.get(dst_symbol_id, 0) + 1
|
||||
if grouped[dst_symbol_id] > limit_per_dst:
|
||||
continue
|
||||
items.append(self._to_item(row, metadata))
|
||||
return items
|
||||
|
||||
def resolve_symbol_by_ref(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
dst_ref: str,
|
||||
package_hint: str | None = None,
|
||||
) -> LayeredRetrievalItem | None:
|
||||
ref = (dst_ref or "").strip()
|
||||
if not ref:
|
||||
return None
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end, qname
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C1_SYMBOL_CATALOG'
|
||||
AND (qname = :ref OR title = :ref OR qname LIKE :tail)
|
||||
ORDER BY path
|
||||
LIMIT 12
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "ref": ref, "tail": f"%{ref}"},
|
||||
).mappings().fetchall()
|
||||
best: LayeredRetrievalItem | None = None
|
||||
best_score = -1
|
||||
for row in rows:
|
||||
metadata = self._loads(row.get("metadata_json"))
|
||||
package = str(metadata.get("package_or_module") or "")
|
||||
score = 0
|
||||
if str(row.get("qname") or "") == ref:
|
||||
score += 3
|
||||
if str(row.get("title") or "") == ref:
|
||||
score += 2
|
||||
if package_hint and package.startswith(package_hint):
|
||||
score += 3
|
||||
if package_hint and package_hint in str(row.get("path") or ""):
|
||||
score += 1
|
||||
if score > best_score:
|
||||
best = self._to_item(row, metadata)
|
||||
best_score = score
|
||||
return best
|
||||
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]) -> list[LayeredRetrievalItem]:
|
||||
if not symbol_ids:
|
||||
return []
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C1_SYMBOL_CATALOG'
|
||||
AND symbol_id = ANY(:symbol_ids)
|
||||
ORDER BY path, span_start
|
||||
"""
|
||||
),
|
||||
{"sid": rag_session_id, "symbol_ids": symbol_ids},
|
||||
).mappings().fetchall()
|
||||
return [self._to_item(row, self._loads(row.get("metadata_json"))) for row in rows]
|
||||
|
||||
def get_chunks_by_symbol_ids(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
symbol_ids: list[str],
|
||||
prefer_chunk_type: str = "symbol_block",
|
||||
) -> list[LayeredRetrievalItem]:
|
||||
symbols = self.get_symbols_by_ids(rag_session_id, symbol_ids)
|
||||
chunks: list[LayeredRetrievalItem] = []
|
||||
for symbol in symbols:
|
||||
location = symbol.location
|
||||
if location is None:
|
||||
continue
|
||||
chunk = self._chunk_for_symbol(rag_session_id, symbol, prefer_chunk_type=prefer_chunk_type)
|
||||
if chunk is not None:
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
def _chunk_for_symbol(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
symbol: LayeredRetrievalItem,
|
||||
*,
|
||||
prefer_chunk_type: str,
|
||||
) -> LayeredRetrievalItem | None:
|
||||
location = symbol.location
|
||||
if location is None:
|
||||
return None
|
||||
with get_engine().connect() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
||||
FROM rag_chunks
|
||||
WHERE rag_session_id = :sid
|
||||
AND layer = 'C0_SOURCE_CHUNKS'
|
||||
AND path = :path
|
||||
AND COALESCE(span_start, 0) <= :end_line
|
||||
AND COALESCE(span_end, 999999) >= :start_line
|
||||
ORDER BY
|
||||
CASE WHEN CAST(metadata_json AS jsonb)->>'chunk_type' = :prefer_chunk_type THEN 0 ELSE 1 END,
|
||||
ABS(COALESCE(span_start, 0) - :start_line)
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{
|
||||
"sid": rag_session_id,
|
||||
"path": location.path,
|
||||
"start_line": location.start_line or 0,
|
||||
"end_line": location.end_line or 999999,
|
||||
"prefer_chunk_type": prefer_chunk_type,
|
||||
},
|
||||
).mappings().fetchall()
|
||||
if not rows:
|
||||
return None
|
||||
row = rows[0]
|
||||
return self._to_item(row, self._loads(row.get("metadata_json")))
|
||||
|
||||
def _to_item(self, row, metadata: dict) -> LayeredRetrievalItem:
|
||||
return LayeredRetrievalItem(
|
||||
source=str(row.get("path") or ""),
|
||||
content=str(row.get("content") or ""),
|
||||
layer=str(row.get("layer") or ""),
|
||||
title=str(row.get("title") or ""),
|
||||
metadata=metadata,
|
||||
location=CodeLocation(
|
||||
path=str(row.get("path") or ""),
|
||||
start_line=row.get("span_start"),
|
||||
end_line=row.get("span_end"),
|
||||
),
|
||||
)
|
||||
|
||||
def _loads(self, value) -> dict:
|
||||
if not value:
|
||||
return {}
|
||||
return json.loads(str(value))
|
||||
102
app/modules/rag/explain/intent_builder.py
Normal file
102
app/modules/rag/explain/intent_builder.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.modules.rag.explain.models import ExplainHints, ExplainIntent
|
||||
from app.modules.rag.retrieval.query_terms import extract_query_terms
|
||||
|
||||
|
||||
class ExplainIntentBuilder:
|
||||
_ROUTE_RE = re.compile(r"(/[A-Za-z0-9_./{}:-]+)")
|
||||
_FILE_RE = re.compile(r"([A-Za-z0-9_./-]+\.py)")
|
||||
_SYMBOL_RE = re.compile(r"\b([A-Z][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*|[A-Z][A-Za-z0-9_]{2,}|[a-z_][A-Za-z0-9_]{2,})\b")
|
||||
_COMMAND_RE = re.compile(r"`([A-Za-z0-9:_-]+)`")
|
||||
_TEST_KEYWORDS = (
|
||||
"тест",
|
||||
"tests",
|
||||
"test ",
|
||||
"unit-test",
|
||||
"unit test",
|
||||
"юнит-тест",
|
||||
"pytest",
|
||||
"spec",
|
||||
"как покрыто тестами",
|
||||
"как проверяется",
|
||||
"how is it tested",
|
||||
"how it's tested",
|
||||
)
|
||||
|
||||
def build(self, user_query: str) -> ExplainIntent:
|
||||
normalized = " ".join((user_query or "").split())
|
||||
lowered = normalized.lower()
|
||||
keywords = self._keywords(normalized)
|
||||
hints = ExplainHints(
|
||||
paths=self._dedupe(self._FILE_RE.findall(normalized)),
|
||||
symbols=self._symbols(normalized),
|
||||
endpoints=self._dedupe(self._ROUTE_RE.findall(normalized)),
|
||||
commands=self._commands(normalized, lowered),
|
||||
)
|
||||
return ExplainIntent(
|
||||
raw_query=user_query,
|
||||
normalized_query=normalized,
|
||||
keywords=keywords[:12],
|
||||
hints=hints,
|
||||
include_tests=self._include_tests(lowered),
|
||||
expected_entry_types=self._entry_types(lowered, hints),
|
||||
depth=self._depth(lowered),
|
||||
)
|
||||
|
||||
def _keywords(self, text: str) -> list[str]:
|
||||
keywords = extract_query_terms(text)
|
||||
for token in self._symbols(text):
|
||||
if token not in keywords:
|
||||
keywords.append(token)
|
||||
for token in self._ROUTE_RE.findall(text):
|
||||
if token not in keywords:
|
||||
keywords.append(token)
|
||||
return self._dedupe(keywords)
|
||||
|
||||
def _symbols(self, text: str) -> list[str]:
|
||||
values = []
|
||||
for raw in self._SYMBOL_RE.findall(text):
|
||||
token = raw.strip()
|
||||
if len(token) < 3:
|
||||
continue
|
||||
if token.endswith(".py"):
|
||||
continue
|
||||
values.append(token)
|
||||
return self._dedupe(values)
|
||||
|
||||
def _commands(self, text: str, lowered: str) -> list[str]:
|
||||
values = list(self._COMMAND_RE.findall(text))
|
||||
if " command " in f" {lowered} ":
|
||||
values.extend(re.findall(r"command\s+([A-Za-z0-9:_-]+)", lowered))
|
||||
if " cli " in f" {lowered} ":
|
||||
values.extend(re.findall(r"cli\s+([A-Za-z0-9:_-]+)", lowered))
|
||||
return self._dedupe(values)
|
||||
|
||||
def _entry_types(self, lowered: str, hints: ExplainHints) -> list[str]:
|
||||
if hints.endpoints or any(token in lowered for token in ("endpoint", "route", "handler", "http", "api")):
|
||||
return ["http"]
|
||||
if hints.commands or any(token in lowered for token in ("cli", "command", "click", "typer")):
|
||||
return ["cli"]
|
||||
return ["http", "cli"]
|
||||
|
||||
def _depth(self, lowered: str) -> str:
|
||||
if any(token in lowered for token in ("deep", "подроб", "деталь", "full flow", "trace")):
|
||||
return "deep"
|
||||
if any(token in lowered for token in ("high level", "overview", "кратко", "summary")):
|
||||
return "high"
|
||||
return "medium"
|
||||
|
||||
def _include_tests(self, lowered: str) -> bool:
|
||||
normalized = f" {lowered} "
|
||||
return any(token in normalized for token in self._TEST_KEYWORDS)
|
||||
|
||||
def _dedupe(self, values: list[str]) -> list[str]:
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
item = value.strip()
|
||||
if item and item not in result:
|
||||
result.append(item)
|
||||
return result
|
||||
289
app/modules/rag/explain/layered_gateway.py
Normal file
289
app/modules/rag/explain/layered_gateway.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.retrieval.test_filter import build_test_filters, debug_disable_test_filter
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.persistence.repository import RagRepository
|
||||
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LayerRetrievalResult:
|
||||
items: list[LayeredRetrievalItem]
|
||||
missing: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class LayeredRetrievalGateway:
|
||||
def __init__(self, repository: RagRepository, embedder: GigaChatEmbedder) -> None:
|
||||
self._repository = repository
|
||||
self._embedder = embedder
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
) -> LayerRetrievalResult:
|
||||
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
||||
filter_args = self._filter_args(effective_exclude_tests)
|
||||
query_embedding: list[float] | None = None
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repository.retrieve(
|
||||
rag_session_id,
|
||||
query_embedding,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
layers=[layer],
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
||||
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
||||
prefer_non_tests=prefer_non_tests or not effective_exclude_tests,
|
||||
)
|
||||
return self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label="layered retrieval",
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
except Exception as exc:
|
||||
if query_embedding is None:
|
||||
self._log_failure(
|
||||
label="layered retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
)
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
||||
retry_result = self._retry_without_test_filter(
|
||||
operation=lambda: self._repository.retrieve(
|
||||
rag_session_id,
|
||||
query_embedding,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
layers=[layer],
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=None,
|
||||
exclude_like_patterns=None,
|
||||
prefer_non_tests=True,
|
||||
),
|
||||
label="layered retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
missing_prefix=f"layer:{layer} retrieval_failed",
|
||||
)
|
||||
if retry_result is not None:
|
||||
return retry_result
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
) -> LayerRetrievalResult:
|
||||
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
|
||||
filter_args = self._filter_args(effective_exclude_tests)
|
||||
try:
|
||||
rows = self._repository.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
|
||||
exclude_like_patterns=filter_args["exclude_like_patterns"],
|
||||
prefer_non_tests=not effective_exclude_tests,
|
||||
)
|
||||
return self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label="lexical retrieval",
|
||||
include_spans=include_spans,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
except Exception as exc:
|
||||
retry_result = self._retry_without_test_filter(
|
||||
operation=lambda: self._repository.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
path_prefixes=path_prefixes,
|
||||
exclude_path_prefixes=None,
|
||||
exclude_like_patterns=None,
|
||||
prefer_non_tests=True,
|
||||
),
|
||||
label="lexical retrieval",
|
||||
rag_session_id=rag_session_id,
|
||||
include_spans=include_spans,
|
||||
exclude_tests=effective_exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
missing_prefix="layer:C0 lexical_retrieval_failed",
|
||||
)
|
||||
if retry_result is not None:
|
||||
return retry_result
|
||||
return LayerRetrievalResult(items=[], missing=[self._failure_missing("layer:C0 lexical_retrieval_failed", exc)])
|
||||
|
||||
def _retry_without_test_filter(
|
||||
self,
|
||||
*,
|
||||
operation: Callable[[], list[dict]],
|
||||
label: str,
|
||||
rag_session_id: str,
|
||||
include_spans: bool,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
exc: Exception,
|
||||
missing_prefix: str,
|
||||
layer: str | None = None,
|
||||
) -> LayerRetrievalResult | None:
|
||||
if not exclude_tests:
|
||||
self._log_failure(
|
||||
label=label,
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
)
|
||||
return None
|
||||
self._log_failure(
|
||||
label=label,
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=exclude_tests,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=exc,
|
||||
retried_without_test_filter=True,
|
||||
)
|
||||
try:
|
||||
rows = operation()
|
||||
except Exception as retry_exc:
|
||||
self._log_failure(
|
||||
label=f"{label} retry",
|
||||
rag_session_id=rag_session_id,
|
||||
layer=layer,
|
||||
exclude_tests=False,
|
||||
path_prefixes=path_prefixes,
|
||||
exc=retry_exc,
|
||||
)
|
||||
return None
|
||||
result = self._success_result(
|
||||
rows,
|
||||
rag_session_id=rag_session_id,
|
||||
label=f"{label} retry",
|
||||
include_spans=include_spans,
|
||||
layer=layer,
|
||||
exclude_tests=False,
|
||||
path_prefixes=path_prefixes,
|
||||
)
|
||||
result.missing.append(f"{missing_prefix}:retried_without_test_filter")
|
||||
return result
|
||||
|
||||
def _success_result(
|
||||
self,
|
||||
rows: list[dict],
|
||||
*,
|
||||
rag_session_id: str,
|
||||
label: str,
|
||||
include_spans: bool,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
layer: str | None = None,
|
||||
) -> LayerRetrievalResult:
|
||||
items = [self._to_item(row, include_spans=include_spans) for row in rows]
|
||||
LOGGER.warning(
|
||||
"%s: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s returned_count=%s top_paths=%s",
|
||||
label,
|
||||
rag_session_id,
|
||||
layer,
|
||||
exclude_tests,
|
||||
path_prefixes or [],
|
||||
len(items),
|
||||
[item.source for item in items[:3]],
|
||||
)
|
||||
return LayerRetrievalResult(items=items)
|
||||
|
||||
def _log_failure(
|
||||
self,
|
||||
*,
|
||||
label: str,
|
||||
rag_session_id: str,
|
||||
exclude_tests: bool,
|
||||
path_prefixes: list[str] | None,
|
||||
exc: Exception,
|
||||
layer: str | None = None,
|
||||
retried_without_test_filter: bool = False,
|
||||
) -> None:
|
||||
LOGGER.warning(
|
||||
"%s failed: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s retried_without_test_filter=%s error=%s",
|
||||
label,
|
||||
rag_session_id,
|
||||
layer,
|
||||
exclude_tests,
|
||||
path_prefixes or [],
|
||||
retried_without_test_filter,
|
||||
self._exception_summary(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _filter_args(self, exclude_tests: bool) -> dict[str, list[str] | None]:
|
||||
test_filters = build_test_filters() if exclude_tests else None
|
||||
return {
|
||||
"exclude_path_prefixes": test_filters.exclude_path_prefixes if test_filters else None,
|
||||
"exclude_like_patterns": test_filters.exclude_like_patterns if test_filters else None,
|
||||
}
|
||||
|
||||
def _failure_missing(self, prefix: str, exc: Exception) -> str:
|
||||
return f"{prefix}:{self._exception_summary(exc)}"
|
||||
|
||||
def _exception_summary(self, exc: Exception) -> str:
|
||||
message = " ".join(str(exc).split())
|
||||
if len(message) > 180:
|
||||
message = message[:177] + "..."
|
||||
return f"{type(exc).__name__}:{message or 'no_message'}"
|
||||
|
||||
def _to_item(self, row: dict, *, include_spans: bool) -> LayeredRetrievalItem:
|
||||
location = None
|
||||
if include_spans:
|
||||
location = CodeLocation(
|
||||
path=str(row.get("path") or ""),
|
||||
start_line=row.get("span_start"),
|
||||
end_line=row.get("span_end"),
|
||||
)
|
||||
return LayeredRetrievalItem(
|
||||
source=str(row.get("path") or ""),
|
||||
content=str(row.get("content") or ""),
|
||||
layer=str(row.get("layer") or ""),
|
||||
title=str(row.get("title") or ""),
|
||||
metadata=dict(row.get("metadata", {}) or {}),
|
||||
score=row.get("distance"),
|
||||
location=location,
|
||||
)
|
||||
91
app/modules/rag/explain/models.py
Normal file
91
app/modules/rag/explain/models.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ExplainHints(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
paths: list[str] = Field(default_factory=list)
|
||||
symbols: list[str] = Field(default_factory=list)
|
||||
endpoints: list[str] = Field(default_factory=list)
|
||||
commands: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExplainIntent(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
raw_query: str
|
||||
normalized_query: str
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
hints: ExplainHints = Field(default_factory=ExplainHints)
|
||||
include_tests: bool = False
|
||||
expected_entry_types: list[Literal["http", "cli"]] = Field(default_factory=list)
|
||||
depth: Literal["high", "medium", "deep"] = "medium"
|
||||
|
||||
|
||||
class CodeLocation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
path: str
|
||||
start_line: int | None = None
|
||||
end_line: int | None = None
|
||||
|
||||
|
||||
class LayeredRetrievalItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: str
|
||||
content: str
|
||||
layer: str
|
||||
title: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
score: float | None = None
|
||||
location: CodeLocation | None = None
|
||||
|
||||
|
||||
class TracePath(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
symbol_ids: list[str] = Field(default_factory=list)
|
||||
score: float = 0.0
|
||||
entrypoint_id: str | None = None
|
||||
notes: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EvidenceItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
evidence_id: str
|
||||
kind: Literal["entrypoint", "symbol", "edge", "excerpt"]
|
||||
summary: str
|
||||
location: CodeLocation | None = None
|
||||
supports: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CodeExcerpt(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
evidence_id: str
|
||||
symbol_id: str | None = None
|
||||
title: str
|
||||
path: str
|
||||
start_line: int | None = None
|
||||
end_line: int | None = None
|
||||
content: str
|
||||
focus: str = "overview"
|
||||
|
||||
|
||||
class ExplainPack(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
intent: ExplainIntent
|
||||
selected_entrypoints: list[LayeredRetrievalItem] = Field(default_factory=list)
|
||||
seed_symbols: list[LayeredRetrievalItem] = Field(default_factory=list)
|
||||
trace_paths: list[TracePath] = Field(default_factory=list)
|
||||
evidence_index: dict[str, EvidenceItem] = Field(default_factory=dict)
|
||||
code_excerpts: list[CodeExcerpt] = Field(default_factory=list)
|
||||
missing: list[str] = Field(default_factory=list)
|
||||
conflicts: list[str] = Field(default_factory=list)
|
||||
328
app/modules/rag/explain/retriever_v2.py
Normal file
328
app/modules/rag/explain/retriever_v2.py
Normal file
@@ -0,0 +1,328 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.explain.intent_builder import ExplainIntentBuilder
|
||||
from app.modules.rag.explain.layered_gateway import LayerRetrievalResult, LayeredRetrievalGateway
|
||||
from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, ExplainPack, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.source_excerpt_fetcher import SourceExcerptFetcher
|
||||
from app.modules.rag.explain.trace_builder import TraceBuilder
|
||||
from app.modules.rag.retrieval.test_filter import exclude_tests_default, is_test_path
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
_MIN_EXCERPTS = 2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
||||
from app.modules.rag.explain.models import ExplainIntent
|
||||
|
||||
|
||||
class CodeExplainRetrieverV2:
|
||||
def __init__(
|
||||
self,
|
||||
gateway: LayeredRetrievalGateway,
|
||||
graph_repository: CodeGraphRepository,
|
||||
intent_builder: ExplainIntentBuilder | None = None,
|
||||
trace_builder: TraceBuilder | None = None,
|
||||
excerpt_fetcher: SourceExcerptFetcher | None = None,
|
||||
) -> None:
|
||||
self._gateway = gateway
|
||||
self._graph = graph_repository
|
||||
self._intent_builder = intent_builder or ExplainIntentBuilder()
|
||||
self._trace_builder = trace_builder or TraceBuilder(graph_repository)
|
||||
self._excerpt_fetcher = excerpt_fetcher or SourceExcerptFetcher(graph_repository)
|
||||
|
||||
def build_pack(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
user_query: str,
|
||||
*,
|
||||
file_candidates: list[dict] | None = None,
|
||||
) -> ExplainPack:
|
||||
intent = self._intent_builder.build(user_query)
|
||||
path_prefixes = _path_prefixes(intent, file_candidates or [])
|
||||
exclude_tests = exclude_tests_default() and not intent.include_tests
|
||||
pack = self._run_pass(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
|
||||
if exclude_tests and len(pack.code_excerpts) < _MIN_EXCERPTS:
|
||||
self._merge_test_fallback(pack, rag_session_id, intent, path_prefixes)
|
||||
self._log_pack(rag_session_id, pack)
|
||||
return pack
|
||||
|
||||
def _run_pass(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
*,
|
||||
exclude_tests: bool,
|
||||
) -> ExplainPack:
|
||||
missing: list[str] = []
|
||||
entrypoints_result = self._entrypoints(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
|
||||
missing.extend(entrypoints_result.missing)
|
||||
selected_entrypoints = self._filter_entrypoints(intent, entrypoints_result.items)
|
||||
if not selected_entrypoints:
|
||||
missing.append("layer:C3 empty")
|
||||
seed_result = self._seed_symbols(rag_session_id, intent, path_prefixes, selected_entrypoints, exclude_tests=exclude_tests)
|
||||
missing.extend(seed_result.missing)
|
||||
seed_symbols = seed_result.items
|
||||
if not seed_symbols:
|
||||
missing.append("layer:C1 empty")
|
||||
depth = 4 if intent.depth == "deep" else 3 if intent.depth == "medium" else 2
|
||||
trace_paths = self._trace_builder.build_paths(rag_session_id, seed_symbols, max_depth=depth) if seed_symbols else []
|
||||
excerpts, excerpt_evidence = self._excerpt_fetcher.fetch(rag_session_id, trace_paths) if trace_paths else ([], {})
|
||||
if not excerpts:
|
||||
lexical_result = self._gateway.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
limit=6,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=exclude_tests,
|
||||
include_spans=True,
|
||||
)
|
||||
missing.extend(lexical_result.missing)
|
||||
excerpts, excerpt_evidence = _lexical_excerpts(lexical_result.items)
|
||||
if not excerpts:
|
||||
missing.append("layer:C0 empty")
|
||||
evidence_index = _evidence_index(selected_entrypoints, seed_symbols)
|
||||
evidence_index.update(excerpt_evidence)
|
||||
missing.extend(_missing(selected_entrypoints, seed_symbols, trace_paths, excerpts))
|
||||
return ExplainPack(
|
||||
intent=intent,
|
||||
selected_entrypoints=selected_entrypoints,
|
||||
seed_symbols=seed_symbols,
|
||||
trace_paths=trace_paths,
|
||||
evidence_index=evidence_index,
|
||||
code_excerpts=excerpts,
|
||||
missing=_cleanup_missing(_dedupe(missing), has_excerpts=bool(excerpts)),
|
||||
conflicts=[],
|
||||
)
|
||||
|
||||
def _merge_test_fallback(
|
||||
self,
|
||||
pack: ExplainPack,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
) -> None:
|
||||
lexical_result = self._gateway.retrieve_lexical_code(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
limit=6,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=False,
|
||||
include_spans=True,
|
||||
)
|
||||
excerpt_offset = len([key for key in pack.evidence_index if key.startswith("excerpt_")])
|
||||
excerpts, evidence = _lexical_excerpts(
|
||||
lexical_result.items,
|
||||
start_index=excerpt_offset,
|
||||
is_test_fallback=True,
|
||||
)
|
||||
if not excerpts:
|
||||
pack.missing = _dedupe(pack.missing + lexical_result.missing)
|
||||
return
|
||||
seen = {(item.path, item.start_line, item.end_line, item.content) for item in pack.code_excerpts}
|
||||
for excerpt in excerpts:
|
||||
key = (excerpt.path, excerpt.start_line, excerpt.end_line, excerpt.content)
|
||||
if key in seen:
|
||||
continue
|
||||
pack.code_excerpts.append(excerpt)
|
||||
seen.add(key)
|
||||
pack.evidence_index.update(evidence)
|
||||
pack.missing = _cleanup_missing(_dedupe(pack.missing + lexical_result.missing), has_excerpts=bool(pack.code_excerpts))
|
||||
|
||||
def _entrypoints(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
*,
|
||||
exclude_tests: bool,
|
||||
) -> LayerRetrievalResult:
|
||||
return self._gateway.retrieve_layer(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
RagLayer.CODE_ENTRYPOINTS,
|
||||
limit=6,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=exclude_tests,
|
||||
prefer_non_tests=True,
|
||||
include_spans=True,
|
||||
)
|
||||
|
||||
def _filter_entrypoints(self, intent: ExplainIntent, items: list[LayeredRetrievalItem]) -> list[LayeredRetrievalItem]:
|
||||
if not intent.expected_entry_types:
|
||||
return items[:3]
|
||||
filtered = [item for item in items if str(item.metadata.get("entry_type") or "") in intent.expected_entry_types]
|
||||
return filtered[:3] or items[:3]
|
||||
|
||||
def _seed_symbols(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
intent: ExplainIntent,
|
||||
path_prefixes: list[str],
|
||||
entrypoints: list[LayeredRetrievalItem],
|
||||
*,
|
||||
exclude_tests: bool,
|
||||
) -> LayerRetrievalResult:
|
||||
symbol_result = self._gateway.retrieve_layer(
|
||||
rag_session_id,
|
||||
intent.normalized_query,
|
||||
RagLayer.CODE_SYMBOL_CATALOG,
|
||||
limit=12,
|
||||
path_prefixes=path_prefixes or None,
|
||||
exclude_tests=exclude_tests,
|
||||
prefer_non_tests=True,
|
||||
include_spans=True,
|
||||
)
|
||||
handlers: list[LayeredRetrievalItem] = []
|
||||
handler_ids = [str(item.metadata.get("handler_symbol_id") or "") for item in entrypoints]
|
||||
if handler_ids:
|
||||
handlers = self._graph.get_symbols_by_ids(rag_session_id, [item for item in handler_ids if item])
|
||||
seeds: list[LayeredRetrievalItem] = []
|
||||
seen: set[str] = set()
|
||||
for item in handlers + symbol_result.items:
|
||||
symbol_id = str(item.metadata.get("symbol_id") or "")
|
||||
if not symbol_id or symbol_id in seen:
|
||||
continue
|
||||
seen.add(symbol_id)
|
||||
seeds.append(item)
|
||||
if len(seeds) >= 8:
|
||||
break
|
||||
return LayerRetrievalResult(items=seeds, missing=list(symbol_result.missing))
|
||||
|
||||
def _log_pack(self, rag_session_id: str, pack: ExplainPack) -> None:
|
||||
prod_excerpt_count = len([excerpt for excerpt in pack.code_excerpts if not _is_test_excerpt(excerpt)])
|
||||
test_excerpt_count = len(pack.code_excerpts) - prod_excerpt_count
|
||||
LOGGER.warning(
|
||||
"code explain pack: rag_session_id=%s entrypoints=%s seeds=%s paths=%s excerpts=%s prod_excerpt_count=%s test_excerpt_count=%s missing=%s",
|
||||
rag_session_id,
|
||||
len(pack.selected_entrypoints),
|
||||
len(pack.seed_symbols),
|
||||
len(pack.trace_paths),
|
||||
len(pack.code_excerpts),
|
||||
prod_excerpt_count,
|
||||
test_excerpt_count,
|
||||
pack.missing,
|
||||
)
|
||||
|
||||
|
||||
def _evidence_index(
|
||||
entrypoints: list[LayeredRetrievalItem],
|
||||
seed_symbols: list[LayeredRetrievalItem],
|
||||
) -> dict[str, EvidenceItem]:
|
||||
result: dict[str, EvidenceItem] = {}
|
||||
for index, item in enumerate(entrypoints, start=1):
|
||||
evidence_id = f"entrypoint_{index}"
|
||||
result[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="entrypoint",
|
||||
summary=item.title,
|
||||
location=item.location,
|
||||
supports=[str(item.metadata.get("handler_symbol_id") or "")],
|
||||
)
|
||||
for index, item in enumerate(seed_symbols, start=1):
|
||||
evidence_id = f"symbol_{index}"
|
||||
result[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="symbol",
|
||||
summary=item.title,
|
||||
location=item.location,
|
||||
supports=[str(item.metadata.get("symbol_id") or "")],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _missing(
|
||||
entrypoints: list[LayeredRetrievalItem],
|
||||
seed_symbols: list[LayeredRetrievalItem],
|
||||
trace_paths,
|
||||
excerpts,
|
||||
) -> list[str]:
|
||||
missing: list[str] = []
|
||||
if not entrypoints:
|
||||
missing.append("entrypoints")
|
||||
if not seed_symbols:
|
||||
missing.append("seed_symbols")
|
||||
if not trace_paths:
|
||||
missing.append("trace_paths")
|
||||
if not excerpts:
|
||||
missing.append("code_excerpts")
|
||||
return missing
|
||||
|
||||
|
||||
def _lexical_excerpts(
|
||||
items: list[LayeredRetrievalItem],
|
||||
*,
|
||||
start_index: int = 0,
|
||||
is_test_fallback: bool = False,
|
||||
) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]:
|
||||
excerpts: list[CodeExcerpt] = []
|
||||
evidence_index: dict[str, EvidenceItem] = {}
|
||||
for item in items:
|
||||
evidence_id = f"excerpt_{start_index + len(evidence_index) + 1}"
|
||||
location = item.location
|
||||
evidence_index[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="excerpt",
|
||||
summary=item.title or item.source,
|
||||
location=location,
|
||||
supports=[],
|
||||
)
|
||||
focus = "lexical"
|
||||
if _item_is_test(item):
|
||||
focus = "test:lexical"
|
||||
elif is_test_fallback:
|
||||
focus = "lexical"
|
||||
excerpts.append(
|
||||
CodeExcerpt(
|
||||
evidence_id=evidence_id,
|
||||
symbol_id=str(item.metadata.get("symbol_id") or "") or None,
|
||||
title=item.title or item.source,
|
||||
path=item.source,
|
||||
start_line=location.start_line if location else None,
|
||||
end_line=location.end_line if location else None,
|
||||
content=item.content,
|
||||
focus=focus,
|
||||
)
|
||||
)
|
||||
return excerpts, evidence_index
|
||||
|
||||
|
||||
def _item_is_test(item: LayeredRetrievalItem) -> bool:
|
||||
return bool(item.metadata.get("is_test")) or is_test_path(item.source)
|
||||
|
||||
|
||||
def _is_test_excerpt(excerpt: CodeExcerpt) -> bool:
|
||||
return excerpt.focus.startswith("test:") or is_test_path(excerpt.path)
|
||||
|
||||
|
||||
def _path_prefixes(intent: ExplainIntent, file_candidates: list[dict]) -> list[str]:
|
||||
values: list[str] = []
|
||||
for path in intent.hints.paths:
|
||||
prefix = path.rsplit("/", 1)[0] if "/" in path else path
|
||||
if prefix and prefix not in values:
|
||||
values.append(prefix)
|
||||
for item in file_candidates[:6]:
|
||||
path = str(item.get("path") or "")
|
||||
prefix = path.rsplit("/", 1)[0] if "/" in path else ""
|
||||
if prefix and prefix not in values:
|
||||
values.append(prefix)
|
||||
return values
|
||||
|
||||
|
||||
def _cleanup_missing(values: list[str], *, has_excerpts: bool) -> list[str]:
|
||||
if not has_excerpts:
|
||||
return values
|
||||
return [value for value in values if value not in {"code_excerpts", "layer:C0 empty"}]
|
||||
|
||||
|
||||
def _dedupe(values: list[str]) -> list[str]:
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
item = value.strip()
|
||||
if item and item not in result:
|
||||
result.append(item)
|
||||
return result
|
||||
53
app/modules/rag/explain/source_excerpt_fetcher.py
Normal file
53
app/modules/rag/explain/source_excerpt_fetcher.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.modules.rag.explain.excerpt_planner import ExcerptPlanner
|
||||
from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, TracePath
|
||||
from app.modules.rag.retrieval.test_filter import is_test_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
||||
|
||||
|
||||
class SourceExcerptFetcher:
|
||||
def __init__(self, graph_repository: CodeGraphRepository, planner: ExcerptPlanner | None = None) -> None:
|
||||
self._graph = graph_repository
|
||||
self._planner = planner or ExcerptPlanner()
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
trace_paths: list[TracePath],
|
||||
*,
|
||||
max_excerpts: int = 40,
|
||||
) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]:
|
||||
ordered_symbol_ids: list[str] = []
|
||||
for path in trace_paths:
|
||||
for symbol_id in path.symbol_ids:
|
||||
if symbol_id and symbol_id not in ordered_symbol_ids:
|
||||
ordered_symbol_ids.append(symbol_id)
|
||||
chunks = self._graph.get_chunks_by_symbol_ids(rag_session_id, ordered_symbol_ids)
|
||||
excerpts: list[CodeExcerpt] = []
|
||||
evidence_index: dict[str, EvidenceItem] = {}
|
||||
for chunk in chunks:
|
||||
symbol_id = str(chunk.metadata.get("symbol_id") or "")
|
||||
evidence_id = f"excerpt_{len(evidence_index) + 1}"
|
||||
location = chunk.location
|
||||
evidence_index[evidence_id] = EvidenceItem(
|
||||
evidence_id=evidence_id,
|
||||
kind="excerpt",
|
||||
summary=chunk.title,
|
||||
location=location,
|
||||
supports=[symbol_id] if symbol_id else [],
|
||||
)
|
||||
is_test_chunk = bool(chunk.metadata.get("is_test")) or is_test_path(location.path if location else chunk.source)
|
||||
for excerpt in self._planner.plan(chunk, evidence_id=evidence_id, symbol_id=symbol_id):
|
||||
if len(excerpts) >= max_excerpts:
|
||||
break
|
||||
if is_test_chunk and not excerpt.focus.startswith("test:"):
|
||||
excerpt.focus = f"test:{excerpt.focus}"
|
||||
excerpts.append(excerpt)
|
||||
if len(excerpts) >= max_excerpts:
|
||||
break
|
||||
return excerpts, evidence_index
|
||||
102
app/modules/rag/explain/trace_builder.py
Normal file
102
app/modules/rag/explain/trace_builder.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.modules.rag.explain.models import LayeredRetrievalItem, TracePath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
||||
|
||||
|
||||
class TraceBuilder:
|
||||
def __init__(self, graph_repository: CodeGraphRepository) -> None:
|
||||
self._graph = graph_repository
|
||||
|
||||
def build_paths(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
seed_symbols: list[LayeredRetrievalItem],
|
||||
*,
|
||||
max_depth: int,
|
||||
max_paths: int = 3,
|
||||
edge_types: list[str] | None = None,
|
||||
) -> list[TracePath]:
|
||||
edges_filter = edge_types or ["calls", "imports", "inherits"]
|
||||
symbol_map = self._symbol_map(seed_symbols)
|
||||
paths: list[TracePath] = []
|
||||
for seed in seed_symbols:
|
||||
seed_id = str(seed.metadata.get("symbol_id") or "")
|
||||
if not seed_id:
|
||||
continue
|
||||
queue: list[tuple[list[str], float, list[str]]] = [([seed_id], 0.0, [])]
|
||||
while queue and len(paths) < max_paths * 3:
|
||||
current_path, score, notes = queue.pop(0)
|
||||
src_symbol_id = current_path[-1]
|
||||
out_edges = self._graph.get_out_edges(rag_session_id, [src_symbol_id], edges_filter, limit_per_src=4)
|
||||
if not out_edges or len(current_path) >= max_depth:
|
||||
paths.append(TracePath(symbol_ids=current_path, score=score, notes=notes))
|
||||
continue
|
||||
for edge in out_edges:
|
||||
metadata = edge.metadata
|
||||
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
|
||||
next_notes = list(notes)
|
||||
next_score = score + self._edge_score(edge, symbol_map.get(src_symbol_id))
|
||||
if not dst_symbol_id:
|
||||
dst_ref = str(metadata.get("dst_ref") or "")
|
||||
package_hint = self._package_hint(symbol_map.get(src_symbol_id))
|
||||
resolved = self._graph.resolve_symbol_by_ref(rag_session_id, dst_ref, package_hint=package_hint)
|
||||
if resolved is not None:
|
||||
dst_symbol_id = str(resolved.metadata.get("symbol_id") or "")
|
||||
symbol_map[dst_symbol_id] = resolved
|
||||
next_score += 2.0
|
||||
next_notes.append(f"resolved:{dst_ref}")
|
||||
if not dst_symbol_id or dst_symbol_id in current_path:
|
||||
paths.append(TracePath(symbol_ids=current_path, score=next_score, notes=next_notes))
|
||||
continue
|
||||
if dst_symbol_id not in symbol_map:
|
||||
symbols = self._graph.get_symbols_by_ids(rag_session_id, [dst_symbol_id])
|
||||
if symbols:
|
||||
symbol_map[dst_symbol_id] = symbols[0]
|
||||
queue.append((current_path + [dst_symbol_id], next_score, next_notes))
|
||||
unique = self._unique_paths(paths)
|
||||
unique.sort(key=lambda item: item.score, reverse=True)
|
||||
return unique[:max_paths] or [TracePath(symbol_ids=[seed.metadata.get("symbol_id", "")], score=0.0) for seed in seed_symbols[:1]]
|
||||
|
||||
def _edge_score(self, edge: LayeredRetrievalItem, source_symbol: LayeredRetrievalItem | None) -> float:
|
||||
metadata = edge.metadata
|
||||
score = 1.0
|
||||
if str(metadata.get("resolution") or "") == "resolved":
|
||||
score += 2.0
|
||||
source_path = source_symbol.source if source_symbol is not None else ""
|
||||
if source_path and edge.source == source_path:
|
||||
score += 1.0
|
||||
if "tests/" in edge.source or "/tests/" in edge.source:
|
||||
score -= 3.0
|
||||
return score
|
||||
|
||||
def _package_hint(self, symbol: LayeredRetrievalItem | None) -> str | None:
|
||||
if symbol is None:
|
||||
return None
|
||||
package = str(symbol.metadata.get("package_or_module") or "")
|
||||
if not package:
|
||||
return None
|
||||
return ".".join(package.split(".")[:-1]) or package
|
||||
|
||||
def _symbol_map(self, items: list[LayeredRetrievalItem]) -> dict[str, LayeredRetrievalItem]:
|
||||
result: dict[str, LayeredRetrievalItem] = {}
|
||||
for item in items:
|
||||
symbol_id = str(item.metadata.get("symbol_id") or "")
|
||||
if symbol_id:
|
||||
result[symbol_id] = item
|
||||
return result
|
||||
|
||||
def _unique_paths(self, items: list[TracePath]) -> list[TracePath]:
|
||||
result: list[TracePath] = []
|
||||
seen: set[tuple[str, ...]] = set()
|
||||
for item in items:
|
||||
key = tuple(symbol_id for symbol_id in item.symbol_ids if symbol_id)
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(item)
|
||||
return result
|
||||
Reference in New Issue
Block a user