Фиксация изменений

This commit is contained in:
2026-03-05 11:03:17 +03:00
parent 1ef0b4d68c
commit 417b8b6f72
261 changed files with 8215 additions and 332 deletions

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from importlib import import_module
__all__ = [
"CodeExcerpt",
"CodeExplainRetrieverV2",
"CodeGraphRepository",
"EvidenceItem",
"ExplainIntent",
"ExplainIntentBuilder",
"ExplainPack",
"LayeredRetrievalGateway",
"PromptBudgeter",
"TracePath",
]
def __getattr__(name: str):
module_map = {
"CodeExcerpt": "app.modules.rag.explain.models",
"EvidenceItem": "app.modules.rag.explain.models",
"ExplainIntent": "app.modules.rag.explain.models",
"ExplainPack": "app.modules.rag.explain.models",
"TracePath": "app.modules.rag.explain.models",
"ExplainIntentBuilder": "app.modules.rag.explain.intent_builder",
"PromptBudgeter": "app.modules.rag.explain.budgeter",
"LayeredRetrievalGateway": "app.modules.rag.explain.layered_gateway",
"CodeGraphRepository": "app.modules.rag.explain.graph_repository",
"CodeExplainRetrieverV2": "app.modules.rag.explain.retriever_v2",
}
module_name = module_map.get(name)
if module_name is None:
raise AttributeError(name)
module = import_module(module_name)
return getattr(module, name)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import json
from app.modules.rag.explain.models import ExplainPack
class PromptBudgeter:
def __init__(
self,
*,
max_paths: int = 3,
max_symbols: int = 25,
max_excerpts: int = 40,
max_chars: int = 30000,
) -> None:
self._max_paths = max_paths
self._max_symbols = max_symbols
self._max_excerpts = max_excerpts
self._max_chars = max_chars
def build_prompt_input(self, question: str, pack: ExplainPack) -> str:
symbol_ids: list[str] = []
for path in pack.trace_paths[: self._max_paths]:
for symbol_id in path.symbol_ids:
if symbol_id and symbol_id not in symbol_ids and len(symbol_ids) < self._max_symbols:
symbol_ids.append(symbol_id)
excerpts = []
total_chars = 0
for excerpt in pack.code_excerpts:
if symbol_ids and excerpt.symbol_id and excerpt.symbol_id not in symbol_ids:
continue
body = excerpt.content.strip()
remaining = self._max_chars - total_chars
if remaining <= 0 or len(excerpts) >= self._max_excerpts:
break
if len(body) > remaining:
body = body[:remaining].rstrip() + "...[truncated]"
excerpts.append(
{
"evidence_id": excerpt.evidence_id,
"title": excerpt.title,
"path": excerpt.path,
"start_line": excerpt.start_line,
"end_line": excerpt.end_line,
"focus": excerpt.focus,
"content": body,
}
)
total_chars += len(body)
payload = {
"question": question,
"intent": pack.intent.model_dump(mode="json"),
"selected_entrypoints": [item.model_dump(mode="json") for item in pack.selected_entrypoints[:5]],
"seed_symbols": [item.model_dump(mode="json") for item in pack.seed_symbols[: self._max_symbols]],
"trace_paths": [path.model_dump(mode="json") for path in pack.trace_paths[: self._max_paths]],
"evidence_index": {key: value.model_dump(mode="json") for key, value in pack.evidence_index.items()},
"code_excerpts": excerpts,
"missing": pack.missing,
"conflicts": pack.conflicts,
}
return json.dumps(payload, ensure_ascii=False, indent=2)

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from app.modules.rag.explain.models import CodeExcerpt, LayeredRetrievalItem
class ExcerptPlanner:
_FOCUS_TOKENS = ("raise", "except", "db", "select", "insert", "update", "delete", "http", "publish", "emit")
def plan(self, chunk: LayeredRetrievalItem, *, evidence_id: str, symbol_id: str | None) -> list[CodeExcerpt]:
location = chunk.location
if location is None:
return []
excerpts = [
CodeExcerpt(
evidence_id=evidence_id,
symbol_id=symbol_id,
title=chunk.title,
path=location.path,
start_line=location.start_line,
end_line=location.end_line,
content=chunk.content.strip(),
focus="overview",
)
]
focus = self._focus_excerpt(chunk, evidence_id=evidence_id, symbol_id=symbol_id)
if focus is not None:
excerpts.append(focus)
return excerpts
def _focus_excerpt(
self,
chunk: LayeredRetrievalItem,
*,
evidence_id: str,
symbol_id: str | None,
) -> CodeExcerpt | None:
location = chunk.location
if location is None:
return None
lines = chunk.content.splitlines()
for index, line in enumerate(lines):
lowered = line.lower()
if not any(token in lowered for token in self._FOCUS_TOKENS):
continue
start = max(0, index - 2)
end = min(len(lines), index + 3)
if end - start >= len(lines):
return None
return CodeExcerpt(
evidence_id=evidence_id,
symbol_id=symbol_id,
title=f"{chunk.title}:focus",
path=location.path,
start_line=(location.start_line or 1) + start,
end_line=(location.start_line or 1) + end - 1,
content="\n".join(lines[start:end]).strip(),
focus="focus",
)
return None

View File

