diff --git a/test/common/test_config_utils.py b/test/common/test_config_utils.py new file mode 100644 index 0000000000..16f09b60ba --- /dev/null +++ b/test/common/test_config_utils.py @@ -0,0 +1,610 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import json +from pathlib import Path + +import pytest +from pydantic import Field + +from olive.common.config_utils import ( + ConfigBase, + ConfigDictBase, + ConfigListBase, + ConfigParam, + NestedConfig, + ParamCategory, + config_json_dumps, + config_json_loads, + convert_configs_to_dicts, + create_config_class, + load_config_file, + serialize_function, + serialize_object, + serialize_to_json, + validate_config, + validate_enum, + validate_lowercase, +) + + +class TestSerializeFunction: + def test_serialize_function_returns_dict(self): + # setup + def my_func(x): + return x + + # execute + result = serialize_function(my_func) + + # assert + assert result["olive_parameter_type"] == "Function" + assert result["name"] == "my_func" + assert "signature" in result + assert "sourcecode_hash" in result + + +class TestSerializeObject: + def test_serialize_object_returns_dict(self): + # setup + obj = {"key": "value"} + + # execute + result = serialize_object(obj) + + # assert + assert result["olive_parameter_type"] == "Object" + assert result["type"] == "dict" + assert "hash" in result + + +class TestConfigJsonDumps: + def test_basic_dict(self): + # setup + data = {"key": "value", "num": 42} + + # execute + result = config_json_dumps(data) + parsed = json.loads(result) + + # assert + assert parsed == data + + def test_path_serialization_absolute(self): + # setup + data = {"path": Path("/some/path")} + + # execute + result = config_json_dumps(data, make_absolute=True) + parsed = json.loads(result) + + # assert + assert isinstance(parsed["path"], str) + + def test_path_serialization_relative(self): + # setup + data = {"path": Path("relative/path")} + + # execute + result = config_json_dumps(data, make_absolute=False) + parsed = json.loads(result) + + # assert + assert parsed["path"] == "relative/path" + + def test_function_serialization(self): + # setup + def sample_func(): + pass + + data = {"func": sample_func} + + # execute + result = config_json_dumps(data) + parsed = json.loads(result) + + # assert + assert parsed["func"]["olive_parameter_type"] == "Function" + + +class TestConfigJsonLoads: + def test_basic_json(self): + # setup + data = '{"key": "value"}' + + # execute + result = config_json_loads(data) + + # assert + assert result == {"key": "value"} + + def test_function_object_raises_error(self): + # setup + data = json.dumps({"olive_parameter_type": "Function", "name": "my_func"}) + + # execute & assert + with pytest.raises(ValueError, match="Cannot load"): + config_json_loads(data) + + def test_custom_object_hook(self): + # setup + data = '{"key": "value"}' + + # execute + result = config_json_loads(data, object_hook=lambda obj: obj) + + # assert + assert result == {"key": "value"} + + +class TestSerializeToJson: + def test_dict_input(self): + # setup + data = {"key": "value"} + + # execute + result = serialize_to_json(data) + + # assert + assert result == data + + def test_config_base_input(self): + # setup + config = ConfigBase() + + # execute + result = serialize_to_json(config) + + # assert + assert isinstance(result, dict) + + def test_check_object_with_function_raises(self): + # setup + def my_func(): + pass + + # execute & assert + with pytest.raises(ValueError, match="Cannot serialize"): + serialize_to_json({"func": my_func}, check_object=True) + + +class TestLoadConfigFile: + def test_load_json_file(self, tmp_path): + # setup + config_file = tmp_path / "config.json" + config_file.write_text('{"key": "value"}') + + # execute + result = load_config_file(config_file) + + # assert + assert result == {"key": "value"} + + def test_load_yaml_file(self, tmp_path): + # setup + config_file = tmp_path / "config.yaml" + config_file.write_text("key: value\n") + + # execute + result = load_config_file(config_file) + + # assert + assert result == {"key": "value"} + + def test_load_yml_file(self, tmp_path): + # setup + config_file = tmp_path / "config.yml" + config_file.write_text("key: value\n") + + # execute + result = load_config_file(config_file) + + # assert + assert result == {"key": "value"} + + def test_unsupported_file_type_raises(self, tmp_path): + # setup + config_file = tmp_path / "config.txt" + config_file.write_text("key=value") + + # execute & assert + with pytest.raises(ValueError, match="Unsupported file type"): + load_config_file(config_file) + + +class TestConfigBase: + def test_to_json(self): + # setup + config = ConfigBase() + + # execute + result = config.to_json() + + # assert + assert isinstance(result, dict) + + def test_from_json(self): + # execute + config = ConfigBase.from_json({}) + + # assert + assert isinstance(config, ConfigBase) + + def test_parse_file_or_obj_dict(self): + # execute + config = ConfigBase.parse_file_or_obj({}) + + # assert + assert isinstance(config, ConfigBase) + + def test_parse_file_or_obj_json_file(self, tmp_path): + # setup + config_file = tmp_path / "config.json" + config_file.write_text("{}") + + # execute + config = ConfigBase.parse_file_or_obj(config_file) + + # assert + assert isinstance(config, ConfigBase) + + +class TestConfigListBase: + def test_iter(self): + # setup + config = ConfigListBase.model_validate([1, 2, 3]) + + # execute + result = list(config) + + # assert + assert result == [1, 2, 3] + + def test_getitem(self): + # setup + config = ConfigListBase.model_validate([10, 20, 30]) + + # execute + first = config[0] + last = config[2] + + # assert + assert first == 10 + assert last == 30 + + def test_len(self): + # setup + config = ConfigListBase.model_validate([1, 2, 3]) + + # execute + result = len(config) + + # assert + assert result == 3 + + def test_to_json(self): + # setup + config = ConfigListBase.model_validate([1, 2, 3]) + + # execute + result = config.to_json() + + # assert + assert isinstance(result, list) + + +class TestConfigDictBase: + def test_iter(self): + # setup + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + + # execute + result = set(config) + + # assert + assert result == {"a", "b"} + + def test_keys(self): + # setup + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + + # execute + result = set(config.keys()) + + # assert + assert result == {"a", "b"} + + def test_values(self): + # setup + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + + # execute + result = set(config.values()) + + # assert + assert result == {1, 2} + + def test_items(self): + # setup + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + + # execute + result = dict(config.items()) + + # assert + assert result == {"a": 1, "b": 2} + + def test_getitem(self): + # setup + config = ConfigDictBase.model_validate({"key": "value"}) + + # execute + result = config["key"] + + # assert + assert result == "value" + + def test_len(self): + # setup + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + + # execute + result = len(config) + + # assert + assert result == 2 + + def test_len_empty(self): + # setup + config = ConfigDictBase.model_validate({}) + + # execute + result = len(config) + + # assert + assert result == 0 + + +class TestNestedConfig: + def test_gather_nested_field_basic(self): + # setup + class MyConfig(NestedConfig): + type: str + config: dict = Field(default_factory=dict) + + # execute + c = MyConfig(type="test", key1="val1") + + # assert + assert c.config == {"key1": "val1"} + + def test_gather_nested_field_none_values(self): + # setup + class MyConfig(NestedConfig): + type: str = "default" + config: dict = Field(default_factory=dict) + + # execute + c = MyConfig.model_validate(None) + + # assert + assert c.type == "default" + + def test_gather_nested_field_explicit_config(self): + # setup + class MyConfig(NestedConfig): + type: str + config: dict = Field(default_factory=dict) + + # execute + c = MyConfig(type="test", config={"inner_key": "inner_val"}) + + # assert + assert c.config == {"inner_key": "inner_val"} + + +class TestCaseInsensitiveEnum: + def test_case_insensitive_creation(self): + # execute + lower = ParamCategory("none") + upper = ParamCategory("NONE") + mixed = ParamCategory("None") + + # assert + assert lower == ParamCategory.NONE + assert upper == ParamCategory.NONE + assert mixed == ParamCategory.NONE + + def test_invalid_value_returns_none(self): + # execute + result = ParamCategory._missing_("nonexistent") + + # assert + assert result is None + + +class TestConfigParam: + def test_config_param_defaults(self): + # execute + param = ConfigParam(type_=str) + + # assert + assert param.required is False + assert param.default_value is None + assert param.category == ParamCategory.NONE + + def test_config_param_required(self): + # execute + param = ConfigParam(type_=str, required=True) + + # assert + assert param.required is True + + def test_config_param_repr(self): + # setup + param = ConfigParam(type_=str, required=True, description="A test param") + + # execute + repr_str = repr(param) + + # assert + assert "required=True" in repr_str + assert "description=" in repr_str + + +class TestValidateEnum: + def test_valid_enum_value(self): + # execute + result = validate_enum(ParamCategory, "none") + + # assert + assert result == ParamCategory.NONE + + def test_invalid_enum_value_raises(self): + # execute & assert + with pytest.raises(ValueError, match="Invalid value"): + validate_enum(ParamCategory, "invalid_value") + + +class TestValidateLowercase: + def test_string_lowercased(self): + # execute + result = validate_lowercase("HELLO") + + # assert + assert result == "hello" + + def test_non_string_unchanged(self): + # execute + result_int = validate_lowercase(42) + result_none = validate_lowercase(None) + + # assert + assert result_int == 42 + assert result_none is None + + +class TestCreateConfigClass: + def test_create_basic_config_class(self): + # setup + config = { + "name": ConfigParam(type_=str, required=True), + "value": ConfigParam(type_=int, default_value=10), + } + + # execute + cls = create_config_class("TestConfig", config) + instance = cls(name="test") + + # assert + assert instance.name == "test" + assert instance.value == 10 + + def test_create_config_class_with_optional_field(self): + # setup + config = { + "name": ConfigParam(type_=str, default_value=None), + } + + # execute + cls = create_config_class("OptionalConfig", config) + instance = cls() + + # assert + assert instance.name is None + + +class TestValidateConfig: + def test_validate_dict_config(self): + # setup + config = {"name": "test"} + + class MyConfig(ConfigBase): + name: str + + # execute + result = validate_config(config, MyConfig) + + # assert + assert isinstance(result, MyConfig) + assert result.name == "test" + + def test_validate_none_config(self): + # setup + class MyConfig(ConfigBase): + name: str = "default" + + # execute + result = validate_config(None, MyConfig) + + # assert + assert result.name == "default" + + def test_validate_config_instance(self): + # setup + class MyConfig(ConfigBase): + name: str + + config = MyConfig(name="test") + + # execute + result = validate_config(config, MyConfig) + + # assert + assert result.name == "test" + + def test_validate_config_wrong_class_raises(self): + # setup + class MyConfig(ConfigBase): + name: str + + class OtherConfig(ConfigBase): + value: int + + config = OtherConfig(value=42) + + # execute & assert + with pytest.raises(ValueError, match="Invalid config class"): + validate_config(config, MyConfig) + + +class TestConvertConfigsToDicts: + def test_config_base_to_dict(self): + # setup + config = ConfigBase() + + # execute + result = convert_configs_to_dicts(config) + + # assert + assert isinstance(result, dict) + + def test_nested_dict_conversion(self): + # setup + data = {"key": "value"} + + # execute + result = convert_configs_to_dicts(data) + + # assert + assert result == {"key": "value"} + + def test_list_conversion(self): + # setup + data = ["a", "b"] + + # execute + result = convert_configs_to_dicts(data) + + # assert + assert result == ["a", "b"] + + def test_plain_value_passthrough(self): + # execute + result_int = convert_configs_to_dicts(42) + result_str = convert_configs_to_dicts("hello") + + # assert + assert result_int == 42 + assert result_str == "hello" diff --git a/test/data_container/test_registry.py b/test/data_container/test_registry.py new file mode 100644 index 0000000000..fa255593d1 --- /dev/null +++ b/test/data_container/test_registry.py @@ -0,0 +1,144 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from olive.data.constants import ( + DataComponentType, + DefaultDataContainer, +) +from olive.data.registry import Registry + + +class TestRegistryRegister: + def test_register_dataset_component(self): + # setup & execute + @Registry.register(DataComponentType.LOAD_DATASET, name="test_dataset_reg") + def my_dataset(): + return "dataset" + + # assert + result = Registry.get_load_dataset_component("test_dataset_reg") + assert result is my_dataset + + def test_register_pre_process_component(self): + # setup & execute + @Registry.register_pre_process(name="test_pre_process_reg") + def my_pre_process(data): + return data + + # assert + result = Registry.get_pre_process_component("test_pre_process_reg") + assert result is my_pre_process + + def test_register_post_process_component(self): + # setup & execute + @Registry.register_post_process(name="test_post_process_reg") + def my_post_process(data): + return data + + # assert + result = Registry.get_post_process_component("test_post_process_reg") + assert result is my_post_process + + def test_register_dataloader_component(self): + # setup & execute + @Registry.register_dataloader(name="test_dataloader_reg") + def my_dataloader(data): + return data + + # assert + result = Registry.get_dataloader_component("test_dataloader_reg") + assert result is my_dataloader + + def test_register_case_insensitive(self): + # setup & execute + @Registry.register(DataComponentType.LOAD_DATASET, name="CaseSensitiveTest_Reg") + def my_func(): + pass + + # assert + result = Registry.get_load_dataset_component("casesensitivetest_reg") + assert result is my_func + + def test_register_uses_class_name_when_no_name(self): + # setup & execute + @Registry.register(DataComponentType.LOAD_DATASET) + def unique_named_test_func_reg(): + pass + + # assert + result = Registry.get_load_dataset_component("unique_named_test_func_reg") + assert result is unique_named_test_func_reg + + +class TestRegistryGet: + def test_get_component(self): + # setup + @Registry.register(DataComponentType.LOAD_DATASET, name="test_get_comp_reg") + def my_func(): + pass + + # execute + result = Registry.get_component(DataComponentType.LOAD_DATASET.value, "test_get_comp_reg") + + # assert + assert result is my_func + + def test_get_by_subtype(self): + # setup + @Registry.register(DataComponentType.LOAD_DATASET, name="test_get_subtype_reg") + def my_func(): + pass + + # execute + result = Registry.get(DataComponentType.LOAD_DATASET.value, "test_get_subtype_reg") + + # assert + assert result is my_func + + +class TestRegistryDefaultComponents: + def test_get_default_load_dataset(self): + # execute + result = Registry.get_default_load_dataset_component() + + # assert + assert result is not None + + def test_get_default_pre_process(self): + # execute + result = Registry.get_default_pre_process_component() + + # assert + assert result is not None + + def test_get_default_post_process(self): + # execute + result = Registry.get_default_post_process_component() + + # assert + assert result is not None + + def test_get_default_dataloader(self): + # execute + result = Registry.get_default_dataloader_component() + + # assert + assert result is not None + + +class TestRegistryContainer: + def test_get_container_default(self): + # execute + result = Registry.get_container(None) + + # assert + assert result is not None + + def test_get_container_by_name(self): + # execute + result = Registry.get_container(DefaultDataContainer.DATA_CONTAINER.value) + + # assert + assert result is not None diff --git a/test/evaluator/test_metric_config.py b/test/evaluator/test_metric_config.py new file mode 100644 index 0000000000..f88dbe3de4 --- /dev/null +++ b/test/evaluator/test_metric_config.py @@ -0,0 +1,216 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import pytest +from pydantic import ValidationError + +from olive.evaluator.metric_config import ( + LatencyMetricConfig, + MetricGoal, + SizeOnDiskMetricConfig, + ThroughputMetricConfig, + get_user_config_class, +) + + +class TestLatencyMetricConfig: + def test_defaults(self): + # execute + config = LatencyMetricConfig() + + # assert + assert config.warmup_num == 10 + assert config.repeat_test_num == 20 + assert config.sleep_num == 0 + + def test_custom_values(self): + # execute + config = LatencyMetricConfig(warmup_num=5, repeat_test_num=100, sleep_num=2) + + # assert + assert config.warmup_num == 5 + assert config.repeat_test_num == 100 + assert config.sleep_num == 2 + + +class TestThroughputMetricConfig: + def test_defaults(self): + # execute + config = ThroughputMetricConfig() + + # assert + assert config.warmup_num == 10 + assert config.repeat_test_num == 20 + assert config.sleep_num == 0 + + def test_custom_values(self): + # execute + config = ThroughputMetricConfig(warmup_num=3, repeat_test_num=50, sleep_num=1) + + # assert + assert config.warmup_num == 3 + assert config.repeat_test_num == 50 + assert config.sleep_num == 1 + + +class TestSizeOnDiskMetricConfig: + def test_creation(self): + # execute + config = SizeOnDiskMetricConfig() + + # assert + assert isinstance(config, SizeOnDiskMetricConfig) + + +class TestMetricGoal: + def test_threshold_type(self): + # execute + goal = MetricGoal(type="threshold", value=0.9) + + # assert + assert goal.type == "threshold" + assert goal.value == 0.9 + + def test_min_improvement_type(self): + # execute + goal = MetricGoal(type="min-improvement", value=0.05) + + # assert + assert goal.type == "min-improvement" + assert goal.value == 0.05 + + def test_max_degradation_type(self): + # execute + goal = MetricGoal(type="max-degradation", value=0.1) + + # assert + assert goal.type == "max-degradation" + assert goal.value == 0.1 + + def test_percent_min_improvement_type(self): + # execute + goal = MetricGoal(type="percent-min-improvement", value=5.0) + + # assert + assert goal.type == "percent-min-improvement" + + def test_percent_max_degradation_type(self): + # execute + goal = MetricGoal(type="percent-max-degradation", value=10.0) + + # assert + assert goal.type == "percent-max-degradation" + + def test_invalid_type_raises(self): + # execute & assert + with pytest.raises(ValidationError, match="Metric goal type must be one of"): + MetricGoal(type="invalid_type", value=0.5) + + def test_negative_value_for_min_improvement_raises(self): + # execute & assert + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="min-improvement", value=-0.5) + + def test_negative_value_for_max_degradation_raises(self): + # execute & assert + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="max-degradation", value=-0.1) + + def test_negative_value_for_percent_min_improvement_raises(self): + # execute & assert + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="percent-min-improvement", value=-5.0) + + def test_negative_value_for_percent_max_degradation_raises(self): + # execute & assert + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="percent-max-degradation", value=-10.0) + + def test_threshold_allows_negative_value(self): + # execute + goal = MetricGoal(type="threshold", value=-1.0) + + # assert + assert goal.value == -1.0 + + def test_has_regression_goal_min_improvement(self): + # setup + goal = MetricGoal(type="min-improvement", value=0.05) + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False + + def test_has_regression_goal_percent_min_improvement(self): + # setup + goal = MetricGoal(type="percent-min-improvement", value=5.0) + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False + + def test_has_regression_goal_max_degradation_positive(self): + # setup + goal = MetricGoal(type="max-degradation", value=0.1) + + # execute + result = goal.has_regression_goal() + + # assert + assert result is True + + def test_has_regression_goal_max_degradation_zero(self): + # setup + goal = MetricGoal(type="max-degradation", value=0.0) + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False + + def test_has_regression_goal_percent_max_degradation_positive(self): + # setup + goal = MetricGoal(type="percent-max-degradation", value=10.0) + + # execute + result = goal.has_regression_goal() + + # assert + assert result is True + + def test_has_regression_goal_threshold(self): + # setup + goal = MetricGoal(type="threshold", value=0.9) + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False + + +class TestGetUserConfigClass: + def test_custom_metric_type(self): + # execute + cls = get_user_config_class("custom") + instance = cls() + + # assert + assert hasattr(instance, "user_script") + assert hasattr(instance, "evaluate_func") + + def test_non_custom_metric_type_includes_common_fields(self): + # execute + cls = get_user_config_class("latency") + instance = cls() + + # assert + assert hasattr(instance, "user_script") + assert hasattr(instance, "inference_settings") diff --git a/test/evaluator/test_metric_result.py b/test/evaluator/test_metric_result.py new file mode 100644 index 0000000000..8583ea79c5 --- /dev/null +++ b/test/evaluator/test_metric_result.py @@ -0,0 +1,173 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import json + +from olive.evaluator.metric_result import ( + MetricResult, + SubMetricResult, + flatten_metric_result, + flatten_metric_sub_type, + joint_metric_key, +) + + +class TestSubMetricResult: + def test_creation(self): + # execute + result = SubMetricResult(value=0.95, priority=1, higher_is_better=True) + + # assert + assert result.value == 0.95 + assert result.priority == 1 + assert result.higher_is_better is True + + def test_integer_value(self): + # execute + result = SubMetricResult(value=100, priority=2, higher_is_better=False) + + # assert + assert result.value == 100 + + def test_float_value(self): + # execute + result = SubMetricResult(value=0.001, priority=0, higher_is_better=True) + + # assert + assert result.value == 0.001 + + +class TestMetricResult: + def _make_result(self): + return MetricResult.model_validate( + { + "accuracy-top1": SubMetricResult(value=0.95, priority=1, higher_is_better=True), + "latency-avg": SubMetricResult(value=10.5, priority=2, higher_is_better=False), + "latency-p99": SubMetricResult(value=20.0, priority=3, higher_is_better=False), + } + ) + + def test_get_value(self): + # setup + result = self._make_result() + + # execute + accuracy = result.get_value("accuracy", "top1") + latency = result.get_value("latency", "avg") + + # assert + assert accuracy == 0.95 + assert latency == 10.5 + + def test_get_all_sub_type_metric_value(self): + # setup + result = self._make_result() + + # execute + latency_values = result.get_all_sub_type_metric_value("latency") + + # assert + assert latency_values == {"avg": 10.5, "p99": 20.0} + + def test_get_all_sub_type_single_metric(self): + # setup + result = self._make_result() + + # execute + accuracy_values = result.get_all_sub_type_metric_value("accuracy") + + # assert + assert accuracy_values == {"top1": 0.95} + + def test_str_representation(self): + # setup + result = self._make_result() + + # execute + result_str = str(result) + parsed = json.loads(result_str) + + # assert + assert parsed["accuracy-top1"] == 0.95 + assert parsed["latency-avg"] == 10.5 + + def test_len(self): + # setup + result = self._make_result() + + # execute + length = len(result) + + # assert + assert length == 3 + + def test_getitem(self): + # setup + result = self._make_result() + + # execute + item = result["accuracy-top1"] + + # assert + assert item.value == 0.95 + + def test_delimiter(self): + # execute + delimiter = MetricResult.delimiter + + # assert + assert delimiter == "-" + + +class TestJointMetricKey: + def test_basic(self): + # execute + result = joint_metric_key("accuracy", "top1") + + # assert + assert result == "accuracy-top1" + + def test_with_special_names(self): + # execute + result = joint_metric_key("latency", "p99") + + # assert + assert result == "latency-p99" + + +class TestFlattenMetricSubType: + def test_flatten(self): + # setup + metric_dict = { + "accuracy": {"top1": {"value": 0.95, "priority": 1, "higher_is_better": True}}, + "latency": {"avg": {"value": 10.5, "priority": 2, "higher_is_better": False}}, + } + + # execute + result = flatten_metric_sub_type(metric_dict) + + # assert + assert "accuracy-top1" in result + assert "latency-avg" in result + + +class TestFlattenMetricResult: + def test_flatten_to_metric_result(self): + # setup + dict_results = { + "accuracy": { + "top1": {"value": 0.95, "priority": 1, "higher_is_better": True}, + }, + "latency": { + "avg": {"value": 10.5, "priority": 2, "higher_is_better": False}, + }, + } + + # execute + result = flatten_metric_result(dict_results) + + # assert + assert isinstance(result, MetricResult) + assert result.get_value("accuracy", "top1") == 0.95 + assert result.get_value("latency", "avg") == 10.5 diff --git a/test/search/test_search_parameter.py b/test/search/test_search_parameter.py new file mode 100644 index 0000000000..dcb4cb72ea --- /dev/null +++ b/test/search/test_search_parameter.py @@ -0,0 +1,357 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import pytest + +from olive.search.search_parameter import ( + Boolean, + Categorical, + Conditional, + ConditionalDefault, + SpecialParamValue, + json_to_search_parameter, +) + + +class TestSpecialParamValue: + def test_ignored_value(self): + # execute + result = SpecialParamValue.IGNORED + + # assert + assert result == "OLIVE_IGNORED_PARAM_VALUE" + + def test_invalid_value(self): + # execute + result = SpecialParamValue.INVALID + + # assert + assert result == "OLIVE_INVALID_PARAM_VALUE" + + +class TestCategorical: + def test_int_support(self): + # setup + cat = Categorical([1, 2, 3]) + + # execute + result = cat.get_support() + + # assert + assert result == [1, 2, 3] + + def test_string_support(self): + # setup + cat = Categorical(["a", "b", "c"]) + + # execute + result = cat.get_support() + + # assert + assert result == ["a", "b", "c"] + + def test_float_support(self): + # setup + cat = Categorical([0.1, 0.5, 1.0]) + + # execute + result = cat.get_support() + + # assert + assert result == [0.1, 0.5, 1.0] + + def test_bool_support(self): + # setup + cat = Categorical([True, False]) + + # execute + result = cat.get_support() + + # assert + assert result == [True, False] + + def test_repr(self): + # setup + cat = Categorical([1, 2, 3]) + + # execute + result = repr(cat) + + # assert + assert result == "Categorical([1, 2, 3])" + + def test_to_json(self): + # setup + cat = Categorical([1, 2, 3]) + + # execute + result = cat.to_json() + + # assert + assert result["olive_parameter_type"] == "SearchParameter" + assert result["type"] == "Categorical" + assert result["support"] == [1, 2, 3] + + +class TestBoolean: + def test_support(self): + # setup + b = Boolean() + + # execute + result = b.get_support() + + # assert + assert result == [True, False] + + def test_is_categorical(self): + # execute + result = issubclass(Boolean, Categorical) + + # assert + assert result + + +class TestConditional: + def test_single_parent(self): + # setup + cond = Conditional( + parents=("parent1",), + support={ + ("value1",): Categorical([1, 2, 3]), + ("value2",): Categorical([4, 5, 6]), + }, + default=Categorical([7, 8, 9]), + ) + + # execute + result_v1 = cond.get_support_with_args({"parent1": "value1"}) + result_v2 = cond.get_support_with_args({"parent1": "value2"}) + result_unknown = cond.get_support_with_args({"parent1": "unknown"}) + + # assert + assert result_v1 == [1, 2, 3] + assert result_v2 == [4, 5, 6] + assert result_unknown == [7, 8, 9] + + def test_multi_parent(self): + # setup + cond = Conditional( + parents=("parent1", "parent2"), + support={ + ("v1", "v2"): Categorical([10, 20]), + }, + ) + + # execute + result = cond.get_support_with_args({"parent1": "v1", "parent2": "v2"}) + + # assert + assert result == [10, 20] + + def test_default_is_invalid(self): + # setup + cond = Conditional( + parents=("p",), + support={("a",): Categorical([1])}, + ) + + # execute + result = cond.get_support_with_args({"p": "missing"}) + + # assert + assert result == [SpecialParamValue.INVALID] + + def test_get_invalid_choice(self): + # execute + result = Conditional.get_invalid_choice() + + # assert + assert isinstance(result, Categorical) + assert result.get_support() == [SpecialParamValue.INVALID] + + def test_get_ignored_choice(self): + # execute + result = Conditional.get_ignored_choice() + + # assert + assert isinstance(result, Categorical) + assert result.get_support() == [SpecialParamValue.IGNORED] + + def test_repr(self): + # setup + cond = Conditional( + parents=("p",), + support={("a",): Categorical([1])}, + ) + + # execute + result = repr(cond) + + # assert + assert "Conditional" in result + assert "parents" in result + + def test_to_json(self): + # setup + cond = Conditional( + parents=("p",), + support={("a",): Categorical([1, 2])}, + default=Categorical([3]), + ) + + # execute + result = cond.to_json() + + # assert + assert result["olive_parameter_type"] == "SearchParameter" + assert result["type"] == "Conditional" + assert result["parents"] == ("p",) + + def test_condition_single_parent(self): + # setup + cond = Conditional( + parents=("p",), + support={ + ("a",): Categorical([1, 2]), + ("b",): Categorical([3, 4]), + }, + ) + + # execute + result = cond.condition({"p": "a"}) + + # assert + assert isinstance(result, Categorical) + assert result.get_support() == [1, 2] + + def test_condition_returns_default_when_no_match(self): + # setup + cond = Conditional( + parents=("p1", "p2"), + support={("a", "b"): Categorical([1])}, + default=Categorical([99]), + ) + + # execute + result = cond.condition({"p1": "missing"}) + + # assert + assert isinstance(result, Categorical) + assert result.get_support() == [99] + + +class TestConditionalDefault: + def test_basic(self): + # setup + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1, ("b",): 2}, + default=3, + ) + + # execute + result_a = cd.get_support_with_args({"p": "a"}) + result_b = cd.get_support_with_args({"p": "b"}) + result_c = cd.get_support_with_args({"p": "c"}) + + # assert + assert result_a == 1 + assert result_b == 2 + assert result_c == 3 + + def test_default_invalid(self): + # setup + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1}, + ) + + # execute + result = cd.get_support_with_args({"p": "missing"}) + + # assert + assert result == SpecialParamValue.INVALID + + def test_condition(self): + # setup + cd = ConditionalDefault( + parents=("p",), + support={("a",): 10, ("b",): 20}, + default=30, + ) + + # execute + result_a = cd.condition({"p": "a"}) + result_b = cd.condition({"p": "b"}) + result_c = cd.condition({"p": "c"}) + + # assert + assert result_a == 10 + assert result_b == 20 + assert result_c == 30 + + def test_get_invalid_choice(self): + # execute + result = ConditionalDefault.get_invalid_choice() + + # assert + assert result == SpecialParamValue.INVALID + + def test_get_ignored_choice(self): + # execute + result = ConditionalDefault.get_ignored_choice() + + # assert + assert result == SpecialParamValue.IGNORED + + def test_to_json(self): + # setup + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1}, + default=2, + ) + + # execute + result = cd.to_json() + + # assert + assert result["type"] == "ConditionalDefault" + + def test_repr(self): + # setup + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1}, + default=2, + ) + + # execute + result = repr(cd) + + # assert + assert "ConditionalDefault" in result + + +class TestJsonToSearchParameter: + def test_categorical(self): + # setup + json_data = {"olive_parameter_type": "SearchParameter", "type": "Categorical", "support": [1, 2, 3]} + + # execute + result = json_to_search_parameter(json_data) + + # assert + assert isinstance(result, Categorical) + assert result.get_support() == [1, 2, 3] + + def test_unknown_type_raises(self): + # setup + json_data = {"olive_parameter_type": "SearchParameter", "type": "Unknown"} + + # execute & assert + with pytest.raises(ValueError, match="Unknown search parameter type"): + json_to_search_parameter(json_data) diff --git a/test/search/test_search_point.py b/test/search/test_search_point.py new file mode 100644 index 0000000000..23f252751a --- /dev/null +++ b/test/search/test_search_point.py @@ -0,0 +1,129 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.search_parameter import SpecialParamValue +from olive.search.search_point import SearchPoint + + +class TestSearchPoint: + def _make_point(self, index=0, values=None): + if values is None: + values = OrderedDict( + { + "pass1": ( + 0, + OrderedDict( + { + "param1": (0, "a"), + "param2": (1, 10), + } + ), + ) + } + ) + return SearchPoint(index=index, values=values) + + def test_creation(self): + # execute + point = self._make_point() + + # assert + assert point.index == 0 + + def test_repr(self): + # setup + point = self._make_point() + + # execute + result = repr(point) + + # assert + assert "SearchPoint" in result + assert "0" in result + + def test_equality_same(self): + # setup + point1 = self._make_point() + point2 = self._make_point() + + # execute + result = point1 == point2 + + # assert + assert result + + def test_equality_different_index(self): + # setup + point1 = self._make_point(index=0) + point2 = self._make_point(index=1) + + # execute + result = point1 == point2 + + # assert + assert not result + + def test_equality_different_type(self): + # setup + point = self._make_point() + + # execute + result = point == "not a search point" + + # assert + assert not result + + def test_is_valid_true(self): + # setup + point = self._make_point() + + # execute + result = point.is_valid() + + # assert + assert result is True + + def test_is_valid_false_with_invalid(self): + # setup + values = OrderedDict( + { + "pass1": OrderedDict( + { + "param1": SpecialParamValue.INVALID, + } + ) + } + ) + point = SearchPoint(index=0, values=values) + + # execute + result = point.is_valid() + + # assert + assert result is False + + def test_to_json(self): + # setup + point = self._make_point(index=5) + + # execute + result = point.to_json() + + # assert + assert result["index"] == 5 + assert "values" in result + + def test_from_json_roundtrip(self): + # setup + point = self._make_point(index=3) + json_data = point.to_json() + + # execute + restored = SearchPoint.from_json(json_data) + + # assert + assert restored.index == 3 + assert restored == point diff --git a/test/search/test_search_sample.py b/test/search/test_search_sample.py new file mode 100644 index 0000000000..3dc4c6345e --- /dev/null +++ b/test/search/test_search_sample.py @@ -0,0 +1,111 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.search_parameter import SpecialParamValue +from olive.search.search_point import SearchPoint +from olive.search.search_sample import SearchSample + + +class TestSearchSample: + def _make_sample(self, invalid_param=False, ignored_param=False): + param_value = "a" + if invalid_param: + param_value = SpecialParamValue.INVALID + elif ignored_param: + param_value = SpecialParamValue.IGNORED + + values = OrderedDict( + { + "pass1": ( + 0, + OrderedDict( + { + "param1": (0, param_value), + "param2": (1, 10), + } + ), + ) + } + ) + point = SearchPoint(index=0, values=values) + return SearchSample(search_point=point, model_ids=["model_0"]) + + def test_creation(self): + # execute + sample = self._make_sample() + + # assert + assert sample.model_ids == ["model_0"] + assert sample.search_point.index == 0 + + def test_passes_configs_valid(self): + # setup + sample = self._make_sample() + + # execute + configs = sample.passes_configs + + # assert + assert configs is not None + assert "pass1" in configs + assert configs["pass1"]["params"]["param1"] == "a" + assert configs["pass1"]["params"]["param2"] == 10 + + def test_passes_configs_with_invalid_returns_none(self): + # setup + sample = self._make_sample(invalid_param=True) + + # execute + configs = sample.passes_configs + + # assert + assert configs is None + + def test_passes_configs_with_ignored_excludes_param(self): + # setup + sample = self._make_sample(ignored_param=True) + + # execute + configs = sample.passes_configs + + # assert + assert configs is not None + assert "param1" not in configs["pass1"]["params"] + assert configs["pass1"]["params"]["param2"] == 10 + + def test_to_json(self): + # setup + sample = self._make_sample() + + # execute + result = sample.to_json() + + # assert + assert "search_point" in result + assert "model_ids" in result + assert result["model_ids"] == ["model_0"] + + def test_from_json_roundtrip(self): + # setup + sample = self._make_sample() + json_data = sample.to_json() + + # execute + restored = SearchSample.from_json(json_data) + + # assert + assert restored.model_ids == sample.model_ids + assert restored.search_point.index == sample.search_point.index + + def test_repr(self): + # setup + sample = self._make_sample() + + # execute + result = repr(sample) + + # assert + assert "SearchSample" in result diff --git a/test/telemetry/__init__.py b/test/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/telemetry/test_telemetry_utils.py b/test/telemetry/test_telemetry_utils.py new file mode 100644 index 0000000000..97029abd33 --- /dev/null +++ b/test/telemetry/test_telemetry_utils.py @@ -0,0 +1,165 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import base64 +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from olive.telemetry.utils import ( + _decode_cache_line, + _encode_cache_line, + _format_exception_message, + _resolve_home_dir, + get_telemetry_base_dir, +) + + +class TestResolveHomeDir: + def test_returns_path(self): + # execute + result = _resolve_home_dir() + + # assert + assert isinstance(result, Path) + + def test_with_home_env_set(self): + # execute + with patch.dict(os.environ, {"HOME": "/tmp/test_home"}): + result = _resolve_home_dir() + + # assert + assert isinstance(result, Path) + + def test_without_home_env(self): + # execute + with patch.dict(os.environ, {}, clear=True): + result = _resolve_home_dir() + + # assert + assert isinstance(result, Path) + + +class TestGetTelemetryBaseDir: + def test_returns_path(self): + # setup + get_telemetry_base_dir.cache_clear() + + # execute + result = get_telemetry_base_dir() + + # assert + assert isinstance(result, Path) + + def test_path_contains_onnxruntime(self): + # setup + get_telemetry_base_dir.cache_clear() + + # execute + result = get_telemetry_base_dir() + + # assert + assert ".onnxruntime" in str(result) + + +class TestFormatExceptionMessage: + def test_basic_exception(self): + # setup + exception = None + traceback = None + try: + 1 / 0 # noqa: B018 + except ZeroDivisionError as ex: + exception = ex + traceback = ex.__traceback__ + + # execute + result = _format_exception_message(exception, traceback) + + # assert + assert "ZeroDivisionError" in result + + def test_exception_without_traceback(self): + # setup + ex = ValueError("test error") + + # execute + result = _format_exception_message(ex) + + # assert + assert "test error" in result + + +class TestEncodeCacheLine: + def test_encode_basic_string(self): + # setup + expected = base64.b64encode(b"hello").decode("ascii") + + # execute + result = _encode_cache_line("hello") + + # assert + assert result == expected + + def test_encode_empty_string(self): + # setup + expected = base64.b64encode(b"").decode("ascii") + + # execute + result = _encode_cache_line("") + + # assert + assert result == expected + + def test_encode_unicode_string(self): + # execute + result = _encode_cache_line("hello \u4e16\u754c") + decoded = base64.b64decode(result).decode("utf-8") + + # assert + assert decoded == "hello \u4e16\u754c" + + +class TestDecodeCacheLine: + def test_decode_basic_string(self): + # setup + encoded = base64.b64encode(b"hello").decode("ascii") + + # execute + result = _decode_cache_line(encoded) + + # assert + assert result == "hello" + + def test_decode_empty_string(self): + # setup + encoded = base64.b64encode(b"").decode("ascii") + + # execute + result = _decode_cache_line(encoded) + + # assert + assert result == "" + + +class TestEncodeDecodeRoundtrip: + @pytest.mark.parametrize( + "text", + [ + "simple text", + "path/to/file.json", + '{"key": "value"}', + "special chars: !@#$%^&*()", + "", + ], + ) + def test_roundtrip(self, text): + # execute + encoded = _encode_cache_line(text) + decoded = _decode_cache_line(encoded) + + # assert + assert decoded == text diff --git a/test/test_constants.py b/test/test_constants.py new file mode 100644 index 0000000000..77283c2089 --- /dev/null +++ b/test/test_constants.py @@ -0,0 +1,90 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import pytest + +from olive.constants import ( + Framework, + Precision, + PrecisionBits, + QuantAlgorithm, + precision_bits_from_precision, +) + + +class TestFramework: + def test_framework_all_members(self): + # setup + expected = {"ONNX", "PYTORCH", "QAIRT", "QNN", "TENSORFLOW", "OPENVINO"} + + # execute + result = set(Framework.__members__.keys()) + + # assert + assert result == expected + + +class TestPrecision: + def test_precision_all_members(self): + # setup + expected_count = 14 + + # execute + result = len(Precision) + + # assert + assert result == expected_count + + +class TestQuantAlgorithm: + def test_quant_algorithm_case_insensitive(self): + # execute + lower = QuantAlgorithm("awq") + upper = QuantAlgorithm("AWQ") + mixed = QuantAlgorithm("Awq") + + # assert + assert lower == QuantAlgorithm.AWQ + assert upper == QuantAlgorithm.AWQ + assert mixed == QuantAlgorithm.AWQ + + def test_quant_algorithm_all_members(self): + # setup + expected = {"AWQ", "GPTQ", "HQQ", "RTN", "SPINQUANT", "QUAROT", "LPBQ", "SEQMSE", "ADAROUND"} + + # execute + result = set(QuantAlgorithm.__members__.keys()) + + # assert + assert result == expected + + +class TestPrecisionBitsFromPrecision: + @pytest.mark.parametrize( + ("precision", "expected"), + [ + (Precision.INT4, PrecisionBits.BITS4), + (Precision.INT8, PrecisionBits.BITS8), + (Precision.INT16, PrecisionBits.BITS16), + (Precision.INT32, PrecisionBits.BITS32), + (Precision.UINT4, PrecisionBits.BITS4), + (Precision.UINT8, PrecisionBits.BITS8), + (Precision.UINT16, PrecisionBits.BITS16), + (Precision.UINT32, PrecisionBits.BITS32), + ], + ) + def test_precision_to_bits_mapping(self, precision, expected): + # execute + result = precision_bits_from_precision(precision) + + # assert + assert result == expected + + @pytest.mark.parametrize("precision", [Precision.FP16, Precision.FP32, Precision.BF16, Precision.NF4]) + def test_precision_without_bits_mapping_returns_none(self, precision): + # execute + result = precision_bits_from_precision(precision) + + # assert + assert result is None diff --git a/test/test_exception.py b/test/test_exception.py new file mode 100644 index 0000000000..6afd936a0d --- /dev/null +++ b/test/test_exception.py @@ -0,0 +1,89 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import pytest + +from olive.exception import EXCEPTIONS_TO_RAISE, OliveError, OliveEvaluationError, OlivePassError + + +class TestOliveError: + def test_olive_error_is_exception(self): + # execute + result = issubclass(OliveError, Exception) + + # assert + assert result + + def test_olive_error_can_be_raised(self): + # execute & assert + with pytest.raises(OliveError, match="test error"): + raise OliveError("test error") + + def test_olive_error_empty_message(self): + # execute & assert + with pytest.raises(OliveError): + raise OliveError + + +class TestOlivePassError: + def test_olive_pass_error_inherits_olive_error(self): + # execute + result = issubclass(OlivePassError, OliveError) + + # assert + assert result + + def test_olive_pass_error_can_be_raised(self): + # execute & assert + with pytest.raises(OlivePassError, match="pass failed"): + raise OlivePassError("pass failed") + + def test_olive_pass_error_caught_as_olive_error(self): + # execute & assert + with pytest.raises(OliveError): + raise OlivePassError("pass failed") + + +class TestOliveEvaluationError: + def test_olive_evaluation_error_inherits_olive_error(self): + # execute + result = issubclass(OliveEvaluationError, OliveError) + + # assert + assert result + + def test_olive_evaluation_error_can_be_raised(self): + # execute & assert + with pytest.raises(OliveEvaluationError, match="evaluation failed"): + raise OliveEvaluationError("evaluation failed") + + def test_olive_evaluation_error_caught_as_olive_error(self): + # execute & assert + with pytest.raises(OliveError): + raise OliveEvaluationError("evaluation failed") + + +class TestExceptionsToRaise: + def test_exceptions_to_raise_is_tuple(self): + # execute + result = isinstance(EXCEPTIONS_TO_RAISE, tuple) + + # assert + assert result + + def test_exceptions_to_raise_contains_expected_types(self): + # setup + expected = {AssertionError, AttributeError, ImportError, TypeError, ValueError} + + # execute + result = set(EXCEPTIONS_TO_RAISE) + + # assert + assert result == expected + + @pytest.mark.parametrize("exc_type", EXCEPTIONS_TO_RAISE) + def test_each_exception_is_catchable(self, exc_type): + # execute & assert + with pytest.raises(exc_type): + raise exc_type("test") diff --git a/test/test_logging.py b/test/test_logging.py new file mode 100644 index 0000000000..99a23ed89a --- /dev/null +++ b/test/test_logging.py @@ -0,0 +1,191 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from pathlib import Path +from unittest.mock import patch + +import pytest + +from olive.logging import ( + enable_filelog, + get_logger_level, + get_olive_logger, + get_verbosity, + set_default_logger_severity, + set_verbosity, + set_verbosity_critical, + set_verbosity_debug, + set_verbosity_error, + set_verbosity_from_env, + set_verbosity_info, + set_verbosity_warning, +) + + +@pytest.fixture(autouse=True) +def _restore_logger_level(): + """Save and restore the olive logger level between tests.""" + logger = get_olive_logger() + original_level = logger.level + yield + logger.setLevel(original_level) + + +class TestGetOliveLogger: + def test_returns_olive_logger(self): + # execute + logger = get_olive_logger() + + # assert + assert logger.name == "olive" + + def test_returns_same_logger_instance(self): + # execute + logger1 = get_olive_logger() + logger2 = get_olive_logger() + + # assert + assert logger1 is logger2 + + +class TestSetVerbosity: + def test_set_verbosity_info(self): + # execute + set_verbosity_info() + + # assert + assert get_olive_logger().level == logging.INFO + + def test_set_verbosity_warning(self): + # execute + set_verbosity_warning() + + # assert + assert get_olive_logger().level == logging.WARNING + + def test_set_verbosity_debug(self): + # execute + set_verbosity_debug() + + # assert + assert get_olive_logger().level == logging.DEBUG + + def test_set_verbosity_error(self): + # execute + set_verbosity_error() + + # assert + assert get_olive_logger().level == logging.ERROR + + def test_set_verbosity_critical(self): + # execute + set_verbosity_critical() + + # assert + assert get_olive_logger().level == logging.CRITICAL + + def test_set_verbosity_custom_level(self): + # execute + set_verbosity(logging.WARNING) + + # assert + assert get_olive_logger().level == logging.WARNING + + +class TestSetVerbosityFromEnv: + def test_set_verbosity_from_env_default(self): + # execute + with patch.dict("os.environ", {}, clear=True): + set_verbosity_from_env() + + # assert + assert get_olive_logger().level == logging.INFO + + def test_set_verbosity_from_env_custom(self): + # execute + with patch.dict("os.environ", {"OLIVE_LOG_LEVEL": "DEBUG"}): + set_verbosity_from_env() + + # assert + assert get_olive_logger().level == logging.DEBUG + + +class TestGetVerbosity: + def test_get_verbosity_returns_int(self): + # setup + set_verbosity_info() + + # execute + level = get_verbosity() + + # assert + assert isinstance(level, int) + assert level == logging.INFO + + +class TestGetLoggerLevel: + @pytest.mark.parametrize( + ("level_int", "expected"), + [ + (0, logging.DEBUG), + (1, logging.INFO), + (2, logging.WARNING), + (3, logging.ERROR), + (4, logging.CRITICAL), + ], + ) + def test_valid_levels(self, level_int, expected): + # execute + result = get_logger_level(level_int) + + # assert + assert result == expected + + @pytest.mark.parametrize("invalid_level", [-1, 5, 10, 100]) + def test_invalid_levels_raise_value_error(self, invalid_level): + # execute & assert + with pytest.raises(ValueError, match="Invalid level"): + get_logger_level(invalid_level) + + +class TestSetDefaultLoggerSeverity: + @pytest.mark.parametrize("level", [0, 1, 2, 3, 4]) + def test_set_default_logger_severity(self, level): + # setup + expected = get_logger_level(level) + + # execute + set_default_logger_severity(level) + + # assert + assert get_olive_logger().level == expected + + +class TestEnableFilelog: + def test_enable_filelog_creates_handler(self, tmp_path): + # setup + workflow_id = "test_workflow" + logger = get_olive_logger() + original_handler_ids = {id(h) for h in logger.handlers} + log_file_path = (tmp_path / f"{workflow_id}.log").resolve() + + try: + # execute + enable_filelog(1, str(tmp_path), workflow_id) + + # assert + new_handlers = [h for h in logger.handlers if id(h) not in original_handler_ids] + matching_handlers = [ + h + for h in new_handlers + if isinstance(h, logging.FileHandler) and Path(h.baseFilename).resolve() == log_file_path + ] + assert matching_handlers, f"Expected a FileHandler for {log_file_path}, but none was added." + assert log_file_path.exists() + finally: + # cleanup + for h in [h for h in logger.handlers if id(h) not in original_handler_ids]: + logger.removeHandler(h) + h.close()