78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
|
|
from app.modules.rag.intent_router_v2.models import IntentRouterResult
|
|
|
|
|
|
def assert_intent(out: IntentRouterResult, expected: str) -> None:
|
|
assert out.intent == expected
|
|
|
|
|
|
def assert_domains(out: IntentRouterResult, expected: list[str]) -> None:
|
|
assert out.retrieval_spec.domains == expected
|
|
|
|
|
|
def assert_has_file_path(out: IntentRouterResult, path: str) -> None:
|
|
assert any(anchor.type == "FILE_PATH" and anchor.value == path for anchor in out.query_plan.anchors)
|
|
|
|
|
|
def assert_path_scope(out: IntentRouterResult, file_path: str, dir_path: str | None = None) -> None:
|
|
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
|
assert file_path in scope
|
|
if dir_path is not None:
|
|
assert dir_path in scope
|
|
|
|
|
|
def assert_file_only_scope(out: IntentRouterResult, file_path: str) -> None:
|
|
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
|
assert scope == [file_path]
|
|
|
|
|
|
def assert_spans_valid(out: IntentRouterResult) -> None:
|
|
raw_len = len(out.query_plan.raw)
|
|
for anchor in out.query_plan.anchors:
|
|
if anchor.source == "conversation_state":
|
|
assert anchor.span is None
|
|
continue
|
|
assert anchor.span is not None
|
|
assert 0 <= anchor.span.start < anchor.span.end <= raw_len
|
|
|
|
|
|
def assert_test_policy(out: IntentRouterResult, expected: str) -> None:
|
|
assert getattr(out.retrieval_spec.filters, "test_policy", None) == expected
|
|
|
|
|
|
def assert_sub_intent(out: IntentRouterResult, expected: str) -> None:
|
|
assert out.query_plan.sub_intent == expected
|
|
|
|
|
|
def assert_no_symbol_keyword(out: IntentRouterResult, forbidden: set[str] | None = None) -> None:
|
|
denied = forbidden or {"def", "class", "return", "import", "from"}
|
|
symbols = {anchor.value.lower() for anchor in out.query_plan.anchors if anchor.type == "SYMBOL"}
|
|
assert symbols.isdisjoint({token.lower() for token in denied})
|
|
|
|
|
|
def assert_domain_layer_prefixes(out: IntentRouterResult) -> None:
|
|
prefixes = {layer.layer_id[0] for layer in out.retrieval_spec.layer_queries if layer.layer_id}
|
|
if out.retrieval_spec.domains == ["CODE"]:
|
|
assert prefixes <= {"C"}
|
|
elif out.retrieval_spec.domains == ["DOCS"]:
|
|
assert prefixes <= {"D"}
|
|
else:
|
|
assert prefixes <= {"C", "D"}
|
|
|
|
|
|
def assert_no_symbol_leakage_from_paths(out: IntentRouterResult) -> None:
|
|
file_values = [anchor.value for anchor in out.query_plan.anchors if anchor.type == "FILE_PATH"]
|
|
if not file_values:
|
|
return
|
|
parts: set[str] = set()
|
|
for value in file_values:
|
|
for token in re.split(r"[/.]+", value.lower()):
|
|
if token:
|
|
parts.add(token)
|
|
for anchor in out.query_plan.anchors:
|
|
if anchor.type == "SYMBOL":
|
|
assert anchor.value.lower() not in parts
|