from __future__ import annotations import re from app.modules.rag.intent_router_v2.models import IntentRouterResult def assert_intent(out: IntentRouterResult, expected: str) -> None: assert out.intent == expected def assert_domains(out: IntentRouterResult, expected: list[str]) -> None: assert out.retrieval_spec.domains == expected def assert_has_file_path(out: IntentRouterResult, path: str) -> None: assert any(anchor.type == "FILE_PATH" and anchor.value == path for anchor in out.query_plan.anchors) def assert_path_scope(out: IntentRouterResult, file_path: str, dir_path: str | None = None) -> None: scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or []) assert file_path in scope if dir_path is not None: assert dir_path in scope def assert_file_only_scope(out: IntentRouterResult, file_path: str) -> None: scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or []) assert scope == [file_path] def assert_spans_valid(out: IntentRouterResult) -> None: raw_len = len(out.query_plan.raw) for anchor in out.query_plan.anchors: if anchor.source == "conversation_state": assert anchor.span is None continue assert anchor.span is not None assert 0 <= anchor.span.start < anchor.span.end <= raw_len def assert_test_policy(out: IntentRouterResult, expected: str) -> None: assert getattr(out.retrieval_spec.filters, "test_policy", None) == expected def assert_sub_intent(out: IntentRouterResult, expected: str) -> None: assert out.query_plan.sub_intent == expected def assert_no_symbol_keyword(out: IntentRouterResult, forbidden: set[str] | None = None) -> None: denied = forbidden or {"def", "class", "return", "import", "from"} symbols = {anchor.value.lower() for anchor in out.query_plan.anchors if anchor.type == "SYMBOL"} assert symbols.isdisjoint({token.lower() for token in denied}) def assert_domain_layer_prefixes(out: IntentRouterResult) -> None: prefixes = {layer.layer_id[0] for layer in out.retrieval_spec.layer_queries if layer.layer_id} if out.retrieval_spec.domains == ["CODE"]: assert prefixes <= {"C"} elif out.retrieval_spec.domains == ["DOCS"]: assert prefixes <= {"D"} else: assert prefixes <= {"C", "D"} def assert_no_symbol_leakage_from_paths(out: IntentRouterResult) -> None: file_values = [anchor.value for anchor in out.query_plan.anchors if anchor.type == "FILE_PATH"] if not file_values: return parts: set[str] = set() for value in file_values: for token in re.split(r"[/.]+", value.lower()): if token: parts.add(token) for anchor in out.query_plan.anchors: if anchor.type == "SYMBOL": assert anchor.value.lower() not in parts