diff --git a/plugins/external/opa/README.md b/plugins/external/opa/README.md index 6c89d9960..28e32e686 100644 --- a/plugins/external/opa/README.md +++ b/plugins/external/opa/README.md @@ -275,6 +275,21 @@ make test The`make test` command executes a complete testing workflow: it launches an OPA server using the policy file located at ./opaserver/rego/policy.rego (as specified by `POLICY_PATH`), runs all test cases against this server, and automatically terminates the OPA server process once testing finishes. +## Error Handling Verification + +1. `OPA_SERVER_NONE_RESPONSE` = "OPA server returned an empty response" +2. `OPA_SERVER_ERROR` = "Error while communicating with the OPA server" +3. `OPA_SERVER_UNCONFIGURED_ENDPOINT` = "Policy endpoint not configured on the OPA server" +4. `UNSPECIFIED_REQUIRED_PARAMS` = "Required parameters missing: policy config, payload, or hook type" +5. `UNSUPPORTED_HOOK_TYPE` = "Unsupported hook type (only tool, prompt, and resource are supported)" +6. `INVALID_POLICY_ENDPOINT` = "Policy endpoint must be curated with the supported hooktypes" +7. `UNSPECIFIED_POLICY_MODALITY` = "Unspecified policy modality. Picking up default modality: text" +8. `UNSUPPORTED_POLICY_MODALITY` = "Unsupported policy modality (Supports text, image and resource)" +9. `UNSPECIFIED_POLICY_PACKAGE_NAME` = "Unspecified policy package name" + +If OPA plugin encounters any of the errors above, it raises a PluginError. +The file `test_errors.py` includes unit tests to verify that these errors are correctly raised under the corresponding conditions. +Run it using `make test` ## License diff --git a/plugins/external/opa/opapluginfilter/plugin.py b/plugins/external/opa/opapluginfilter/plugin.py index 408153bed..c59fc5733 100644 --- a/plugins/external/opa/opapluginfilter/plugin.py +++ b/plugins/external/opa/opapluginfilter/plugin.py @@ -14,7 +14,6 @@ from urllib.parse import urlparse # Third-Party -from opapluginfilter.schema import BaseOPAInputKeys, OPAConfig, OPAInput import requests # First-Party @@ -22,30 +21,36 @@ Plugin, PluginConfig, PluginContext, + PluginError, + PluginErrorModel, PluginViolation, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, PromptPrehookResult, + PromptHookType, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, ResourcePreFetchResult, + ResourceHookType, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, ToolPreInvokeResult, + ToolHookType, ) from mcpgateway.plugins.framework.models import AppliedTo from mcpgateway.services.logging_service import LoggingService +from opapluginfilter.schema import BaseOPAInputKeys, OPAConfig, OPAInput # Initialize logging service first logging_service = LoggingService() logger = logging_service.get_logger(__name__) -class OPACodes(str, Enum): - """OPACodes implementation.""" +class OPAPluginCodes(str, Enum): + """OPAPluginCodes implementation.""" ALLOW_CODE = "ALLOW" DENIAL_CODE = "DENY" @@ -53,13 +58,27 @@ class OPACodes(str, Enum): REQUIRES_HUMAN_APPROVAL_CODE = "REQUIRES_APPROVAL" -class OPAResponseTemplates(str, Enum): - """OPAResponseTemplates implementation.""" +class OPAPluginResponseTemplates(str, Enum): + """OPAPluginResponseTemplates implementation.""" OPA_REASON = "OPA policy denied for {hook_type}" OPA_DESC = "{hook_type} not allowed" +class OPAPluginErrorCodes(str, Enum): + """OPA plugin error codes or reasons when raising plugin error""" + + OPA_SERVER_NONE_RESPONSE = "OPA server returned an empty response" + OPA_SERVER_ERROR = "Error while communicating with the OPA server" + OPA_SERVER_UNCONFIGURED_ENDPOINT = "Policy endpoint not configured on the OPA server" + UNSPECIFIED_REQUIRED_PARAMS = "Required parameters missing: policy config, payload, or hook type" + UNSUPPORTED_HOOK_TYPE = "Unsupported hook type (only tool, prompt, and resource are supported)" + INVALID_POLICY_ENDPOINT = "Policy endpoint must be curated with the supported hooktypes" + UNSPECIFIED_POLICY_MODALITY = "Unspecified policy modality. Picking up default modality: text" + UNSUPPORTED_POLICY_MODALITY = "Unsupported policy modality (Supports text, image and resource)" + UNSPECIFIED_POLICY_PACKAGE_NAME = "Unspecified policy package name" + + HookPayload: TypeAlias = ToolPreInvokePayload | ToolPostInvokePayload | PromptPosthookPayload | PromptPrehookPayload | ResourcePreFetchPayload | ResourcePostFetchPayload @@ -75,6 +94,7 @@ def __init__(self, config: PluginConfig): super().__init__(config) self.opa_config = OPAConfig.model_validate(self._config.config) self.opa_context_key = "opa_policy_context" + logger.info(f"OPAPluginFilter initialised with configuraiton {self.opa_config}") def _get_nested_value(self, data, key_string, default=None): """ @@ -128,8 +148,12 @@ def _key(k: str, m: str) -> str: payload = {"input": {m: self._get_nested_value(input.model_dump()["input"], _key(k, m)) for k, m in policy_input_data_map.items()}} if policy_input_data_map else input.model_dump() logger.info(f"OPA url {url}, OPA payload {payload}") - rsp = requests.post(url, json=payload) - logger.info(f"OPA connection response '{rsp}'") + try: + rsp = requests.post(url, json=payload) + logger.info(f"OPA connection response '{rsp}'") + except Exception as e: + logger.error(f"{OPAPluginErrorCodes.OPA_SERVER_ERROR.value}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.OPA_SERVER_ERROR.value, plugin_name="OPAPluginFilter", details={"reason": str(e)})) if rsp.status_code == 200: json_response = rsp.json() decision = json_response.get("result", None) @@ -142,10 +166,12 @@ def _key(k: str, m: str) -> str: logger.debug(f"OPA decision {allow}") return allow, json_response else: - logger.debug(f"OPA sent a none response {json_response}") + logger.error(f"{OPAPluginErrorCodes.OPA_SERVER_NONE_RESPONSE.value} : {json_response}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.OPA_SERVER_NONE_RESPONSE.value, plugin_name="OPAPluginFilter", details={"reason": json_response})) + else: - logger.debug(f"OPA error: {rsp}") - return True, None + logger.error(f"{OPAPluginErrorCodes.OPA_SERVER_ERROR.value}: {rsp}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.OPA_SERVER_ERROR.value, plugin_name="OPAPluginFilter", details={"reason": rsp})) def _preprocess_opa(self, policy_apply_config: AppliedTo = None, payload: HookPayload = None, context: PluginContext = None, hook_type: str = None) -> dict: """Function to preprocess input for OPA server based on the type of hook it's invoked on. @@ -160,11 +186,11 @@ def _preprocess_opa(self, policy_apply_config: AppliedTo = None, payload: HookPa dict: if a valid policy_apply_config, payload and hook_type, otherwise returns dictionary with none values """ - result = {"opa_server_url": None, "policy_context": None, "policy_input_data_map": None, "policy_modality": None} + result = {"opa_server_url": None, "policy_context": None, "policy_input_data_map": None, "policy_modality": None, "policy_apply": None} if not (policy_apply_config and payload and hook_type): - logger.error(f"Unspecified required: {policy_apply_config} and payload: {payload} and hook_type: {hook_type}") - return result + logger.error(f"{OPAPluginErrorCodes.UNSPECIFIED_REQUIRED_PARAMS.value} {policy_apply_config} and payload: {payload} and hook_type: {hook_type}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.UNSPECIFIED_REQUIRED_PARAMS.value, plugin_name="OPAPluginFilter")) input_context = [] policy_context = {} @@ -173,6 +199,7 @@ def _preprocess_opa(self, policy_apply_config: AppliedTo = None, payload: HookPa policy_input_data_map = {} policy_modality = None hook_name = None + policy_apply = False if policy_apply_config: if "tool" in hook_type and policy_apply_config.tools: @@ -182,8 +209,7 @@ def _preprocess_opa(self, policy_apply_config: AppliedTo = None, payload: HookPa elif "resource" in hook_type and policy_apply_config.resources: hook_info = policy_apply_config.resources else: - logger.error("The hooks should belong to either of the following: tool, prompts and resources") - return result + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.UNSUPPORTED_HOOK_TYPE.value, plugin_name="OPAPluginFilter")) for hook in hook_info: if "tool" in hook_type: @@ -191,35 +217,53 @@ def _preprocess_opa(self, policy_apply_config: AppliedTo = None, payload: HookPa payload_name = payload.name elif "prompt" in hook_type: hook_name = hook.prompt_name - payload_name = payload.name + payload_name = payload.prompt_id elif "resource" in hook_type: hook_name = hook.resource_uri payload_name = payload.uri else: - logger.error("The hooks should belong to either of the following: tool, prompts and resources") - return result + logger.error(f"{OPAPluginErrorCodes.UNSUPPORTED_HOOK_TYPE.value}: {hook}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.UNSUPPORTED_HOOK_TYPE.value, plugin_name="OPAPluginFilter")) if payload_name == hook_name or hook_name in payload_name: + policy_apply = True if hook.context: input_context = [ctx.rsplit(".", 1)[-1] for ctx in hook.context] if self.opa_context_key in context.global_context.state: policy_context = {k: context.global_context.state[self.opa_context_key][k] for k in input_context} if hook.extensions: - policy = hook.extensions.get("policy") + policy = hook.extensions.get("policy", None) + if not policy: + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.UNSPECIFIED_POLICY_PACKAGE_NAME.value, plugin_name="OPAPluginFilter")) policy_endpoints = hook.extensions.get("policy_endpoints", []) policy_input_data_map = hook.extensions.get("policy_input_data_map", {}) - policy_modality = hook.extensions.get("policy_modality", ["text"]) + if "policy_modality" not in hook.extensions: + logger.error(f"{OPAPluginErrorCodes.UNSPECIFIED_POLICY_MODALITY.value}") + policy_modality = hook.extensions.get("policy_modality", ["text"]) + else: + policy_modality = hook.extensions.get("policy_modality", ["text"]) + all_hook_types = [hook.value for hook in ToolHookType] + [hook.value for hook in PromptHookType] + [hook.value for hook in ResourceHookType] + all_hook_flag = 0 + for hook in all_hook_types: + for endpoint in policy_endpoints: + if hook in endpoint: + all_hook_flag += 1 + if len(policy_endpoints) != all_hook_flag: + if "allow" not in policy_endpoints: + raise PluginError( + PluginErrorModel(message=OPAPluginErrorCodes.INVALID_POLICY_ENDPOINT, plugin_name="OPAPluginFilter", details={"reason": f"Supported hook type: {all_hook_types}"}) + ) if policy_endpoints: policy_endpoint = next((endpoint for endpoint in policy_endpoints if hook_type in endpoint), "allow") - - if not policy_endpoint: - logger.debug(f"Unconfigured endpoint for policy {hook_type} {hook_name} invocation") - return result + else: + logger.error(f"{OPAPluginErrorCodes.OPA_SERVER_UNCONFIGURED_ENDPOINT.value} {hook_type} {hook_name} invocation") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.OPA_SERVER_UNCONFIGURED_ENDPOINT.value, plugin_name="OPAPluginFilter")) result["policy_context"] = policy_context result["opa_server_url"] = "{opa_url}{policy}/{policy_endpoint}".format(opa_url=self.opa_config.opa_base_url, policy=policy, policy_endpoint=policy_endpoint) result["policy_input_data_map"] = policy_input_data_map result["policy_modality"] = policy_modality + result["policy_apply"] = policy_apply return result def _extract_payload_key(self, content: Any = None, key: str = None, result: dict[str, list] = None) -> None: @@ -235,15 +279,22 @@ def _extract_payload_key(self, content: Any = None, key: str = None, result: dic for element in content: if isinstance(element, dict) and key in element: self._extract_payload_key(element, key, result) + else: + logger.error(f"{OPAPluginErrorCodes.UNSUPPORTED_POLICY_MODALITY.value}: {type(content)}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.UNSUPPORTED_POLICY_MODALITY.value, plugin_name="OPAPluginFilter")) elif isinstance(content, dict): if key in content or hasattr(content, key): result[key].append(content[key]) + else: + logger.error(f"{OPAPluginErrorCodes.UNSUPPORTED_POLICY_MODALITY.value}: {type(content)}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.UNSUPPORTED_POLICY_MODALITY.value, plugin_name="OPAPluginFilter")) elif isinstance(content, str): result[key].append(content) elif hasattr(content, key): result[key].append(getattr(content, key)) else: - logger.error(f"Can't handle content of {type(content)}") + logger.error(f"{OPAPluginErrorCodes.UNSUPPORTED_POLICY_MODALITY.value}: {type(content)}") + raise PluginError(PluginErrorModel(message=OPAPluginErrorCodes.UNSUPPORTED_POLICY_MODALITY.value, plugin_name="OPAPluginFilter")) async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: """OPA Plugin hook run before a prompt is fetched. This hook takes in payload and context and further evaluates rego @@ -257,8 +308,8 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC The result of the plugin's analysis, including whether prompt input could proceed further. """ - hook_type = "prompt_pre_fetch" - logger.info(f"Processing {hook_type} for '{payload.name}' with {len(payload.args) if payload.args else 0} arguments") + hook_type = PromptHookType.PROMPT_PRE_FETCH.value + logger.info(f"Processing {hook_type} for '{payload.prompt_id}' with {len(payload.args) if payload.args else 0} arguments") logger.info(f"Processing context {context}") if not payload.args: @@ -267,16 +318,16 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginC policy_apply_config = self._config.applied_to if policy_apply_config and policy_apply_config.prompts: opa_pre_prompt_input = self._preprocess_opa(policy_apply_config, payload, context, hook_type) - if not all(v is None for v in opa_pre_prompt_input.values()): - opa_input = BaseOPAInputKeys(kind="post_tool", user="none", payload=payload.model_dump(), context=opa_pre_prompt_input["policy_context"], request_ip="none", headers={}, mode="input") + if opa_pre_prompt_input["policy_apply"]: + opa_input = BaseOPAInputKeys(kind=hook_type, user="none", payload=payload.model_dump(), context=opa_pre_prompt_input["policy_context"], request_ip="none", headers={}, mode="input") decision, decision_context = self._evaluate_opa_policy( url=opa_pre_prompt_input["opa_server_url"], input=OPAInput(input=opa_input), policy_input_data_map=opa_pre_prompt_input["policy_input_data_map"] ) if not decision: violation = PluginViolation( - reason=OPAResponseTemplates.OPA_REASON.format(hook_type=hook_type), - description=OPAResponseTemplates.OPA_DESC.format(hook_type=hook_type), - code=OPACodes.DENIAL_CODE, + reason=OPAPluginResponseTemplates.OPA_REASON.format(hook_type=hook_type), + description=OPAPluginResponseTemplates.OPA_DESC.format(hook_type=hook_type), + code=OPAPluginCodes.DENIAL_CODE, details=decision_context, ) return PromptPrehookResult(modified_payload=payload, violation=violation, continue_processing=False) @@ -294,7 +345,7 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi The result of the plugin's analysis, including whether prompt result could proceed further. """ - hook_type = "prompt_post_fetch" + hook_type = PromptHookType.PROMPT_POST_FETCH.value logger.info(f"Processing {hook_type} for '{payload.result}'") logger.info(f"Processing context {context}") @@ -304,14 +355,13 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi policy_apply_config = self._config.applied_to if policy_apply_config and policy_apply_config.prompts: opa_post_prompt_input = self._preprocess_opa(policy_apply_config, payload, context, hook_type) - policy_modality = opa_post_prompt_input.get("policy_modality") if opa_post_prompt_input else None - if opa_post_prompt_input and policy_modality: - result = dict.fromkeys(policy_modality, []) + if opa_post_prompt_input["policy_apply"]: + result = dict.fromkeys(opa_post_prompt_input["policy_modality"], []) if hasattr(payload.result, "messages") and isinstance(payload.result.messages, list): for message in payload.result.messages: if hasattr(message, "content"): - for key in policy_modality: + for key in opa_post_prompt_input["policy_modality"]: self._extract_payload_key(message.content, key, result) opa_input = BaseOPAInputKeys(kind=hook_type, user="none", payload=result, context=opa_post_prompt_input["policy_context"], request_ip="none", headers={}, mode="output") @@ -320,9 +370,9 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi ) if not decision: violation = PluginViolation( - reason=OPAResponseTemplates.OPA_REASON.format(hook_type=hook_type), - description=OPAResponseTemplates.OPA_DESC.format(hook_type=hook_type), - code=OPACodes.DENIAL_CODE, + reason=OPAPluginResponseTemplates.OPA_REASON.format(hook_type=hook_type), + description=OPAPluginResponseTemplates.OPA_DESC.format(hook_type=hook_type), + code=OPAPluginCodes.DENIAL_CODE, details=decision_context, ) return PromptPosthookResult(modified_payload=payload, violation=violation, continue_processing=False) @@ -340,7 +390,7 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo The result of the plugin's analysis, including whether the tool can proceed. """ - hook_type = "tool_pre_invoke" + hook_type = ToolHookType.TOOL_PRE_INVOKE.value logger.info(f"Processing {hook_type} for '{payload.name}' with {len(payload.args) if payload.args else 0} arguments") logger.info(f"Processing context {context}") @@ -348,18 +398,19 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo return ToolPreInvokeResult() policy_apply_config = self._config.applied_to + logger.info(f"policy_apply_config {policy_apply_config}") if policy_apply_config and policy_apply_config.tools: opa_pre_tool_input = self._preprocess_opa(policy_apply_config, payload, context, hook_type) - if opa_pre_tool_input: + if opa_pre_tool_input["policy_apply"]: opa_input = BaseOPAInputKeys(kind=hook_type, user="none", payload=payload.model_dump(), context=opa_pre_tool_input["policy_context"], request_ip="none", headers={}, mode="input") decision, decision_context = self._evaluate_opa_policy( url=opa_pre_tool_input["opa_server_url"], input=OPAInput(input=opa_input), policy_input_data_map=opa_pre_tool_input["policy_input_data_map"] ) if not decision: violation = PluginViolation( - reason=OPAResponseTemplates.OPA_REASON.format(hook_type=hook_type), - description=OPAResponseTemplates.OPA_DESC.format(hook_type=hook_type), - code=OPACodes.DENIAL_CODE, + reason=OPAPluginResponseTemplates.OPA_REASON.format(hook_type=hook_type), + description=OPAPluginResponseTemplates.OPA_DESC.format(hook_type=hook_type), + code=OPAPluginCodes.DENIAL_CODE, details=decision_context, ) return ToolPreInvokeResult(modified_payload=payload, violation=violation, continue_processing=False) @@ -377,22 +428,22 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin The result of the plugin's analysis, including whether the tool result should proceed. """ - hook_type = "tool_post_invoke" + hook_type = ToolHookType.TOOL_POST_INVOKE.value logger.info(f"Processing {hook_type} for '{payload.result}' with {len(payload.result) if payload.result else 0}") logger.info(f"Processing context {context}") if not payload.result: return ToolPostInvokeResult() policy_apply_config = self._config.applied_to + if policy_apply_config and policy_apply_config.tools: opa_post_tool_input = self._preprocess_opa(policy_apply_config, payload, context, hook_type) - policy_modality = opa_post_tool_input.get("policy_modality") if opa_post_tool_input else None - if opa_post_tool_input and policy_modality: - result = dict.fromkeys(policy_modality, []) + if opa_post_tool_input["policy_apply"]: + result = dict.fromkeys(opa_post_tool_input["policy_modality"], []) if isinstance(payload.result, dict): content = payload.result["content"] if "content" in payload.result else payload.result - for key in policy_modality: + for key in opa_post_tool_input["policy_modality"]: self._extract_payload_key(content, key, result) opa_input = BaseOPAInputKeys(kind=hook_type, user="none", payload=result, context=opa_post_tool_input["policy_context"], request_ip="none", headers={}, mode="output") @@ -401,9 +452,9 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin ) if not decision: violation = PluginViolation( - reason=OPAResponseTemplates.OPA_REASON.format(hook_type=hook_type), - description=OPAResponseTemplates.OPA_DESC.format(hook_type=hook_type), - code=OPACodes.DENIAL_CODE, + reason=OPAPluginResponseTemplates.OPA_REASON.format(hook_type=hook_type), + description=OPAPluginResponseTemplates.OPA_DESC.format(hook_type=hook_type), + code=OPAPluginCodes.DENIAL_CODE, details=decision_context, ) return ToolPostInvokeResult(modified_payload=payload, violation=violation, continue_processing=False) @@ -424,7 +475,7 @@ async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: Pl if not payload.uri: return ResourcePreFetchResult() - hook_type = "resource_pre_fetch" + hook_type = ResourceHookType.RESOURCE_PRE_FETCH.value logger.info(f"Processing {hook_type} for '{payload.uri}'") logger.info(f"Processing context {context}") @@ -442,16 +493,16 @@ async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: Pl policy_apply_config = self._config.applied_to if policy_apply_config and policy_apply_config.resources: opa_pre_resource_input = self._preprocess_opa(policy_apply_config, payload, context, hook_type) - if not all(v is None for v in opa_pre_resource_input.values()): + if opa_pre_resource_input["policy_apply"]: opa_input = BaseOPAInputKeys(kind=hook_type, user="none", payload=payload.model_dump(), context=opa_pre_resource_input["policy_context"], request_ip="none", headers={}, mode="input") decision, decision_context = self._evaluate_opa_policy( url=opa_pre_resource_input["opa_server_url"], input=OPAInput(input=opa_input), policy_input_data_map=opa_pre_resource_input["policy_input_data_map"] ) if not decision: violation = PluginViolation( - reason=OPAResponseTemplates.OPA_REASON.format(hook_type=hook_type), - description=OPAResponseTemplates.OPA_DESC.format(hook_type=hook_type), - code=OPACodes.DENIAL_CODE, + reason=OPAPluginResponseTemplates.OPA_REASON.format(hook_type=hook_type), + description=OPAPluginResponseTemplates.OPA_DESC.format(hook_type=hook_type), + code=OPAPluginCodes.DENIAL_CODE, details=decision_context, ) return ResourcePreFetchResult(modified_payload=payload, violation=violation, continue_processing=False) @@ -472,17 +523,16 @@ async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: if not payload.content or not payload.uri: return ResourcePostFetchResult() - hook_type = "resource_post_fetch" + hook_type = ResourceHookType.RESOURCE_POST_FETCH.value logger.info(f"Processing {hook_type} for '{payload.content}' and uri {payload.uri}") logger.info(f"Processing context {context}") policy_apply_config = self._config.applied_to if policy_apply_config and policy_apply_config.resources: opa_post_resource_input = self._preprocess_opa(policy_apply_config, payload, context, hook_type) - policy_modality = opa_post_resource_input.get("policy_modality") if opa_post_resource_input else None - if not all(v is None for v in opa_post_resource_input.values()) and policy_modality: - result = dict.fromkeys(policy_modality, []) - for key in policy_modality: + if opa_post_resource_input["policy_apply"]: + result = dict.fromkeys(opa_post_resource_input["policy_modality"], []) + for key in opa_post_resource_input["policy_modality"]: if hasattr(payload.content, key): self._extract_payload_key(payload.content, key, result) @@ -492,9 +542,9 @@ async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: ) if not decision: violation = PluginViolation( - reason=OPAResponseTemplates.OPA_REASON.format(hook_type=hook_type), - description=OPAResponseTemplates.OPA_DESC.format(hook_type=hook_type), - code=OPACodes.DENIAL_CODE, + reason=OPAPluginResponseTemplates.OPA_REASON.format(hook_type=hook_type), + description=OPAPluginResponseTemplates.OPA_DESC.format(hook_type=hook_type), + code=OPAPluginCodes.DENIAL_CODE, details=decision_context, ) return ResourcePostFetchResult(modified_payload=payload, violation=violation, continue_processing=False) diff --git a/plugins/external/opa/tests/test_all.py b/plugins/external/opa/tests/test_all.py index 71cbca5c8..912ee02bd 100644 --- a/plugins/external/opa/tests/test_all.py +++ b/plugins/external/opa/tests/test_all.py @@ -8,11 +8,13 @@ import pytest # First-Party -from mcpgateway.common.models import Message, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, ResourceContent, Role, TextContent, PromptResult from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, - PluginResult, + ToolHookType, + PromptHookType, + ResourceHookType, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -43,9 +45,9 @@ async def test_prompt_pre_hook(plugin_manager: PluginManager): plugin_manager: The plugin manager instance. """ # Customize payload for testing - payload = PromptPrehookPayload(name="test_prompt", args={"arg0": "This is an argument"}) + payload = PromptPrehookPayload(prompt_id="test_prompt", args={"arg0": "This is an argument"}) global_context = GlobalContext(request_id="1") - result, _ = await plugin_manager.prompt_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing @@ -59,10 +61,10 @@ async def test_prompt_post_hook(plugin_manager: PluginManager): """ # Customize payload for testing message = Message(content=TextContent(type="text", text="prompt"), role=Role.USER) - prompt_result = PluginResult(messages=[message]) - payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) + prompt_result = PromptResult(messages=[message]) + payload = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) global_context = GlobalContext(request_id="1") - result, _ = await plugin_manager.prompt_post_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing @@ -77,7 +79,7 @@ async def test_tool_pre_hook(plugin_manager: PluginManager): # Customize payload for testing payload = ToolPreInvokePayload(name="test_prompt", args={"arg0": "This is an argument"}) global_context = GlobalContext(request_id="1") - result, _ = await plugin_manager.tool_pre_invoke(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing @@ -92,7 +94,7 @@ async def test_tool_post_hook(plugin_manager: PluginManager): # Customize payload for testing payload = ToolPostInvokePayload(name="test_tool", result={"output0": "output value"}) global_context = GlobalContext(request_id="1") - result, _ = await plugin_manager.tool_post_invoke(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing @@ -107,7 +109,7 @@ async def test_resource_pre_hook(plugin_manager: PluginManager): # Customize payload for testing payload = ResourcePreFetchPayload(uri="https://test_resource.com", metadata={}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await plugin_manager.tool_post_invoke(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing @@ -124,9 +126,10 @@ async def test_resource_post_hook(plugin_manager: PluginManager): type="resource", uri="test://resource", text="test://test_resource.com", + id="1" ) payload = ResourcePostFetchPayload(uri="https://example.com", content=content) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await plugin_manager.resource_post_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(ResourceHookType.RESOURCE_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing diff --git a/plugins/external/opa/tests/test_errors.py b/plugins/external/opa/tests/test_errors.py new file mode 100644 index 000000000..de1cb6c84 --- /dev/null +++ b/plugins/external/opa/tests/test_errors.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +"""Test cases for OPA plugin + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Shriti Priya + +This module contains test cases for running opa plugin. Here, the OPA server is scoped under session fixture, +and started once, and further used by all test cases for policy evaluations. +""" + +# Standard + +# Third-Party +import pytest + +# First-Party +from mcpgateway.plugins.framework import ( + GlobalContext, + PluginConfig, + PluginContext, + ToolPostInvokePayload, + ToolPreInvokePayload, + PluginError +) + +from mcpgateway.services.logging_service import LoggingService +from opapluginfilter.plugin import OPAPluginFilter, OPAPluginErrorCodes + +logging_service = LoggingService() +logger = logging_service.get_logger(__name__) + + +@pytest.mark.asyncio +# Check for OPA Server returning none response +async def test_error_opa_server_error(): + """Test that validates opa plugin applied on pre tool invocation is working successfully. Evaluates for both malign and benign cases""" + config = { + "tools": [ + { + "tool_name": "fast-time-git-status", + "extensions": { + "policy": "example", + "policy_endpoints": [ + "allow_tool_pre_invoke", + ], + "policy_modality": ["text"], + }, + } + ] + } + + incorrect_opa_url = "http://127.0.0.1:3000/v1/data/" + config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": incorrect_opa_url}, applied_to=config) + plugin = OPAPluginFilter(config) + payload = ToolPreInvokePayload(name="fast-time-git-status", args={"repo_path": "/path/IBM"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + try: + await plugin.tool_pre_invoke(payload, context) + except PluginError as e: + assert e.error.message == OPAPluginErrorCodes.OPA_SERVER_ERROR.value + + +@pytest.mark.asyncio +# Test for when opaplugin is configured with invalid endpoint +async def test_error_opa_server_invalid_endpoint(): + """Test that validates opa plugin applied on pre tool invocation is working successfully. Evaluates for both malign and benign cases""" + config = { + "tools": [ + { + "tool_name": "fast-time-git-status", + "extensions": { + "policy": "example", + "policy_endpoints": [ + "allow_x_invoke", + ], + "policy_modality": ["text"], + }, + } + ] + } + config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}, applied_to=config) + plugin = OPAPluginFilter(config) + + # Benign payload (allowed by OPA (rego) policy) + payload = ToolPreInvokePayload(name="fast-time-git-status", args={"repo_path": "/path/IBM"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + try: + await plugin.tool_pre_invoke(payload, context) + except PluginError as e: + assert e.error.message == OPAPluginErrorCodes.INVALID_POLICY_ENDPOINT.value + + +@pytest.mark.asyncio +# Test for when opaplugin opa server sends none response +async def test_error_opa_server_none_response(): + """Test that validates opa plugin applied on pre tool invocation is working successfully. Evaluates for both malign and benign cases""" + config = { + "tools": [ + { + "tool_name": "fast-time-git-status", + "extensions": { + "policy": "example1", + "policy_endpoints": [ + "allow_tool_pre_invoke", + ], + "policy_modality": ["text"], + }, + } + ] + } + config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}, applied_to=config) + plugin = OPAPluginFilter(config) + + # Benign payload (allowed by OPA (rego) policy) + payload = ToolPreInvokePayload(name="fast-time-git-status", args={"repo_path": "/path/IBM"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + try: + await plugin.tool_pre_invoke(payload, context) + except PluginError as e: + assert e.error.message == OPAPluginErrorCodes.OPA_SERVER_NONE_RESPONSE.value + + +@pytest.mark.asyncio +# Test for when opaplugin is configured with no policy endpoint +async def test_error_opa_server_unconfigured_endpoint(): + """Test that validates opa plugin applied on pre tool invocation is working successfully. Evaluates for both malign and benign cases""" + config = { + "tools": [ + { + "tool_name": "fast-time-git-status", + "extensions": { + "policy": "example", + "policy_modality": ["text"], + }, + } + ] + } + config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}, applied_to=config) + plugin = OPAPluginFilter(config) + + # Benign payload (allowed by OPA (rego) policy) + payload = ToolPreInvokePayload(name="fast-time-git-status", args={"repo_path": "/path/IBM"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + try: + await plugin.tool_pre_invoke(payload, context) + except PluginError as e: + assert e.error.message == OPAPluginErrorCodes.OPA_SERVER_UNCONFIGURED_ENDPOINT.value + + +@pytest.mark.asyncio +# Test for when opaplugin if not supported policy modality location has been used in configuration +async def test_error_opa_server_unsupported_modality(): + """Test that validates opa plugin applied on pre tool invocation is working successfully. Evaluates for both malign and benign cases""" + config = { + "tools": [ + { + "tool_name": "fast-time-git-status", + "extensions": { + "policy": "example", + "policy_endpoints": [ + "allow_tool_post_invoke", + ], + "policy_modality": ["location"], + }, + } + ] + } + config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}, applied_to=config) + plugin = OPAPluginFilter(config) + + # Benign payload (allowed by OPA (rego) policy) + payload = ToolPostInvokePayload(name="fast-time-git-status", result={"text": "IBM@example.com"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + try: + await plugin.tool_post_invoke(payload, context) + except PluginError as e: + assert e.error.message == OPAPluginErrorCodes.UNSUPPORTED_POLICY_MODALITY.value + + +@pytest.mark.asyncio +# Test for when opaplugin has not been configured with a policy modality. The expected behavior is to pick up default policy modality as text +async def test_error_opa_server_unspecified_policy_modality(): + """Test that validates opa plugin applied on pre tool invocation is working successfully. Evaluates for both malign and benign cases""" + config = { + "tools": [ + { + "tool_name": "fast-time-git-status", + "extensions": { + "policy": "example", + "policy_endpoints": [ + "allow_tool_post_invoke", + ], + }, + } + ] + } + config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}, applied_to=config) + plugin = OPAPluginFilter(config) + + # Benign payload (allowed by OPA (rego) policy) + payload = ToolPostInvokePayload(name="fast-time-git-status", result={"text": "IBM@example.com"}) + context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) + result = await plugin.tool_post_invoke(payload, context) + assert not result.continue_processing diff --git a/plugins/external/opa/tests/test_opapluginfilter.py b/plugins/external/opa/tests/test_opapluginfilter.py index 075d9e54f..309c28732 100644 --- a/plugins/external/opa/tests/test_opapluginfilter.py +++ b/plugins/external/opa/tests/test_opapluginfilter.py @@ -12,16 +12,14 @@ # Standard # Third-Party -from opapluginfilter.plugin import OPAPluginFilter import pytest # First-Party -from mcpgateway.common.models import Message, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, ResourceContent, Role, TextContent, PromptResult from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, - PluginResult, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -29,7 +27,9 @@ ToolPostInvokePayload, ToolPreInvokePayload, ) + from mcpgateway.services.logging_service import LoggingService +from opapluginfilter.plugin import OPAPluginFilter logging_service = LoggingService() logger = logging_service.get_logger(__name__) @@ -125,13 +125,13 @@ async def test_pre_prompt_fetch_opapluginfilter(): plugin = OPAPluginFilter(config) # Benign payload (allowed by OPA (rego) policy) - payload = PromptPrehookPayload(name="test_prompt", args={"text": "You are curseword"}) + payload = PromptPrehookPayload(prompt_id="test_prompt", args={"text": "You are curseword"}) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_pre_fetch(payload, context) assert result.continue_processing # Malign payload (denied by OPA (rego) policy) - payload = PromptPrehookPayload(name="test_prompt", args={"text": "You are curseword1"}) + payload = PromptPrehookPayload(prompt_id="test_prompt", args={"text": "You are curseword1"}) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_pre_fetch(payload, context) assert not result.continue_processing @@ -160,16 +160,16 @@ async def test_post_prompt_fetch_opapluginfilter(): # Benign payload (allowed by OPA (rego) policy) message = Message(content=TextContent(type="text", text="abc"), role=Role.USER) - prompt_result = PluginResult(messages=[message]) - payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) + prompt_result = PromptResult(messages=[message]) + payload = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_post_fetch(payload, context) assert result.continue_processing # Malign payload (denied by OPA (rego) policy) message = Message(content=TextContent(type="text", text="abc@example.com"), role=Role.USER) - prompt_result = PluginResult(messages=[message]) - payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) + prompt_result = PromptResult(messages=[message]) + payload = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_post_fetch(payload, context) assert not result.continue_processing @@ -235,6 +235,7 @@ async def test_post_resource_fetch_opapluginfilter(): type="resource", uri="test://abc", text="abc", + id="1" ) payload = ResourcePostFetchPayload(uri="https://example.com/docs", content=content) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) @@ -246,6 +247,7 @@ async def test_post_resource_fetch_opapluginfilter(): type="resource", uri="test://large", text="test://abc@example.com", + id="1" ) payload = ResourcePostFetchPayload(uri="https://example.com", content=content) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))