Skip to content

Commit c5007db

Browse files
agam1092005samriddhi99
authored andcommitted
Remove required_ from core & added tests to ensure both are working
1 parent b9e6ad8 commit c5007db

2 files changed

Lines changed: 59 additions & 5 deletions

File tree

src/tirith/core/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,14 @@ def start_policy_evaluation_from_dict(policy_dict: Dict, input_dict: Dict) -> Di
225225
policy_meta = policy_dict.get("meta")
226226
eval_objects = policy_dict.get("evaluators")
227227
final_evaluation_policy_string = policy_dict.get("eval_expression")
228-
provider_module = policy_meta.get("required_provider", "core")
229-
# TODO: Write functionality for dynamically importing evaluators from other modules.
228+
229+
provider_module = policy_meta.get("provider", policy_meta.get("required_provider", "core"))
230+
230231
eval_results = []
231232
eval_results_obj = {}
232233
for eval_obj in eval_objects:
233234
eval_id = eval_obj.get("id")
234235
eval_description = eval_obj.get("description")
235-
logger.debug(f"Processing evaluator '{eval_id}'")
236236
eval_result = generate_evaluator_result(eval_obj, input_dict, provider_module)
237237
eval_result["id"] = eval_id
238238
eval_result["description"] = eval_description
@@ -241,10 +241,11 @@ def start_policy_evaluation_from_dict(policy_dict: Dict, input_dict: Dict) -> Di
241241
final_evaluation_result, errors = final_evaluator(final_evaluation_policy_string, eval_results_obj)
242242

243243
final_output = {
244-
"meta": {"version": policy_meta.get("version"), "required_provider": provider_module},
244+
"meta": {"version": policy_meta.get("version"), "provider": provider_module},
245245
"final_result": final_evaluation_result,
246246
"evaluators": eval_results,
247247
"errors": errors,
248248
"eval_expression": final_evaluation_policy_string,
249249
}
250250
return final_output
251+

tests/core/test_core.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytest import mark
33

44
from tirith.core.core import final_evaluator
5-
5+
from tirith.core.core import start_policy_evaluation_from_dict
66

77
@mark.passing
88
def test_final_evaluator_skipped_check_should_be_removed():
@@ -38,3 +38,56 @@ def test_final_evaluator_malicious_eval_should_err():
3838
"!skipped_check && passing_check || [].__class__.__base__", dict(skipped_check=None, passing_check=True)
3939
)
4040
assert actual_result == (False, ["The following symbols are not allowed: __class__, __base__"])
41+
42+
43+
@mark.passing
44+
def test_start_policy_evaluation_with_required_provider():
45+
policy_dict = {
46+
"meta": {"version": "1.0", "required_provider": "legacy_provider"},
47+
"evaluators": [],
48+
"eval_expression": "True",
49+
}
50+
input_dict = {}
51+
52+
result = start_policy_evaluation_from_dict(policy_dict, input_dict)
53+
54+
assert result["meta"]["provider"] == "legacy_provider"
55+
56+
@mark.passing
57+
def test_start_policy_evaluation_with_provider():
58+
policy_dict = {
59+
"meta": {"version": "1.0", "provider": "new_provider"},
60+
"evaluators": [],
61+
"eval_expression": "True",
62+
}
63+
input_dict = {}
64+
65+
result = start_policy_evaluation_from_dict(policy_dict, input_dict)
66+
67+
assert result["meta"]["provider"] == "new_provider"
68+
69+
@mark.passing
70+
def test_start_policy_evaluation_with_both_providers():
71+
policy_dict = {
72+
"meta": {"version": "1.0", "provider": "new_provider", "required_provider": "legacy_provider"},
73+
"evaluators": [],
74+
"eval_expression": "True",
75+
}
76+
input_dict = {}
77+
78+
result = start_policy_evaluation_from_dict(policy_dict, input_dict)
79+
80+
assert result["meta"]["provider"] == "new_provider"
81+
82+
@mark.passing
83+
def test_start_policy_evaluation_with_neither_provider():
84+
policy_dict = {
85+
"meta": {"version": "1.0"},
86+
"evaluators": [],
87+
"eval_expression": "True",
88+
}
89+
input_dict = {}
90+
91+
result = start_policy_evaluation_from_dict(policy_dict, input_dict)
92+
93+
assert result["meta"]["provider"] == "core"

0 commit comments

Comments
 (0)