@@ -0,0 +1,216 @@
from __future__ import annotations
import json
from sqlalchemy import text
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
from app.modules.shared.db import get_engine
class CodeGraphRepository:
def get_out_edges(
self,
rag_session_id: str,
src_symbol_ids: list[str],
edge_types: list[str],
limit_per_src: int,
) -> list[LayeredRetrievalItem]:
if not src_symbol_ids:
return []
sql = """
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C2_DEPENDENCY_GRAPH'
AND CAST(metadata_json AS jsonb)->>'src_symbol_id' = ANY(:src_ids)
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
ORDER BY path, span_start
"""
with get_engine().connect() as conn:
rows = conn.execute(
text(sql),
{"sid": rag_session_id, "src_ids": src_symbol_ids, "edge_types": edge_types},
).mappings().fetchall()
grouped: dict[str, int] = {}
items: list[LayeredRetrievalItem] = []
for row in rows:
metadata = self._loads(row.get("metadata_json"))
src_symbol_id = str(metadata.get("src_symbol_id") or "")
grouped[src_symbol_id] = grouped.get(src_symbol_id, 0) + 1
if grouped[src_symbol_id] > limit_per_src:
continue
items.append(self._to_item(row, metadata))
return items
def get_in_edges(
self,
rag_session_id: str,
dst_symbol_ids: list[str],
edge_types: list[str],
limit_per_dst: int,
) -> list[LayeredRetrievalItem]:
if not dst_symbol_ids:
return []
sql = """
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C2_DEPENDENCY_GRAPH'
AND CAST(metadata_json AS jsonb)->>'dst_symbol_id' = ANY(:dst_ids)
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
ORDER BY path, span_start
"""
with get_engine().connect() as conn:
rows = conn.execute(
text(sql),
{"sid": rag_session_id, "dst_ids": dst_symbol_ids, "edge_types": edge_types},
).mappings().fetchall()
grouped: dict[str, int] = {}
items: list[LayeredRetrievalItem] = []
for row in rows:
metadata = self._loads(row.get("metadata_json"))
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
grouped[dst_symbol_id] = grouped.get(dst_symbol_id, 0) + 1
if grouped[dst_symbol_id] > limit_per_dst:
continue
items.append(self._to_item(row, metadata))
return items
def resolve_symbol_by_ref(
self,
rag_session_id: str,
dst_ref: str,
package_hint: str | None = None,
) -> LayeredRetrievalItem | None:
ref = (dst_ref or "").strip()
if not ref:
return None
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content, layer, title, metadata_json, span_start, span_end, qname
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C1_SYMBOL_CATALOG'
AND (qname = :ref OR title = :ref OR qname LIKE :tail)
ORDER BY path
LIMIT 12
"""
),
{"sid": rag_session_id, "ref": ref, "tail": f"%{ref}"},
).mappings().fetchall()
best: LayeredRetrievalItem | None = None
best_score = -1
for row in rows:
metadata = self._loads(row.get("metadata_json"))
package = str(metadata.get("package_or_module") or "")
score = 0
if str(row.get("qname") or "") == ref:
score += 3
if str(row.get("title") or "") == ref:
score += 2
if package_hint and package.startswith(package_hint):
score += 3
if package_hint and package_hint in str(row.get("path") or ""):
score += 1
if score > best_score:
best = self._to_item(row, metadata)
best_score = score
return best
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]) -> list[LayeredRetrievalItem]:
if not symbol_ids:
return []
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C1_SYMBOL_CATALOG'
AND symbol_id = ANY(:symbol_ids)
ORDER BY path, span_start
"""
),
{"sid": rag_session_id, "symbol_ids": symbol_ids},
).mappings().fetchall()
return [self._to_item(row, self._loads(row.get("metadata_json"))) for row in rows]
def get_chunks_by_symbol_ids(
self,
rag_session_id: str,
symbol_ids: list[str],
prefer_chunk_type: str = "symbol_block",
) -> list[LayeredRetrievalItem]:
symbols = self.get_symbols_by_ids(rag_session_id, symbol_ids)
chunks: list[LayeredRetrievalItem] = []
for symbol in symbols:
location = symbol.location
if location is None:
continue
chunk = self._chunk_for_symbol(rag_session_id, symbol, prefer_chunk_type=prefer_chunk_type)
if chunk is not None:
chunks.append(chunk)
return chunks
def _chunk_for_symbol(
self,
rag_session_id: str,
symbol: LayeredRetrievalItem,
*,
prefer_chunk_type: str,
) -> LayeredRetrievalItem | None:
location = symbol.location
if location is None:
return None
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C0_SOURCE_CHUNKS'
AND path = :path
AND COALESCE(span_start, 0) <= :end_line
AND COALESCE(span_end, 999999) >= :start_line
ORDER BY
CASE WHEN CAST(metadata_json AS jsonb)->>'chunk_type' = :prefer_chunk_type THEN 0 ELSE 1 END,
ABS(COALESCE(span_start, 0) - :start_line)
LIMIT 1
"""
),
{
"sid": rag_session_id,
"path": location.path,
"start_line": location.start_line or 0,
"end_line": location.end_line or 999999,
"prefer_chunk_type": prefer_chunk_type,
},
).mappings().fetchall()
if not rows:
return None
row = rows[0]
return self._to_item(row, self._loads(row.get("metadata_json")))
def _to_item(self, row, metadata: dict) -> LayeredRetrievalItem:
return LayeredRetrievalItem(
source=str(row.get("path") or ""),
content=str(row.get("content") or ""),
layer=str(row.get("layer") or ""),
title=str(row.get("title") or ""),
metadata=metadata,
location=CodeLocation(
path=str(row.get("path") or ""),
start_line=row.get("span_start"),
end_line=row.get("span_end"),
),
)
def _loads(self, value) -> dict:
if not value:
return {}
return json.loads(str(value))

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import re
from app.modules.rag.explain.models import ExplainHints, ExplainIntent
from app.modules.rag.retrieval.query_terms import extract_query_terms
class ExplainIntentBuilder:
_ROUTE_RE = re.compile(r"(/[A-Za-z0-9_./{}:-]+)")
_FILE_RE = re.compile(r"([A-Za-z0-9_./-]+\.py)")
_SYMBOL_RE = re.compile(r"\b([A-Z][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*|[A-Z][A-Za-z0-9_]{2,}|[a-z_][A-Za-z0-9_]{2,})\b")
_COMMAND_RE = re.compile(r"`([A-Za-z0-9:_-]+)`")
_TEST_KEYWORDS = (
"тест",
"tests",
"test ",
"unit-test",
"unit test",
"юнит-тест",
"pytest",
"spec",
"как покрыто тестами",
"как проверяется",
"how is it tested",
"how it's tested",
)
def build(self, user_query: str) -> ExplainIntent:
normalized = " ".join((user_query or "").split())
lowered = normalized.lower()
keywords = self._keywords(normalized)
hints = ExplainHints(
paths=self._dedupe(self._FILE_RE.findall(normalized)),
symbols=self._symbols(normalized),
endpoints=self._dedupe(self._ROUTE_RE.findall(normalized)),
commands=self._commands(normalized, lowered),
)
return ExplainIntent(
raw_query=user_query,
normalized_query=normalized,
keywords=keywords[:12],
hints=hints,
include_tests=self._include_tests(lowered),
expected_entry_types=self._entry_types(lowered, hints),
depth=self._depth(lowered),
)
def _keywords(self, text: str) -> list[str]:
keywords = extract_query_terms(text)
for token in self._symbols(text):
if token not in keywords:
keywords.append(token)
for token in self._ROUTE_RE.findall(text):
if token not in keywords:
keywords.append(token)
return self._dedupe(keywords)
def _symbols(self, text: str) -> list[str]:
values = []
for raw in self._SYMBOL_RE.findall(text):
token = raw.strip()
if len(token) < 3:
continue
if token.endswith(".py"):
continue
values.append(token)
return self._dedupe(values)
def _commands(self, text: str, lowered: str) -> list[str]:
values = list(self._COMMAND_RE.findall(text))
if " command " in f" {lowered} ":
values.extend(re.findall(r"command\s+([A-Za-z0-9:_-]+)", lowered))
if " cli " in f" {lowered} ":
values.extend(re.findall(r"cli\s+([A-Za-z0-9:_-]+)", lowered))
return self._dedupe(values)
def _entry_types(self, lowered: str, hints: ExplainHints) -> list[str]:
if hints.endpoints or any(token in lowered for token in ("endpoint", "route", "handler", "http", "api")):
return ["http"]
if hints.commands or any(token in lowered for token in ("cli", "command", "click", "typer")):
return ["cli"]
return ["http", "cli"]
def _depth(self, lowered: str) -> str:
if any(token in lowered for token in ("deep", "подроб", "деталь", "full flow", "trace")):
return "deep"
if any(token in lowered for token in ("high level", "overview", "кратко", "summary")):
return "high"
return "medium"
def _include_tests(self, lowered: str) -> bool:
normalized = f" {lowered} "
return any(token in normalized for token in self._TEST_KEYWORDS)
def _dedupe(self, values: list[str]) -> list[str]:
result: list[str] = []
for value in values:
item = value.strip()
if item and item not in result:
result.append(item)
return result

View File

@@ -0,0 +1,289 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
from app.modules.rag.retrieval.test_filter import build_test_filters, debug_disable_test_filter
LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from app.modules.rag.persistence.repository import RagRepository
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
@dataclass(slots=True)
class LayerRetrievalResult:
items: list[LayeredRetrievalItem]
missing: list[str] = field(default_factory=list)
class LayeredRetrievalGateway:
def __init__(self, repository: RagRepository, embedder: GigaChatEmbedder) -> None:
self._repository = repository
self._embedder = embedder
def retrieve_layer(
self,
rag_session_id: str,
query: str,
layer: str,
*,
limit: int,
path_prefixes: list[str] | None = None,
exclude_tests: bool = True,
prefer_non_tests: bool = False,
include_spans: bool = False,
) -> LayerRetrievalResult:
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
filter_args = self._filter_args(effective_exclude_tests)
query_embedding: list[float] | None = None
try:
query_embedding = self._embedder.embed([query])[0]
rows = self._repository.retrieve(
rag_session_id,
query_embedding,
query_text=query,
limit=limit,
layers=[layer],
path_prefixes=path_prefixes,
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
exclude_like_patterns=filter_args["exclude_like_patterns"],
prefer_non_tests=prefer_non_tests or not effective_exclude_tests,
)
return self._success_result(
rows,
rag_session_id=rag_session_id,
label="layered retrieval",
include_spans=include_spans,
layer=layer,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
)
except Exception as exc:
if query_embedding is None:
self._log_failure(
label="layered retrieval",
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
)
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
retry_result = self._retry_without_test_filter(
operation=lambda: self._repository.retrieve(
rag_session_id,
query_embedding,
query_text=query,
limit=limit,
layers=[layer],
path_prefixes=path_prefixes,
exclude_path_prefixes=None,
exclude_like_patterns=None,
prefer_non_tests=True,
),
label="layered retrieval",
rag_session_id=rag_session_id,
include_spans=include_spans,
layer=layer,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
missing_prefix=f"layer:{layer} retrieval_failed",
)
if retry_result is not None:
return retry_result
return LayerRetrievalResult(items=[], missing=[self._failure_missing(f"layer:{layer} retrieval_failed", exc)])
def retrieve_lexical_code(
self,
rag_session_id: str,
query: str,
*,
limit: int,
path_prefixes: list[str] | None = None,
exclude_tests: bool = True,
include_spans: bool = False,
) -> LayerRetrievalResult:
effective_exclude_tests = exclude_tests and not debug_disable_test_filter()
filter_args = self._filter_args(effective_exclude_tests)
try:
rows = self._repository.retrieve_lexical_code(
rag_session_id,
query_text=query,
limit=limit,
path_prefixes=path_prefixes,
exclude_path_prefixes=filter_args["exclude_path_prefixes"],
exclude_like_patterns=filter_args["exclude_like_patterns"],
prefer_non_tests=not effective_exclude_tests,
)
return self._success_result(
rows,
rag_session_id=rag_session_id,
label="lexical retrieval",
include_spans=include_spans,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
)
except Exception as exc:
retry_result = self._retry_without_test_filter(
operation=lambda: self._repository.retrieve_lexical_code(
rag_session_id,
query_text=query,
limit=limit,
path_prefixes=path_prefixes,
exclude_path_prefixes=None,
exclude_like_patterns=None,
prefer_non_tests=True,
),
label="lexical retrieval",
rag_session_id=rag_session_id,
include_spans=include_spans,
exclude_tests=effective_exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
missing_prefix="layer:C0 lexical_retrieval_failed",
)
if retry_result is not None:
return retry_result
return LayerRetrievalResult(items=[], missing=[self._failure_missing("layer:C0 lexical_retrieval_failed", exc)])
def _retry_without_test_filter(
self,
*,
operation: Callable[[], list[dict]],
label: str,
rag_session_id: str,
include_spans: bool,
exclude_tests: bool,
path_prefixes: list[str] | None,
exc: Exception,
missing_prefix: str,
layer: str | None = None,
) -> LayerRetrievalResult | None:
if not exclude_tests:
self._log_failure(
label=label,
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
)
return None
self._log_failure(
label=label,
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=exclude_tests,
path_prefixes=path_prefixes,
exc=exc,
retried_without_test_filter=True,
)
try:
rows = operation()
except Exception as retry_exc:
self._log_failure(
label=f"{label} retry",
rag_session_id=rag_session_id,
layer=layer,
exclude_tests=False,
path_prefixes=path_prefixes,
exc=retry_exc,
)
return None
result = self._success_result(
rows,
rag_session_id=rag_session_id,
label=f"{label} retry",
include_spans=include_spans,
layer=layer,
exclude_tests=False,
path_prefixes=path_prefixes,
)
result.missing.append(f"{missing_prefix}:retried_without_test_filter")
return result
def _success_result(
self,
rows: list[dict],
*,
rag_session_id: str,
label: str,
include_spans: bool,
exclude_tests: bool,
path_prefixes: list[str] | None,
layer: str | None = None,
) -> LayerRetrievalResult:
items = [self._to_item(row, include_spans=include_spans) for row in rows]
LOGGER.warning(
"%s: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s returned_count=%s top_paths=%s",
label,
rag_session_id,
layer,
exclude_tests,
path_prefixes or [],
len(items),
[item.source for item in items[:3]],
)
return LayerRetrievalResult(items=items)
def _log_failure(
self,
*,
label: str,
rag_session_id: str,
exclude_tests: bool,
path_prefixes: list[str] | None,
exc: Exception,
layer: str | None = None,
retried_without_test_filter: bool = False,
) -> None:
LOGGER.warning(
"%s failed: rag_session_id=%s layer=%s exclude_tests=%s path_prefixes=%s retried_without_test_filter=%s error=%s",
label,
rag_session_id,
layer,
exclude_tests,
path_prefixes or [],
retried_without_test_filter,
self._exception_summary(exc),
exc_info=True,
)
def _filter_args(self, exclude_tests: bool) -> dict[str, list[str] | None]:
test_filters = build_test_filters() if exclude_tests else None
return {
"exclude_path_prefixes": test_filters.exclude_path_prefixes if test_filters else None,
"exclude_like_patterns": test_filters.exclude_like_patterns if test_filters else None,
}
def _failure_missing(self, prefix: str, exc: Exception) -> str:
return f"{prefix}:{self._exception_summary(exc)}"
def _exception_summary(self, exc: Exception) -> str:
message = " ".join(str(exc).split())
if len(message) > 180:
message = message[:177] + "..."
return f"{type(exc).__name__}:{message or 'no_message'}"
def _to_item(self, row: dict, *, include_spans: bool) -> LayeredRetrievalItem:
location = None
if include_spans:
location = CodeLocation(
path=str(row.get("path") or ""),
start_line=row.get("span_start"),
end_line=row.get("span_end"),
)
return LayeredRetrievalItem(
source=str(row.get("path") or ""),
content=str(row.get("content") or ""),
layer=str(row.get("layer") or ""),
title=str(row.get("title") or ""),
metadata=dict(row.get("metadata", {}) or {}),
score=row.get("distance"),
location=location,
)

View File

@@ -0,0 +1,91 @@
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExplainHints(BaseModel):
model_config = ConfigDict(extra="forbid")
paths: list[str] = Field(default_factory=list)
symbols: list[str] = Field(default_factory=list)
endpoints: list[str] = Field(default_factory=list)
commands: list[str] = Field(default_factory=list)
class ExplainIntent(BaseModel):
model_config = ConfigDict(extra="forbid")
raw_query: str
normalized_query: str
keywords: list[str] = Field(default_factory=list)
hints: ExplainHints = Field(default_factory=ExplainHints)
include_tests: bool = False
expected_entry_types: list[Literal["http", "cli"]] = Field(default_factory=list)
depth: Literal["high", "medium", "deep"] = "medium"
class CodeLocation(BaseModel):
model_config = ConfigDict(extra="forbid")
path: str
start_line: int | None = None
end_line: int | None = None
class LayeredRetrievalItem(BaseModel):
model_config = ConfigDict(extra="forbid")
source: str
content: str
layer: str
title: str
metadata: dict[str, Any] = Field(default_factory=dict)
score: float | None = None
location: CodeLocation | None = None
class TracePath(BaseModel):
model_config = ConfigDict(extra="forbid")
symbol_ids: list[str] = Field(default_factory=list)
score: float = 0.0
entrypoint_id: str | None = None
notes: list[str] = Field(default_factory=list)
class EvidenceItem(BaseModel):
model_config = ConfigDict(extra="forbid")
evidence_id: str
kind: Literal["entrypoint", "symbol", "edge", "excerpt"]
summary: str
location: CodeLocation | None = None
supports: list[str] = Field(default_factory=list)
class CodeExcerpt(BaseModel):
model_config = ConfigDict(extra="forbid")
evidence_id: str
symbol_id: str | None = None
title: str
path: str
start_line: int | None = None
end_line: int | None = None
content: str
focus: str = "overview"
class ExplainPack(BaseModel):
model_config = ConfigDict(extra="forbid")
intent: ExplainIntent
selected_entrypoints: list[LayeredRetrievalItem] = Field(default_factory=list)
seed_symbols: list[LayeredRetrievalItem] = Field(default_factory=list)
trace_paths: list[TracePath] = Field(default_factory=list)
evidence_index: dict[str, EvidenceItem] = Field(default_factory=dict)
code_excerpts: list[CodeExcerpt] = Field(default_factory=list)
missing: list[str] = Field(default_factory=list)
conflicts: list[str] = Field(default_factory=list)

View File

@@ -0,0 +1,328 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from app.modules.rag.contracts.enums import RagLayer
from app.modules.rag.explain.intent_builder import ExplainIntentBuilder
from app.modules.rag.explain.layered_gateway import LayerRetrievalResult, LayeredRetrievalGateway
from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, ExplainPack, LayeredRetrievalItem
from app.modules.rag.explain.source_excerpt_fetcher import SourceExcerptFetcher
from app.modules.rag.explain.trace_builder import TraceBuilder
from app.modules.rag.retrieval.test_filter import exclude_tests_default, is_test_path
LOGGER = logging.getLogger(__name__)
_MIN_EXCERPTS = 2
if TYPE_CHECKING:
from app.modules.rag.explain.graph_repository import CodeGraphRepository
from app.modules.rag.explain.models import ExplainIntent
class CodeExplainRetrieverV2:
def __init__(
self,
gateway: LayeredRetrievalGateway,
graph_repository: CodeGraphRepository,
intent_builder: ExplainIntentBuilder | None = None,
trace_builder: TraceBuilder | None = None,
excerpt_fetcher: SourceExcerptFetcher | None = None,
) -> None:
self._gateway = gateway
self._graph = graph_repository
self._intent_builder = intent_builder or ExplainIntentBuilder()
self._trace_builder = trace_builder or TraceBuilder(graph_repository)
self._excerpt_fetcher = excerpt_fetcher or SourceExcerptFetcher(graph_repository)
def build_pack(
self,
rag_session_id: str,
user_query: str,
*,
file_candidates: list[dict] | None = None,
) -> ExplainPack:
intent = self._intent_builder.build(user_query)
path_prefixes = _path_prefixes(intent, file_candidates or [])
exclude_tests = exclude_tests_default() and not intent.include_tests
pack = self._run_pass(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
if exclude_tests and len(pack.code_excerpts) < _MIN_EXCERPTS:
self._merge_test_fallback(pack, rag_session_id, intent, path_prefixes)
self._log_pack(rag_session_id, pack)
return pack
def _run_pass(
self,
rag_session_id: str,
intent: ExplainIntent,
path_prefixes: list[str],
*,
exclude_tests: bool,
) -> ExplainPack:
missing: list[str] = []
entrypoints_result = self._entrypoints(rag_session_id, intent, path_prefixes, exclude_tests=exclude_tests)
missing.extend(entrypoints_result.missing)
selected_entrypoints = self._filter_entrypoints(intent, entrypoints_result.items)
if not selected_entrypoints:
missing.append("layer:C3 empty")
seed_result = self._seed_symbols(rag_session_id, intent, path_prefixes, selected_entrypoints, exclude_tests=exclude_tests)
missing.extend(seed_result.missing)
seed_symbols = seed_result.items
if not seed_symbols:
missing.append("layer:C1 empty")
depth = 4 if intent.depth == "deep" else 3 if intent.depth == "medium" else 2
trace_paths = self._trace_builder.build_paths(rag_session_id, seed_symbols, max_depth=depth) if seed_symbols else []
excerpts, excerpt_evidence = self._excerpt_fetcher.fetch(rag_session_id, trace_paths) if trace_paths else ([], {})
if not excerpts:
lexical_result = self._gateway.retrieve_lexical_code(
rag_session_id,
intent.normalized_query,
limit=6,
path_prefixes=path_prefixes or None,
exclude_tests=exclude_tests,
include_spans=True,
)
missing.extend(lexical_result.missing)
excerpts, excerpt_evidence = _lexical_excerpts(lexical_result.items)
if not excerpts:
missing.append("layer:C0 empty")
evidence_index = _evidence_index(selected_entrypoints, seed_symbols)
evidence_index.update(excerpt_evidence)
missing.extend(_missing(selected_entrypoints, seed_symbols, trace_paths, excerpts))
return ExplainPack(
intent=intent,
selected_entrypoints=selected_entrypoints,
seed_symbols=seed_symbols,
trace_paths=trace_paths,
evidence_index=evidence_index,
code_excerpts=excerpts,
missing=_cleanup_missing(_dedupe(missing), has_excerpts=bool(excerpts)),
conflicts=[],
)
def _merge_test_fallback(
self,
pack: ExplainPack,
rag_session_id: str,
intent: ExplainIntent,
path_prefixes: list[str],
) -> None:
lexical_result = self._gateway.retrieve_lexical_code(
rag_session_id,
intent.normalized_query,
limit=6,
path_prefixes=path_prefixes or None,
exclude_tests=False,
include_spans=True,
)
excerpt_offset = len([key for key in pack.evidence_index if key.startswith("excerpt_")])
excerpts, evidence = _lexical_excerpts(
lexical_result.items,
start_index=excerpt_offset,
is_test_fallback=True,
)
if not excerpts:
pack.missing = _dedupe(pack.missing + lexical_result.missing)
return
seen = {(item.path, item.start_line, item.end_line, item.content) for item in pack.code_excerpts}
for excerpt in excerpts:
key = (excerpt.path, excerpt.start_line, excerpt.end_line, excerpt.content)
if key in seen:
continue
pack.code_excerpts.append(excerpt)
seen.add(key)
pack.evidence_index.update(evidence)
pack.missing = _cleanup_missing(_dedupe(pack.missing + lexical_result.missing), has_excerpts=bool(pack.code_excerpts))
def _entrypoints(
self,
rag_session_id: str,
intent: ExplainIntent,
path_prefixes: list[str],
*,
exclude_tests: bool,
) -> LayerRetrievalResult:
return self._gateway.retrieve_layer(
rag_session_id,
intent.normalized_query,
RagLayer.CODE_ENTRYPOINTS,
limit=6,
path_prefixes=path_prefixes or None,
exclude_tests=exclude_tests,
prefer_non_tests=True,
include_spans=True,
)
def _filter_entrypoints(self, intent: ExplainIntent, items: list[LayeredRetrievalItem]) -> list[LayeredRetrievalItem]:
if not intent.expected_entry_types:
return items[:3]
filtered = [item for item in items if str(item.metadata.get("entry_type") or "") in intent.expected_entry_types]
return filtered[:3] or items[:3]
def _seed_symbols(
self,
rag_session_id: str,
intent: ExplainIntent,
path_prefixes: list[str],
entrypoints: list[LayeredRetrievalItem],
*,
exclude_tests: bool,
) -> LayerRetrievalResult:
symbol_result = self._gateway.retrieve_layer(
rag_session_id,
intent.normalized_query,
RagLayer.CODE_SYMBOL_CATALOG,
limit=12,
path_prefixes=path_prefixes or None,
exclude_tests=exclude_tests,
prefer_non_tests=True,
include_spans=True,
)
handlers: list[LayeredRetrievalItem] = []
handler_ids = [str(item.metadata.get("handler_symbol_id") or "") for item in entrypoints]
if handler_ids:
handlers = self._graph.get_symbols_by_ids(rag_session_id, [item for item in handler_ids if item])
seeds: list[LayeredRetrievalItem] = []
seen: set[str] = set()
for item in handlers + symbol_result.items:
symbol_id = str(item.metadata.get("symbol_id") or "")
if not symbol_id or symbol_id in seen:
continue
seen.add(symbol_id)
seeds.append(item)
if len(seeds) >= 8:
break
return LayerRetrievalResult(items=seeds, missing=list(symbol_result.missing))
def _log_pack(self, rag_session_id: str, pack: ExplainPack) -> None:
prod_excerpt_count = len([excerpt for excerpt in pack.code_excerpts if not _is_test_excerpt(excerpt)])
test_excerpt_count = len(pack.code_excerpts) - prod_excerpt_count
LOGGER.warning(
"code explain pack: rag_session_id=%s entrypoints=%s seeds=%s paths=%s excerpts=%s prod_excerpt_count=%s test_excerpt_count=%s missing=%s",
rag_session_id,
len(pack.selected_entrypoints),
len(pack.seed_symbols),
len(pack.trace_paths),
len(pack.code_excerpts),
prod_excerpt_count,
test_excerpt_count,
pack.missing,
)
def _evidence_index(
entrypoints: list[LayeredRetrievalItem],
seed_symbols: list[LayeredRetrievalItem],
) -> dict[str, EvidenceItem]:
result: dict[str, EvidenceItem] = {}
for index, item in enumerate(entrypoints, start=1):
evidence_id = f"entrypoint_{index}"
result[evidence_id] = EvidenceItem(
evidence_id=evidence_id,
kind="entrypoint",
summary=item.title,
location=item.location,
supports=[str(item.metadata.get("handler_symbol_id") or "")],
)
for index, item in enumerate(seed_symbols, start=1):
evidence_id = f"symbol_{index}"
result[evidence_id] = EvidenceItem(
evidence_id=evidence_id,
kind="symbol",
summary=item.title,
location=item.location,
supports=[str(item.metadata.get("symbol_id") or "")],
)
return result
def _missing(
entrypoints: list[LayeredRetrievalItem],
seed_symbols: list[LayeredRetrievalItem],
trace_paths,
excerpts,
) -> list[str]:
missing: list[str] = []
if not entrypoints:
missing.append("entrypoints")
if not seed_symbols:
missing.append("seed_symbols")
if not trace_paths:
missing.append("trace_paths")
if not excerpts:
missing.append("code_excerpts")
return missing
def _lexical_excerpts(
items: list[LayeredRetrievalItem],
*,
start_index: int = 0,
is_test_fallback: bool = False,
) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]:
excerpts: list[CodeExcerpt] = []
evidence_index: dict[str, EvidenceItem] = {}
for item in items:
evidence_id = f"excerpt_{start_index + len(evidence_index) + 1}"
location = item.location
evidence_index[evidence_id] = EvidenceItem(
evidence_id=evidence_id,
kind="excerpt",
summary=item.title or item.source,
location=location,
supports=[],
)
focus = "lexical"
if _item_is_test(item):
focus = "test:lexical"
elif is_test_fallback:
focus = "lexical"
excerpts.append(
CodeExcerpt(
evidence_id=evidence_id,
symbol_id=str(item.metadata.get("symbol_id") or "") or None,
title=item.title or item.source,
path=item.source,
start_line=location.start_line if location else None,
end_line=location.end_line if location else None,
content=item.content,
focus=focus,
)
)
return excerpts, evidence_index
def _item_is_test(item: LayeredRetrievalItem) -> bool:
return bool(item.metadata.get("is_test")) or is_test_path(item.source)
def _is_test_excerpt(excerpt: CodeExcerpt) -> bool:
return excerpt.focus.startswith("test:") or is_test_path(excerpt.path)
def _path_prefixes(intent: ExplainIntent, file_candidates: list[dict]) -> list[str]:
values: list[str] = []
for path in intent.hints.paths:
prefix = path.rsplit("/", 1)[0] if "/" in path else path
if prefix and prefix not in values:
values.append(prefix)
for item in file_candidates[:6]:
path = str(item.get("path") or "")
prefix = path.rsplit("/", 1)[0] if "/" in path else ""
if prefix and prefix not in values:
values.append(prefix)
return values
def _cleanup_missing(values: list[str], *, has_excerpts: bool) -> list[str]:
if not has_excerpts:
return values
return [value for value in values if value not in {"code_excerpts", "layer:C0 empty"}]
def _dedupe(values: list[str]) -> list[str]:
result: list[str] = []
for value in values:
item = value.strip()
if item and item not in result:
result.append(item)
return result

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from app.modules.rag.explain.excerpt_planner import ExcerptPlanner
from app.modules.rag.explain.models import CodeExcerpt, EvidenceItem, TracePath
from app.modules.rag.retrieval.test_filter import is_test_path
if TYPE_CHECKING:
from app.modules.rag.explain.graph_repository import CodeGraphRepository
class SourceExcerptFetcher:
def __init__(self, graph_repository: CodeGraphRepository, planner: ExcerptPlanner | None = None) -> None:
self._graph = graph_repository
self._planner = planner or ExcerptPlanner()
def fetch(
self,
rag_session_id: str,
trace_paths: list[TracePath],
*,
max_excerpts: int = 40,
) -> tuple[list[CodeExcerpt], dict[str, EvidenceItem]]:
ordered_symbol_ids: list[str] = []
for path in trace_paths:
for symbol_id in path.symbol_ids:
if symbol_id and symbol_id not in ordered_symbol_ids:
ordered_symbol_ids.append(symbol_id)
chunks = self._graph.get_chunks_by_symbol_ids(rag_session_id, ordered_symbol_ids)
excerpts: list[CodeExcerpt] = []
evidence_index: dict[str, EvidenceItem] = {}
for chunk in chunks:
symbol_id = str(chunk.metadata.get("symbol_id") or "")
evidence_id = f"excerpt_{len(evidence_index) + 1}"
location = chunk.location
evidence_index[evidence_id] = EvidenceItem(
evidence_id=evidence_id,
kind="excerpt",
summary=chunk.title,
location=location,
supports=[symbol_id] if symbol_id else [],
)
is_test_chunk = bool(chunk.metadata.get("is_test")) or is_test_path(location.path if location else chunk.source)
for excerpt in self._planner.plan(chunk, evidence_id=evidence_id, symbol_id=symbol_id):
if len(excerpts) >= max_excerpts:
break
if is_test_chunk and not excerpt.focus.startswith("test:"):
excerpt.focus = f"test:{excerpt.focus}"
excerpts.append(excerpt)
if len(excerpts) >= max_excerpts:
break
return excerpts, evidence_index

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from app.modules.rag.explain.models import LayeredRetrievalItem, TracePath
if TYPE_CHECKING:
from app.modules.rag.explain.graph_repository import CodeGraphRepository
class TraceBuilder:
def __init__(self, graph_repository: CodeGraphRepository) -> None:
self._graph = graph_repository
def build_paths(
self,
rag_session_id: str,
seed_symbols: list[LayeredRetrievalItem],
*,
max_depth: int,
max_paths: int = 3,
edge_types: list[str] | None = None,
) -> list[TracePath]:
edges_filter = edge_types or ["calls", "imports", "inherits"]
symbol_map = self._symbol_map(seed_symbols)
paths: list[TracePath] = []
for seed in seed_symbols:
seed_id = str(seed.metadata.get("symbol_id") or "")
if not seed_id:
continue
queue: list[tuple[list[str], float, list[str]]] = [([seed_id], 0.0, [])]
while queue and len(paths) < max_paths * 3:
current_path, score, notes = queue.pop(0)
src_symbol_id = current_path[-1]
out_edges = self._graph.get_out_edges(rag_session_id, [src_symbol_id], edges_filter, limit_per_src=4)
if not out_edges or len(current_path) >= max_depth:
paths.append(TracePath(symbol_ids=current_path, score=score, notes=notes))
continue
for edge in out_edges:
metadata = edge.metadata
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
next_notes = list(notes)
next_score = score + self._edge_score(edge, symbol_map.get(src_symbol_id))
if not dst_symbol_id:
dst_ref = str(metadata.get("dst_ref") or "")
package_hint = self._package_hint(symbol_map.get(src_symbol_id))
resolved = self._graph.resolve_symbol_by_ref(rag_session_id, dst_ref, package_hint=package_hint)
if resolved is not None:
dst_symbol_id = str(resolved.metadata.get("symbol_id") or "")
symbol_map[dst_symbol_id] = resolved
next_score += 2.0
next_notes.append(f"resolved:{dst_ref}")
if not dst_symbol_id or dst_symbol_id in current_path:
paths.append(TracePath(symbol_ids=current_path, score=next_score, notes=next_notes))
continue
if dst_symbol_id not in symbol_map:
symbols = self._graph.get_symbols_by_ids(rag_session_id, [dst_symbol_id])
if symbols:
symbol_map[dst_symbol_id] = symbols[0]
queue.append((current_path + [dst_symbol_id], next_score, next_notes))
unique = self._unique_paths(paths)
unique.sort(key=lambda item: item.score, reverse=True)
return unique[:max_paths] or [TracePath(symbol_ids=[seed.metadata.get("symbol_id", "")], score=0.0) for seed in seed_symbols[:1]]
def _edge_score(self, edge: LayeredRetrievalItem, source_symbol: LayeredRetrievalItem | None) -> float:
metadata = edge.metadata
score = 1.0
if str(metadata.get("resolution") or "") == "resolved":
score += 2.0
source_path = source_symbol.source if source_symbol is not None else ""
if source_path and edge.source == source_path:
score += 1.0
if "tests/" in edge.source or "/tests/" in edge.source:
score -= 3.0
return score
def _package_hint(self, symbol: LayeredRetrievalItem | None) -> str | None:
if symbol is None:
return None
package = str(symbol.metadata.get("package_or_module") or "")
if not package:
return None
return ".".join(package.split(".")[:-1]) or package
def _symbol_map(self, items: list[LayeredRetrievalItem]) -> dict[str, LayeredRetrievalItem]:
result: dict[str, LayeredRetrievalItem] = {}
for item in items:
symbol_id = str(item.metadata.get("symbol_id") or "")
if symbol_id:
result[symbol_id] = item
return result
def _unique_paths(self, items: list[TracePath]) -> list[TracePath]:
result: list[TracePath] = []
seen: set[tuple[str, ...]] = set()
for item in items:
key = tuple(symbol_id for symbol_id in item.symbol_ids if symbol_id)
if not key or key in seen:
continue
seen.add(key)
result.append(item)
return result