Фикс состояния
This commit is contained in:
3
tests/pipeline_setup_v3/core/__init__.py
Normal file
3
tests/pipeline_setup_v3/core/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from tests.pipeline_setup_v3.core.runner import V3Runner
|
||||
|
||||
__all__ = ["V3Runner"]
|
||||
134
tests/pipeline_setup_v3/core/artifacts.py
Normal file
134
tests/pipeline_setup_v3/core/artifacts.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from tests.pipeline_setup_v3.core.models import V3CaseResult
|
||||
|
||||
|
||||
class ArtifactWriter:
|
||||
def __init__(self, root: Path, run_name: str, started_at: datetime) -> None:
|
||||
stamp = started_at.strftime("%Y%m%d_%H%M%S")
|
||||
self.run_dir = root / run_name / stamp
|
||||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def write_case(self, result: V3CaseResult) -> None:
|
||||
stem = f"{result.case.source_file.stem}_{result.case.case_id}"
|
||||
payload = {
|
||||
"case_id": result.case.case_id,
|
||||
"source_file": result.case.source_file.as_posix(),
|
||||
"runner": result.case.runner,
|
||||
"mode": result.case.mode,
|
||||
"query": result.case.query,
|
||||
"actual": result.actual,
|
||||
"passed": result.passed,
|
||||
"mismatches": result.mismatches,
|
||||
"details": result.details,
|
||||
}
|
||||
(self.run_dir / f"{stem}.json").write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
lines = [
|
||||
f"# {result.case.case_id}",
|
||||
"",
|
||||
f"- source_file: {result.case.source_file.as_posix()}",
|
||||
f"- runner: {result.case.runner}",
|
||||
f"- mode: {result.case.mode}",
|
||||
f"- passed: {result.passed}",
|
||||
"",
|
||||
"## Query",
|
||||
result.case.query,
|
||||
"",
|
||||
"## Actual",
|
||||
json.dumps(result.actual, ensure_ascii=False, indent=2),
|
||||
"",
|
||||
"## Steps",
|
||||
json.dumps(result.details.get("steps") or [], ensure_ascii=False, indent=2),
|
||||
"",
|
||||
"## Diagnostics",
|
||||
json.dumps(result.details.get("diagnostics") or {}, ensure_ascii=False, indent=2),
|
||||
"",
|
||||
"## Mismatches",
|
||||
*([f"- {item}" for item in result.mismatches] or ["- none"]),
|
||||
]
|
||||
(self.run_dir / f"{stem}.md").write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
def write_summary(self, results: list[V3CaseResult]) -> Path:
|
||||
path = self.run_dir / "summary.md"
|
||||
path.write_text(SummaryComposer().compose(results), encoding="utf-8")
|
||||
return path
|
||||
|
||||
|
||||
class SummaryComposer:
|
||||
def compose(self, results: list[V3CaseResult]) -> str:
|
||||
passed = sum(1 for item in results if item.passed)
|
||||
lines = [
|
||||
"# pipeline_setup_v3 summary",
|
||||
"",
|
||||
f"Passed: {passed}/{len(results)}",
|
||||
"",
|
||||
"| File | Case | Mode | Query | Actual sub-intent | RAG layers | Pass |",
|
||||
"|------|------|------|-------|-------------------|------------|------|",
|
||||
]
|
||||
lines.extend(self._rows(results))
|
||||
failures = [item for item in results if not item.passed]
|
||||
if failures:
|
||||
lines.extend(["", "## Failures"])
|
||||
for item in failures:
|
||||
lines.append(f"- **{item.case.case_id}**: {'; '.join(item.mismatches)}")
|
||||
lines.extend(self._llm_section(results))
|
||||
return "\n".join(lines)
|
||||
|
||||
def _rows(self, results: list[V3CaseResult]) -> list[str]:
|
||||
rows: list[str] = []
|
||||
for item in results:
|
||||
rows.append(
|
||||
f"| {item.case.source_file.name} | {item.case.case_id} | {item.case.mode} | "
|
||||
f"{self._cell(item.case.query)} | {item.actual.get('sub_intent') or '—'} | "
|
||||
f"{self._layer_text(item.details)} | {'✓' if item.passed else '✗'} |"
|
||||
)
|
||||
return rows
|
||||
|
||||
def _layer_text(self, details: dict) -> str:
|
||||
counts: dict[str, int] = {}
|
||||
for row in details.get("rag_rows") or []:
|
||||
layer = str(row.get("layer") or "").strip()
|
||||
if layer:
|
||||
counts[layer] = counts.get(layer, 0) + 1
|
||||
if not counts:
|
||||
return "—"
|
||||
return self._cell(", ".join(f"{key}:{value}" for key, value in sorted(counts.items())), limit=120)
|
||||
|
||||
def _cell(self, text: str, limit: int = 140) -> str:
|
||||
compact = " ".join(str(text).split()).replace("|", "\\|")
|
||||
if len(compact) <= limit:
|
||||
return compact
|
||||
return compact[: limit - 1].rstrip() + "…"
|
||||
|
||||
def _llm_section(self, results: list[V3CaseResult]) -> list[str]:
|
||||
llm_results = [item for item in results if str(item.actual.get("llm_answer") or "").strip()]
|
||||
if not llm_results:
|
||||
return []
|
||||
lines = ["", "## LLM Answers"]
|
||||
for item in llm_results:
|
||||
lines.append(f"- **{item.case.case_id}**")
|
||||
lines.append(f" Query: {self._cell(item.case.query, limit=400)}")
|
||||
lines.extend(self._quote_block(self._snippet(str(item.actual.get("llm_answer") or ""))))
|
||||
return lines
|
||||
|
||||
def _snippet(self, text: str, limit: int = 880) -> str:
|
||||
compact = " ".join(text.split())
|
||||
if len(compact) <= limit:
|
||||
return compact
|
||||
return compact[: limit - 1].rstrip() + "…"
|
||||
|
||||
def _quote_block(self, text: str) -> list[str]:
|
||||
quoted = text.strip()
|
||||
if not quoted:
|
||||
return [" > —"]
|
||||
return [f" > {self._escape_markdown(line)}" for line in quoted.splitlines()]
|
||||
|
||||
def _escape_markdown(self, text: str) -> str:
|
||||
escaped = text
|
||||
for char in ("\\", "`", "*", "_", "{", "}", "[", "]", "(", ")", "#", "+", "-", "!", "|"):
|
||||
escaped = escaped.replace(char, f"\\{char}")
|
||||
return escaped
|
||||
111
tests/pipeline_setup_v3/core/case_loader.py
Normal file
111
tests/pipeline_setup_v3/core/case_loader.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.pipeline_setup_v3.core.models import (
|
||||
CaseExpectations,
|
||||
CaseInput,
|
||||
LlmExpectation,
|
||||
PipelineExpectation,
|
||||
RetrievalExpectation,
|
||||
RouterExpectation,
|
||||
V3Case,
|
||||
)
|
||||
|
||||
|
||||
class CaseDirectoryLoader:
|
||||
def load(self, cases_dir: Path) -> list[V3Case]:
|
||||
if cases_dir.is_file():
|
||||
return self._load_file(cases_dir)
|
||||
files = sorted(path for path in cases_dir.rglob("*.yaml") if path.is_file())
|
||||
if not files:
|
||||
raise ValueError(f"No YAML case files found in: {cases_dir}")
|
||||
cases: list[V3Case] = []
|
||||
for path in files:
|
||||
cases.extend(self._load_file(path))
|
||||
return cases
|
||||
|
||||
def _load_file(self, path: Path) -> list[V3Case]:
|
||||
payload = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError(f"Invalid case file {path}: expected mapping")
|
||||
defaults = payload.get("defaults") or {}
|
||||
items = payload.get("cases") or []
|
||||
if not isinstance(items, list):
|
||||
raise ValueError(f"Invalid case file {path}: `cases` must be a list")
|
||||
return [self._to_case(path, raw, defaults) for raw in items]
|
||||
|
||||
def _to_case(self, path: Path, raw: object, defaults: dict) -> V3Case:
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"Invalid case in {path}: expected object")
|
||||
merged_input = {**dict(defaults.get("input") or {}), **dict(raw.get("input") or {})}
|
||||
runner = self._normalize_runner(str(raw.get("runner") or defaults.get("runner") or "agent_runtime").strip())
|
||||
mode = str(raw.get("mode") or defaults.get("mode") or "").strip()
|
||||
case_id = str(raw.get("id") or "").strip()
|
||||
query = str(raw.get("query") or "").strip()
|
||||
if not case_id or not query or not runner or not mode:
|
||||
raise ValueError(f"Invalid case in {path}: `id`, `query`, `runner`, `mode` are required")
|
||||
expected = dict(raw.get("expected") or {})
|
||||
return V3Case(
|
||||
case_id=case_id,
|
||||
runner=runner, # type: ignore[arg-type]
|
||||
mode=mode, # type: ignore[arg-type]
|
||||
query=query,
|
||||
source_file=path,
|
||||
input=self._to_input(path.parent, merged_input),
|
||||
expectations=self._to_expectations(expected),
|
||||
notes=str(raw.get("notes") or ""),
|
||||
tags=tuple(str(item).strip() for item in raw.get("tags", []) if str(item).strip()),
|
||||
)
|
||||
|
||||
def _to_input(self, base_dir: Path, raw: dict) -> CaseInput:
|
||||
repo_path_raw = str(raw.get("repo_path") or "").strip()
|
||||
repo_path = None
|
||||
if repo_path_raw:
|
||||
repo_path = Path(repo_path_raw).expanduser()
|
||||
if not repo_path.is_absolute():
|
||||
repo_path = (base_dir / repo_path).resolve()
|
||||
return CaseInput(
|
||||
repo_path=repo_path,
|
||||
project_id=str(raw.get("project_id") or "").strip() or None,
|
||||
rag_session_id=str(raw.get("rag_session_id") or "").strip() or None,
|
||||
)
|
||||
|
||||
def _to_expectations(self, raw: dict) -> CaseExpectations:
|
||||
router = dict(raw.get("router") or {})
|
||||
retrieval = dict(raw.get("retrieval") or {})
|
||||
llm = dict(raw.get("llm") or {})
|
||||
pipeline = dict(raw.get("pipeline") or {})
|
||||
return CaseExpectations(
|
||||
router=RouterExpectation(
|
||||
intent=str(router.get("intent") or "").strip() or None,
|
||||
sub_intent=str(router.get("sub_intent") or "").strip() or None,
|
||||
graph_id=str(router.get("graph_id") or "").strip() or None,
|
||||
conversation_mode=str(router.get("conversation_mode") or "").strip() or None,
|
||||
),
|
||||
retrieval=RetrievalExpectation(
|
||||
non_empty=retrieval.get("non_empty"),
|
||||
min_rows=int(retrieval["min_rows"]) if retrieval.get("min_rows") is not None else None,
|
||||
direct_symbol_test_hits_max=(
|
||||
int(retrieval["direct_symbol_test_hits_max"])
|
||||
if retrieval.get("direct_symbol_test_hits_max") is not None
|
||||
else None
|
||||
),
|
||||
path_scope_contains=tuple(str(item) for item in retrieval.get("path_scope_contains") or []),
|
||||
symbol_candidates_contain=tuple(str(item) for item in retrieval.get("symbol_candidates_contain") or []),
|
||||
layers_include=tuple(str(item) for item in retrieval.get("layers_include") or []),
|
||||
),
|
||||
llm=LlmExpectation(
|
||||
non_empty=llm.get("non_empty"),
|
||||
contains_all=tuple(str(item) for item in llm.get("contains_all") or []),
|
||||
excludes=tuple(str(item) for item in llm.get("excludes") or []),
|
||||
),
|
||||
pipeline=PipelineExpectation(answer_mode=str(pipeline.get("answer_mode") or "").strip() or None),
|
||||
)
|
||||
|
||||
def _normalize_runner(self, value: str) -> str:
|
||||
if value in {"agent_runtime", "runtime", "code_qa_eval"}:
|
||||
return "agent_runtime"
|
||||
return value
|
||||
82
tests/pipeline_setup_v3/core/models.py
Normal file
82
tests/pipeline_setup_v3/core/models.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
|
||||
RunnerKind = Literal["agent_runtime"]
|
||||
ModeKind = Literal["router_only", "router_rag", "full_chain"]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class CaseInput:
|
||||
repo_path: Path | None = None
|
||||
project_id: str | None = None
|
||||
rag_session_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class RouterExpectation:
|
||||
intent: str | None = None
|
||||
sub_intent: str | None = None
|
||||
graph_id: str | None = None
|
||||
conversation_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class RetrievalExpectation:
|
||||
non_empty: bool | None = None
|
||||
min_rows: int | None = None
|
||||
direct_symbol_test_hits_max: int | None = None
|
||||
path_scope_contains: tuple[str, ...] = ()
|
||||
symbol_candidates_contain: tuple[str, ...] = ()
|
||||
layers_include: tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class LlmExpectation:
|
||||
non_empty: bool | None = None
|
||||
contains_all: tuple[str, ...] = ()
|
||||
excludes: tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class PipelineExpectation:
|
||||
answer_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class CaseExpectations:
|
||||
router: RouterExpectation = RouterExpectation()
|
||||
retrieval: RetrievalExpectation = RetrievalExpectation()
|
||||
llm: LlmExpectation = LlmExpectation()
|
||||
pipeline: PipelineExpectation = PipelineExpectation()
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class V3Case:
|
||||
case_id: str
|
||||
runner: RunnerKind
|
||||
mode: ModeKind
|
||||
query: str
|
||||
source_file: Path
|
||||
input: CaseInput = CaseInput()
|
||||
expectations: CaseExpectations = CaseExpectations()
|
||||
notes: str = ""
|
||||
tags: tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class ExecutionPayload:
|
||||
actual: dict
|
||||
details: dict
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class V3CaseResult:
|
||||
case: V3Case
|
||||
actual: dict
|
||||
details: dict
|
||||
passed: bool
|
||||
mismatches: list[str] = field(default_factory=list)
|
||||
52
tests/pipeline_setup_v3/core/runner.py
Normal file
52
tests/pipeline_setup_v3/core/runner.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from tests.pipeline_setup_v3.core.artifacts import ArtifactWriter
|
||||
from tests.pipeline_setup_v3.core.case_loader import CaseDirectoryLoader
|
||||
from tests.pipeline_setup_v3.core.models import V3CaseResult
|
||||
from tests.pipeline_setup_v3.core.session_provider import RagSessionProvider
|
||||
from tests.pipeline_setup_v3.core.validators import CaseValidator
|
||||
|
||||
|
||||
class V3Runner:
|
||||
def __init__(self, cases_dir: Path, results_dir: Path, run_name: str) -> None:
|
||||
self._cases_dir = cases_dir
|
||||
self._validator = CaseValidator()
|
||||
self._sessions = RagSessionProvider()
|
||||
self._agent_runtime = None
|
||||
self._writer = ArtifactWriter(results_dir, run_name=run_name, started_at=datetime.now())
|
||||
|
||||
@property
|
||||
def run_dir(self) -> Path:
|
||||
return self._writer.run_dir
|
||||
|
||||
def run(self) -> tuple[list[V3CaseResult], Path]:
|
||||
results: list[V3CaseResult] = []
|
||||
for case in CaseDirectoryLoader().load(self._cases_dir):
|
||||
rag_session_id = self._sessions.resolve(case.input)
|
||||
payload = self._execute(case, rag_session_id)
|
||||
mismatches = self._validator.validate(case, payload.actual, payload.details)
|
||||
result = V3CaseResult(
|
||||
case=case,
|
||||
actual=payload.actual,
|
||||
details=payload.details,
|
||||
passed=not mismatches,
|
||||
mismatches=mismatches,
|
||||
)
|
||||
self._writer.write_case(result)
|
||||
results.append(result)
|
||||
return results, self._writer.write_summary(results)
|
||||
|
||||
def _execute(self, case, rag_session_id):
|
||||
if case.runner != "agent_runtime":
|
||||
raise ValueError(f"Unsupported runner: {case.runner}")
|
||||
return self._agent_runtime_adapter().execute(case, rag_session_id)
|
||||
|
||||
def _agent_runtime_adapter(self):
|
||||
if self._agent_runtime is None:
|
||||
from tests.pipeline_setup_v3.runtime.agent_runtime_adapter import AgentRuntimeAdapter
|
||||
|
||||
self._agent_runtime = AgentRuntimeAdapter()
|
||||
return self._agent_runtime
|
||||
32
tests/pipeline_setup_v3/core/session_provider.py
Normal file
32
tests/pipeline_setup_v3/core/session_provider.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tests.pipeline_setup_v3.core.models import CaseInput
|
||||
|
||||
|
||||
class RagSessionProvider:
|
||||
def __init__(self) -> None:
|
||||
self._repository = None
|
||||
self._cache: dict[tuple[Path, str | None], str] = {}
|
||||
|
||||
def resolve(self, case_input: CaseInput) -> str | None:
|
||||
if case_input.rag_session_id:
|
||||
return case_input.rag_session_id
|
||||
if not case_input.repo_path:
|
||||
return None
|
||||
key = (case_input.repo_path, case_input.project_id)
|
||||
if key not in self._cache:
|
||||
self._cache[key] = self._build_indexer().index_repo(
|
||||
case_input.repo_path,
|
||||
project_id=case_input.project_id,
|
||||
)
|
||||
return self._cache[key]
|
||||
|
||||
def _build_indexer(self):
|
||||
from app.modules.rag.persistence.repository import RagRepository
|
||||
from tests.pipeline_setup.utils.rag_indexer import RagSessionIndexer
|
||||
|
||||
if self._repository is None:
|
||||
self._repository = RagRepository()
|
||||
return RagSessionIndexer(self._repository)
|
||||
68
tests/pipeline_setup_v3/core/test_validators.py
Normal file
68
tests/pipeline_setup_v3/core/test_validators.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tests.pipeline_setup_v3.core.models import CaseExpectations, LlmExpectation, RetrievalExpectation, RouterExpectation, V3Case
|
||||
from tests.pipeline_setup_v3.core.validators import CaseValidator
|
||||
|
||||
|
||||
def test_direct_symbol_test_hits_allows_related_context_when_symbol_not_in_rows() -> None:
|
||||
case = V3Case(
|
||||
case_id="negative-find-tests",
|
||||
runner="agent_runtime",
|
||||
mode="full_chain",
|
||||
query="Где тесты для WorkflowRuntimeFactory?",
|
||||
source_file=Path("cases.yaml"),
|
||||
expectations=CaseExpectations(
|
||||
router=RouterExpectation(intent="CODE_QA", sub_intent="FIND_TESTS"),
|
||||
retrieval=RetrievalExpectation(
|
||||
direct_symbol_test_hits_max=0,
|
||||
symbol_candidates_contain=("WorkflowRuntimeFactory",),
|
||||
),
|
||||
),
|
||||
)
|
||||
actual = {
|
||||
"intent": "CODE_QA",
|
||||
"sub_intent": "FIND_TESTS",
|
||||
"rag_count": 2,
|
||||
"symbol_candidates": ("WorkflowRuntimeFactory",),
|
||||
"layers": ("C2_DEPENDENCY_GRAPH",),
|
||||
}
|
||||
details = {
|
||||
"rag_rows": [
|
||||
{
|
||||
"path": "tests/test_business_control_actions.py",
|
||||
"title": "test_worker_wakes_up_with_configured_interval:dataflow_slice",
|
||||
"content": "ScenarioWorker.name",
|
||||
"metadata": {"is_test": True, "src_qname": "test_worker_wakes_up_with_configured_interval"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
assert CaseValidator().validate(case, actual, details) == []
|
||||
|
||||
|
||||
def test_llm_contains_and_excludes_checks() -> None:
|
||||
case = V3Case(
|
||||
case_id="llm-quality",
|
||||
runner="agent_runtime",
|
||||
mode="full_chain",
|
||||
query="Где health endpoint?",
|
||||
source_file=Path("cases.yaml"),
|
||||
expectations=CaseExpectations(
|
||||
router=RouterExpectation(intent="CODE_QA", sub_intent="FIND_ENTRYPOINTS"),
|
||||
llm=LlmExpectation(
|
||||
non_empty=True,
|
||||
contains_all=("GET /health",),
|
||||
excludes=("нет явных неподтвержденных кандидатов",),
|
||||
),
|
||||
),
|
||||
)
|
||||
actual = {
|
||||
"intent": "CODE_QA",
|
||||
"sub_intent": "FIND_ENTRYPOINTS",
|
||||
"rag_count": 3,
|
||||
"llm_answer": "GET /health объявлен в HttpControlAppFactory.",
|
||||
}
|
||||
|
||||
assert CaseValidator().validate(case, actual, {}) == []
|
||||
98
tests/pipeline_setup_v3/core/validators.py
Normal file
98
tests/pipeline_setup_v3/core/validators.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tests.pipeline_setup_v3.core.models import V3Case
|
||||
|
||||
|
||||
class CaseValidator:
|
||||
def validate(self, case: V3Case, actual: dict, details: dict | None = None) -> list[str]:
|
||||
mismatches: list[str] = []
|
||||
self._check(case.expectations.router.intent, actual.get("intent"), "intent", mismatches)
|
||||
self._check(case.expectations.router.sub_intent, actual.get("sub_intent"), "sub_intent", mismatches)
|
||||
self._check(case.expectations.router.graph_id, actual.get("graph_id"), "graph_id", mismatches)
|
||||
self._check(case.expectations.router.conversation_mode, actual.get("conversation_mode"), "conversation_mode", mismatches)
|
||||
self._check(case.expectations.pipeline.answer_mode, actual.get("answer_mode"), "answer_mode", mismatches)
|
||||
self._check_retrieval(case, actual, details or {}, mismatches)
|
||||
self._check_llm(case, actual, mismatches)
|
||||
return mismatches
|
||||
|
||||
def _check(self, expected, actual, label: str, mismatches: list[str]) -> None:
|
||||
if label == "answer_mode" and expected == "degraded" and actual == "not_found":
|
||||
return
|
||||
if expected is not None and expected != actual:
|
||||
mismatches.append(f"{label}: expected {expected}, got {actual}")
|
||||
|
||||
def _check_retrieval(self, case: V3Case, actual: dict, details: dict, mismatches: list[str]) -> None:
|
||||
rag_count = int(actual.get("rag_count") or 0)
|
||||
retrieval = case.expectations.retrieval
|
||||
if retrieval.non_empty is True and rag_count == 0:
|
||||
mismatches.append("retrieval: expected non-empty rag, got 0 rows")
|
||||
if retrieval.non_empty is False and rag_count != 0:
|
||||
mismatches.append(f"retrieval: expected empty rag, got {rag_count} rows")
|
||||
if retrieval.min_rows is not None and rag_count < retrieval.min_rows:
|
||||
mismatches.append(f"retrieval: expected at least {retrieval.min_rows} rows, got {rag_count}")
|
||||
if retrieval.direct_symbol_test_hits_max is not None:
|
||||
direct_hits = self._direct_symbol_test_hits(case, details)
|
||||
if direct_hits > retrieval.direct_symbol_test_hits_max:
|
||||
mismatches.append(
|
||||
f"retrieval: expected direct symbol-linked test hits <= {retrieval.direct_symbol_test_hits_max}, got {direct_hits}"
|
||||
)
|
||||
for path in retrieval.path_scope_contains:
|
||||
if not any(path in item for item in actual.get("path_scope") or ()):
|
||||
mismatches.append(f"path_scope should contain '{path}', got {list(actual.get('path_scope') or ())}")
|
||||
for symbol in retrieval.symbol_candidates_contain:
|
||||
if symbol not in (actual.get("symbol_candidates") or ()):
|
||||
mismatches.append(f"symbol_candidates should contain '{symbol}', got {list(actual.get('symbol_candidates') or ())}")
|
||||
for layer in retrieval.layers_include:
|
||||
if layer not in (actual.get("layers") or ()):
|
||||
mismatches.append(f"layers should include '{layer}', got {list(actual.get('layers') or ())}")
|
||||
|
||||
def _check_llm(self, case: V3Case, actual: dict, mismatches: list[str]) -> None:
|
||||
llm = case.expectations.llm
|
||||
answer = str(actual.get("llm_answer") or "").strip()
|
||||
if llm.non_empty and not answer:
|
||||
mismatches.append("llm: expected non-empty answer")
|
||||
if llm.non_empty is False and answer:
|
||||
mismatches.append("llm: expected empty answer")
|
||||
lowered = answer.lower()
|
||||
for value in llm.contains_all:
|
||||
if value.lower() not in lowered:
|
||||
mismatches.append(f"llm: expected answer to contain '{value}'")
|
||||
for value in llm.excludes:
|
||||
if value.lower() in lowered:
|
||||
mismatches.append(f"llm: expected answer to exclude '{value}'")
|
||||
|
||||
def _direct_symbol_test_hits(self, case: V3Case, details: dict) -> int:
|
||||
symbols = tuple(
|
||||
symbol.strip().lower()
|
||||
for symbol in case.expectations.retrieval.symbol_candidates_contain
|
||||
if symbol.strip()
|
||||
)
|
||||
if not symbols:
|
||||
return 0
|
||||
hits = 0
|
||||
for row in details.get("rag_rows") or []:
|
||||
if not self._is_test_row(row):
|
||||
continue
|
||||
haystacks = self._row_search_texts(row)
|
||||
if any(symbol in text for symbol in symbols for text in haystacks):
|
||||
hits += 1
|
||||
return hits
|
||||
|
||||
def _is_test_row(self, row: dict) -> bool:
|
||||
path = str(row.get("path") or "").lower()
|
||||
metadata = dict(row.get("metadata") or {})
|
||||
if bool(metadata.get("is_test")):
|
||||
return True
|
||||
return "/tests/" in f"/{path}" or path.startswith("tests/") or "/test_" in path or path.endswith("_test.py")
|
||||
|
||||
def _row_search_texts(self, row: dict) -> tuple[str, ...]:
|
||||
metadata = dict(row.get("metadata") or {})
|
||||
values = [
|
||||
str(row.get("title") or ""),
|
||||
str(row.get("content") or ""),
|
||||
str(metadata.get("qname") or ""),
|
||||
str(metadata.get("symbol_name") or ""),
|
||||
str(metadata.get("src_qname") or ""),
|
||||
str(metadata.get("dst_ref") or ""),
|
||||
]
|
||||
return tuple(value.lower() for value in values if value)
|
||||
Reference in New Issue
Block a user