57 lines
2.9 KiB
Python
57 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from tests.pipeline_setup_v4.executors.process_v2_full_chain_executor import ProcessV2FullChainExecutor
|
|
from tests.pipeline_setup_v4.executors.process_v2_retrieval_policy_executor import ProcessV2RetrievalPolicyExecutor
|
|
from tests.pipeline_setup_v4.executors.process_v2_router_plus_policy_executor import ProcessV2RouterPlusPolicyExecutor
|
|
from tests.pipeline_setup_v4.executors.process_v2_router_plus_policy_rag_executor import (
|
|
ProcessV2RouterPlusPolicyRagExecutor,
|
|
)
|
|
from tests.pipeline_setup_v4.executors.process_v2_router_executor import ProcessV2IntentRouterExecutor
|
|
|
|
|
|
class ExecutorRegistry:
|
|
def __init__(self) -> None:
|
|
self._router_executor: ProcessV2IntentRouterExecutor | None = None
|
|
self._policy_executor: ProcessV2RetrievalPolicyExecutor | None = None
|
|
self._router_plus_policy_executor: ProcessV2RouterPlusPolicyExecutor | None = None
|
|
self._router_plus_policy_rag_executor: ProcessV2RouterPlusPolicyRagExecutor | None = None
|
|
self._full_chain_executor: ProcessV2FullChainExecutor | None = None
|
|
|
|
def execute(self, component: str, case) -> object:
|
|
if component == "process_v2_intent_router":
|
|
return self._router().execute(case)
|
|
if component == "process_v2_retrieval_policy_resolver":
|
|
return self._policy().execute(case)
|
|
if component == "process_v2_router_plus_retrieval_policy":
|
|
return self._router_plus_policy().execute(case)
|
|
if component == "process_v2_router_plus_retrieval_policy_rag":
|
|
return self._router_plus_policy_rag().execute(case)
|
|
if component == "process_v2_full_chain":
|
|
return self._full_chain().execute(case)
|
|
raise ValueError(f"Unsupported component: {component}")
|
|
|
|
def _router(self) -> ProcessV2IntentRouterExecutor:
|
|
if self._router_executor is None:
|
|
self._router_executor = ProcessV2IntentRouterExecutor()
|
|
return self._router_executor
|
|
|
|
def _policy(self) -> ProcessV2RetrievalPolicyExecutor:
|
|
if self._policy_executor is None:
|
|
self._policy_executor = ProcessV2RetrievalPolicyExecutor()
|
|
return self._policy_executor
|
|
|
|
def _router_plus_policy(self) -> ProcessV2RouterPlusPolicyExecutor:
|
|
if self._router_plus_policy_executor is None:
|
|
self._router_plus_policy_executor = ProcessV2RouterPlusPolicyExecutor()
|
|
return self._router_plus_policy_executor
|
|
|
|
def _router_plus_policy_rag(self) -> ProcessV2RouterPlusPolicyRagExecutor:
|
|
if self._router_plus_policy_rag_executor is None:
|
|
self._router_plus_policy_rag_executor = ProcessV2RouterPlusPolicyRagExecutor()
|
|
return self._router_plus_policy_rag_executor
|
|
|
|
def _full_chain(self) -> ProcessV2FullChainExecutor:
|
|
if self._full_chain_executor is None:
|
|
self._full_chain_executor = ProcessV2FullChainExecutor()
|
|
return self._full_chain_executor
|