From 385826eceacb46d2a763499d4ea1b372d9bd9859 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 9 Apr 2026 23:58:40 +0000 Subject: [PATCH 1/6] Add comprehensive tests for 11 untested modules covering exception, constants, logging, config_utils, metric_config, metric_result, search_parameter, search_point, search_sample, data registry, and telemetry utils Agent-Logs-Url: https://github.com/microsoft/Olive/sessions/656d3838-fc06-4d40-8a11-b75da7f23724 Co-authored-by: xiaoyu-work <85524621+xiaoyu-work@users.noreply.github.com> --- test/common/test_config_utils.py | 363 +++++++++++++++++++++++++ test/data_container/test_registry.py | 106 ++++++++ test/evaluator/test_metric_config.py | 137 ++++++++++ test/evaluator/test_metric_result.py | 108 ++++++++ test/search/test_search_parameter.py | 211 ++++++++++++++ test/search/test_search_point.py | 83 ++++++ test/search/test_search_sample.py | 77 ++++++ test/telemetry/__init__.py | 0 test/telemetry/test_telemetry_utils.py | 107 ++++++++ test/test_constants.py | 152 +++++++++++ test/test_exception.py | 60 ++++ test/test_logging.py | 127 +++++++++ 12 files changed, 1531 insertions(+) create mode 100644 test/common/test_config_utils.py create mode 100644 test/data_container/test_registry.py create mode 100644 test/evaluator/test_metric_config.py create mode 100644 test/evaluator/test_metric_result.py create mode 100644 test/search/test_search_parameter.py create mode 100644 test/search/test_search_point.py create mode 100644 test/search/test_search_sample.py create mode 100644 test/telemetry/__init__.py create mode 100644 test/telemetry/test_telemetry_utils.py create mode 100644 test/test_constants.py create mode 100644 test/test_exception.py create mode 100644 test/test_logging.py diff --git a/test/common/test_config_utils.py b/test/common/test_config_utils.py new file mode 100644 index 0000000000..e053fc7452 --- /dev/null +++ b/test/common/test_config_utils.py @@ -0,0 +1,363 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import json +from pathlib import Path + +import pytest +from pydantic import Field + +from olive.common.config_utils import ( + ConfigBase, + ConfigDictBase, + ConfigListBase, + ConfigParam, + NestedConfig, + ParamCategory, + config_json_dumps, + config_json_loads, + convert_configs_to_dicts, + create_config_class, + load_config_file, + serialize_function, + serialize_object, + serialize_to_json, + validate_config, + validate_enum, + validate_lowercase, +) + + +class TestSerializeFunction: + def test_serialize_function_returns_dict(self): + def my_func(x): + return x + + result = serialize_function(my_func) + assert result["olive_parameter_type"] == "Function" + assert result["name"] == "my_func" + assert "signature" in result + assert "sourcecode_hash" in result + + +class TestSerializeObject: + def test_serialize_object_returns_dict(self): + obj = {"key": "value"} + result = serialize_object(obj) + assert result["olive_parameter_type"] == "Object" + assert result["type"] == "dict" + assert "hash" in result + + +class TestConfigJsonDumps: + def test_basic_dict(self): + data = {"key": "value", "num": 42} + result = config_json_dumps(data) + parsed = json.loads(result) + assert parsed == data + + def test_path_serialization_absolute(self): + data = {"path": Path("/some/path")} + result = config_json_dumps(data, make_absolute=True) + parsed = json.loads(result) + assert isinstance(parsed["path"], str) + + def test_path_serialization_relative(self): + data = {"path": Path("relative/path")} + result = config_json_dumps(data, make_absolute=False) + parsed = json.loads(result) + assert parsed["path"] == "relative/path" + + def test_function_serialization(self): + def sample_func(): + pass + + data = {"func": sample_func} + result = config_json_dumps(data) + parsed = json.loads(result) + assert parsed["func"]["olive_parameter_type"] == "Function" + + +class TestConfigJsonLoads: + def test_basic_json(self): + data = '{"key": "value"}' + result = config_json_loads(data) + assert result == {"key": "value"} + + def test_function_object_raises_error(self): + data = json.dumps({"olive_parameter_type": "Function", "name": "my_func"}) + with pytest.raises(ValueError, match="Cannot load"): + config_json_loads(data) + + def test_custom_object_hook(self): + data = '{"key": "value"}' + result = config_json_loads(data, object_hook=lambda obj: obj) + assert result == {"key": "value"} + + +class TestSerializeToJson: + def test_dict_input(self): + data = {"key": "value"} + result = serialize_to_json(data) + assert result == data + + def test_config_base_input(self): + config = ConfigBase() + result = serialize_to_json(config) + assert isinstance(result, dict) + + def test_check_object_with_function_raises(self): + def my_func(): + pass + + with pytest.raises(ValueError, match="Cannot serialize"): + serialize_to_json({"func": my_func}, check_object=True) + + +class TestLoadConfigFile: + def test_load_json_file(self, tmp_path): + config_file = tmp_path / "config.json" + config_file.write_text('{"key": "value"}') + result = load_config_file(config_file) + assert result == {"key": "value"} + + def test_load_yaml_file(self, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text("key: value\n") + result = load_config_file(config_file) + assert result == {"key": "value"} + + def test_load_yml_file(self, tmp_path): + config_file = tmp_path / "config.yml" + config_file.write_text("key: value\n") + result = load_config_file(config_file) + assert result == {"key": "value"} + + def test_unsupported_file_type_raises(self, tmp_path): + config_file = tmp_path / "config.txt" + config_file.write_text("key=value") + with pytest.raises(ValueError, match="Unsupported file type"): + load_config_file(config_file) + + +class TestConfigBase: + def test_to_json(self): + config = ConfigBase() + result = config.to_json() + assert isinstance(result, dict) + + def test_from_json(self): + config = ConfigBase.from_json({}) + assert isinstance(config, ConfigBase) + + def test_parse_file_or_obj_dict(self): + config = ConfigBase.parse_file_or_obj({}) + assert isinstance(config, ConfigBase) + + def test_parse_file_or_obj_json_file(self, tmp_path): + config_file = tmp_path / "config.json" + config_file.write_text("{}") + config = ConfigBase.parse_file_or_obj(config_file) + assert isinstance(config, ConfigBase) + + +class TestConfigListBase: + def test_iter(self): + config = ConfigListBase.model_validate([1, 2, 3]) + assert list(config) == [1, 2, 3] + + def test_getitem(self): + config = ConfigListBase.model_validate([10, 20, 30]) + assert config[0] == 10 + assert config[2] == 30 + + def test_len(self): + config = ConfigListBase.model_validate([1, 2, 3]) + assert len(config) == 3 + + def test_to_json(self): + config = ConfigListBase.model_validate([1, 2, 3]) + result = config.to_json() + assert isinstance(result, list) + + +class TestConfigDictBase: + def test_iter(self): + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + assert set(config) == {"a", "b"} + + def test_keys(self): + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + assert set(config.keys()) == {"a", "b"} + + def test_values(self): + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + assert set(config.values()) == {1, 2} + + def test_items(self): + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + assert dict(config.items()) == {"a": 1, "b": 2} + + def test_getitem(self): + config = ConfigDictBase.model_validate({"key": "value"}) + assert config["key"] == "value" + + def test_len(self): + config = ConfigDictBase.model_validate({"a": 1, "b": 2}) + assert len(config) == 2 + + def test_len_empty(self): + config = ConfigDictBase.model_validate({}) + assert len(config) == 0 + + +class TestNestedConfig: + def test_gather_nested_field_basic(self): + class MyConfig(NestedConfig): + type: str + config: dict = Field(default_factory=dict) + + c = MyConfig(type="test", key1="val1") + assert c.config == {"key1": "val1"} + + def test_gather_nested_field_none_values(self): + class MyConfig(NestedConfig): + type: str = "default" + config: dict = Field(default_factory=dict) + + c = MyConfig.model_validate(None) + assert c.type == "default" + + def test_gather_nested_field_explicit_config(self): + class MyConfig(NestedConfig): + type: str + config: dict = Field(default_factory=dict) + + c = MyConfig(type="test", config={"inner_key": "inner_val"}) + assert c.config == {"inner_key": "inner_val"} + + +class TestCaseInsensitiveEnum: + def test_case_insensitive_creation(self): + assert ParamCategory("none") == ParamCategory.NONE + assert ParamCategory("NONE") == ParamCategory.NONE + assert ParamCategory("None") == ParamCategory.NONE + + def test_invalid_value_returns_none(self): + result = ParamCategory._missing_("nonexistent") + assert result is None + + +class TestConfigParam: + def test_config_param_defaults(self): + param = ConfigParam(type_=str) + assert param.required is False + assert param.default_value is None + assert param.category == ParamCategory.NONE + + def test_config_param_required(self): + param = ConfigParam(type_=str, required=True) + assert param.required is True + + def test_config_param_repr(self): + param = ConfigParam(type_=str, required=True, description="A test param") + repr_str = repr(param) + assert "required=True" in repr_str + assert "description=" in repr_str + + +class TestValidateEnum: + def test_valid_enum_value(self): + result = validate_enum(ParamCategory, "none") + assert result == ParamCategory.NONE + + def test_invalid_enum_value_raises(self): + with pytest.raises(ValueError, match="Invalid value"): + validate_enum(ParamCategory, "invalid_value") + + +class TestValidateLowercase: + def test_string_lowercased(self): + assert validate_lowercase("HELLO") == "hello" + + def test_non_string_unchanged(self): + assert validate_lowercase(42) == 42 + assert validate_lowercase(None) is None + + +class TestCreateConfigClass: + def test_create_basic_config_class(self): + config = { + "name": ConfigParam(type_=str, required=True), + "value": ConfigParam(type_=int, default_value=10), + } + cls = create_config_class("TestConfig", config) + instance = cls(name="test") + assert instance.name == "test" + assert instance.value == 10 + + def test_create_config_class_with_optional_field(self): + config = { + "name": ConfigParam(type_=str, default_value=None), + } + cls = create_config_class("OptionalConfig", config) + instance = cls() + assert instance.name is None + + +class TestValidateConfig: + def test_validate_dict_config(self): + config = {"name": "test"} + + class MyConfig(ConfigBase): + name: str + + result = validate_config(config, MyConfig) + assert isinstance(result, MyConfig) + assert result.name == "test" + + def test_validate_none_config(self): + class MyConfig(ConfigBase): + name: str = "default" + + result = validate_config(None, MyConfig) + assert result.name == "default" + + def test_validate_config_instance(self): + class MyConfig(ConfigBase): + name: str + + config = MyConfig(name="test") + result = validate_config(config, MyConfig) + assert result.name == "test" + + def test_validate_config_wrong_class_raises(self): + class MyConfig(ConfigBase): + name: str + + class OtherConfig(ConfigBase): + value: int + + config = OtherConfig(value=42) + with pytest.raises(ValueError, match="Invalid config class"): + validate_config(config, MyConfig) + + +class TestConvertConfigsToDicts: + def test_config_base_to_dict(self): + config = ConfigBase() + result = convert_configs_to_dicts(config) + assert isinstance(result, dict) + + def test_nested_dict_conversion(self): + result = convert_configs_to_dicts({"key": "value"}) + assert result == {"key": "value"} + + def test_list_conversion(self): + result = convert_configs_to_dicts(["a", "b"]) + assert result == ["a", "b"] + + def test_plain_value_passthrough(self): + assert convert_configs_to_dicts(42) == 42 + assert convert_configs_to_dicts("hello") == "hello" diff --git a/test/data_container/test_registry.py b/test/data_container/test_registry.py new file mode 100644 index 0000000000..a4cb4de4cd --- /dev/null +++ b/test/data_container/test_registry.py @@ -0,0 +1,106 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from olive.data.constants import ( + DataComponentType, + DefaultDataContainer, +) +from olive.data.registry import Registry + + +class TestRegistryRegister: + def test_register_dataset_component(self): + @Registry.register(DataComponentType.LOAD_DATASET, name="test_dataset_reg") + def my_dataset(): + return "dataset" + + result = Registry.get_load_dataset_component("test_dataset_reg") + assert result is my_dataset + + def test_register_pre_process_component(self): + @Registry.register_pre_process(name="test_pre_process_reg") + def my_pre_process(data): + return data + + result = Registry.get_pre_process_component("test_pre_process_reg") + assert result is my_pre_process + + def test_register_post_process_component(self): + @Registry.register_post_process(name="test_post_process_reg") + def my_post_process(data): + return data + + result = Registry.get_post_process_component("test_post_process_reg") + assert result is my_post_process + + def test_register_dataloader_component(self): + @Registry.register_dataloader(name="test_dataloader_reg") + def my_dataloader(data): + return data + + result = Registry.get_dataloader_component("test_dataloader_reg") + assert result is my_dataloader + + def test_register_case_insensitive(self): + @Registry.register(DataComponentType.LOAD_DATASET, name="CaseSensitiveTest_Reg") + def my_func(): + pass + + result = Registry.get_load_dataset_component("casesensitivetest_reg") + assert result is my_func + + def test_register_uses_class_name_when_no_name(self): + @Registry.register(DataComponentType.LOAD_DATASET) + def unique_named_test_func_reg(): + pass + + result = Registry.get_load_dataset_component("unique_named_test_func_reg") + assert result is unique_named_test_func_reg + + +class TestRegistryGet: + def test_get_component(self): + @Registry.register(DataComponentType.LOAD_DATASET, name="test_get_comp_reg") + def my_func(): + pass + + result = Registry.get_component(DataComponentType.LOAD_DATASET.value, "test_get_comp_reg") + assert result is my_func + + def test_get_by_subtype(self): + @Registry.register(DataComponentType.LOAD_DATASET, name="test_get_subtype_reg") + def my_func(): + pass + + result = Registry.get(DataComponentType.LOAD_DATASET.value, "test_get_subtype_reg") + assert result is my_func + + +class TestRegistryDefaultComponents: + def test_get_default_load_dataset(self): + result = Registry.get_default_load_dataset_component() + assert result is not None + + def test_get_default_pre_process(self): + result = Registry.get_default_pre_process_component() + assert result is not None + + def test_get_default_post_process(self): + result = Registry.get_default_post_process_component() + assert result is not None + + def test_get_default_dataloader(self): + result = Registry.get_default_dataloader_component() + assert result is not None + + +class TestRegistryContainer: + def test_get_container_default(self): + result = Registry.get_container(None) + assert result is not None + + def test_get_container_by_name(self): + result = Registry.get_container(DefaultDataContainer.DATA_CONTAINER.value) + assert result is not None diff --git a/test/evaluator/test_metric_config.py b/test/evaluator/test_metric_config.py new file mode 100644 index 0000000000..3b3f1f364c --- /dev/null +++ b/test/evaluator/test_metric_config.py @@ -0,0 +1,137 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import pytest +from pydantic import ValidationError + +from olive.evaluator.metric_config import ( + LatencyMetricConfig, + MetricGoal, + SizeOnDiskMetricConfig, + ThroughputMetricConfig, + get_user_config_class, +) + + +class TestLatencyMetricConfig: + def test_defaults(self): + config = LatencyMetricConfig() + assert config.warmup_num == 10 + assert config.repeat_test_num == 20 + assert config.sleep_num == 0 + + def test_custom_values(self): + config = LatencyMetricConfig(warmup_num=5, repeat_test_num=100, sleep_num=2) + assert config.warmup_num == 5 + assert config.repeat_test_num == 100 + assert config.sleep_num == 2 + + +class TestThroughputMetricConfig: + def test_defaults(self): + config = ThroughputMetricConfig() + assert config.warmup_num == 10 + assert config.repeat_test_num == 20 + assert config.sleep_num == 0 + + def test_custom_values(self): + config = ThroughputMetricConfig(warmup_num=3, repeat_test_num=50, sleep_num=1) + assert config.warmup_num == 3 + assert config.repeat_test_num == 50 + assert config.sleep_num == 1 + + +class TestSizeOnDiskMetricConfig: + def test_creation(self): + config = SizeOnDiskMetricConfig() + assert isinstance(config, SizeOnDiskMetricConfig) + + +class TestMetricGoal: + def test_threshold_type(self): + goal = MetricGoal(type="threshold", value=0.9) + assert goal.type == "threshold" + assert goal.value == 0.9 + + def test_min_improvement_type(self): + goal = MetricGoal(type="min-improvement", value=0.05) + assert goal.type == "min-improvement" + assert goal.value == 0.05 + + def test_max_degradation_type(self): + goal = MetricGoal(type="max-degradation", value=0.1) + assert goal.type == "max-degradation" + assert goal.value == 0.1 + + def test_percent_min_improvement_type(self): + goal = MetricGoal(type="percent-min-improvement", value=5.0) + assert goal.type == "percent-min-improvement" + + def test_percent_max_degradation_type(self): + goal = MetricGoal(type="percent-max-degradation", value=10.0) + assert goal.type == "percent-max-degradation" + + def test_invalid_type_raises(self): + with pytest.raises(ValidationError, match="Metric goal type must be one of"): + MetricGoal(type="invalid_type", value=0.5) + + def test_negative_value_for_min_improvement_raises(self): + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="min-improvement", value=-0.5) + + def test_negative_value_for_max_degradation_raises(self): + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="max-degradation", value=-0.1) + + def test_negative_value_for_percent_min_improvement_raises(self): + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="percent-min-improvement", value=-5.0) + + def test_negative_value_for_percent_max_degradation_raises(self): + with pytest.raises(ValidationError, match="Value must be nonnegative"): + MetricGoal(type="percent-max-degradation", value=-10.0) + + def test_threshold_allows_negative_value(self): + goal = MetricGoal(type="threshold", value=-1.0) + assert goal.value == -1.0 + + def test_has_regression_goal_min_improvement(self): + goal = MetricGoal(type="min-improvement", value=0.05) + assert goal.has_regression_goal() is False + + def test_has_regression_goal_percent_min_improvement(self): + goal = MetricGoal(type="percent-min-improvement", value=5.0) + assert goal.has_regression_goal() is False + + def test_has_regression_goal_max_degradation_positive(self): + goal = MetricGoal(type="max-degradation", value=0.1) + assert goal.has_regression_goal() is True + + def test_has_regression_goal_max_degradation_zero(self): + goal = MetricGoal(type="max-degradation", value=0.0) + assert goal.has_regression_goal() is False + + def test_has_regression_goal_percent_max_degradation_positive(self): + goal = MetricGoal(type="percent-max-degradation", value=10.0) + assert goal.has_regression_goal() is True + + def test_has_regression_goal_threshold(self): + goal = MetricGoal(type="threshold", value=0.9) + assert goal.has_regression_goal() is False + + +class TestGetUserConfigClass: + def test_custom_metric_type(self): + cls = get_user_config_class("custom") + instance = cls() + assert hasattr(instance, "user_script") + assert hasattr(instance, "evaluate_func") + + def test_unknown_metric_type(self): + cls = get_user_config_class("latency") + instance = cls() + assert hasattr(instance, "user_script") + # Unknown metric types still get common config + assert hasattr(instance, "inference_settings") diff --git a/test/evaluator/test_metric_result.py b/test/evaluator/test_metric_result.py new file mode 100644 index 0000000000..89949476c7 --- /dev/null +++ b/test/evaluator/test_metric_result.py @@ -0,0 +1,108 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import json + +from olive.evaluator.metric_result import ( + MetricResult, + SubMetricResult, + flatten_metric_result, + flatten_metric_sub_type, + joint_metric_key, +) + + +class TestSubMetricResult: + def test_creation(self): + result = SubMetricResult(value=0.95, priority=1, higher_is_better=True) + assert result.value == 0.95 + assert result.priority == 1 + assert result.higher_is_better is True + + def test_integer_value(self): + result = SubMetricResult(value=100, priority=2, higher_is_better=False) + assert result.value == 100 + + def test_float_value(self): + result = SubMetricResult(value=0.001, priority=0, higher_is_better=True) + assert result.value == 0.001 + + +class TestMetricResult: + def _make_result(self): + return MetricResult.model_validate( + { + "accuracy-top1": SubMetricResult(value=0.95, priority=1, higher_is_better=True), + "latency-avg": SubMetricResult(value=10.5, priority=2, higher_is_better=False), + "latency-p99": SubMetricResult(value=20.0, priority=3, higher_is_better=False), + } + ) + + def test_get_value(self): + result = self._make_result() + assert result.get_value("accuracy", "top1") == 0.95 + assert result.get_value("latency", "avg") == 10.5 + + def test_get_all_sub_type_metric_value(self): + result = self._make_result() + latency_values = result.get_all_sub_type_metric_value("latency") + assert latency_values == {"avg": 10.5, "p99": 20.0} + + def test_get_all_sub_type_single_metric(self): + result = self._make_result() + accuracy_values = result.get_all_sub_type_metric_value("accuracy") + assert accuracy_values == {"top1": 0.95} + + def test_str_representation(self): + result = self._make_result() + result_str = str(result) + parsed = json.loads(result_str) + assert parsed["accuracy-top1"] == 0.95 + assert parsed["latency-avg"] == 10.5 + + def test_len(self): + result = self._make_result() + assert len(result) == 3 + + def test_getitem(self): + result = self._make_result() + assert result["accuracy-top1"].value == 0.95 + + def test_delimiter(self): + assert MetricResult.delimiter == "-" + + +class TestJointMetricKey: + def test_basic(self): + assert joint_metric_key("accuracy", "top1") == "accuracy-top1" + + def test_with_special_names(self): + assert joint_metric_key("latency", "p99") == "latency-p99" + + +class TestFlattenMetricSubType: + def test_flatten(self): + metric_dict = { + "accuracy": {"top1": {"value": 0.95, "priority": 1, "higher_is_better": True}}, + "latency": {"avg": {"value": 10.5, "priority": 2, "higher_is_better": False}}, + } + result = flatten_metric_sub_type(metric_dict) + assert "accuracy-top1" in result + assert "latency-avg" in result + + +class TestFlattenMetricResult: + def test_flatten_to_metric_result(self): + dict_results = { + "accuracy": { + "top1": {"value": 0.95, "priority": 1, "higher_is_better": True}, + }, + "latency": { + "avg": {"value": 10.5, "priority": 2, "higher_is_better": False}, + }, + } + result = flatten_metric_result(dict_results) + assert isinstance(result, MetricResult) + assert result.get_value("accuracy", "top1") == 0.95 + assert result.get_value("latency", "avg") == 10.5 diff --git a/test/search/test_search_parameter.py b/test/search/test_search_parameter.py new file mode 100644 index 0000000000..e3c05231d4 --- /dev/null +++ b/test/search/test_search_parameter.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import pytest + +from olive.search.search_parameter import ( + Boolean, + Categorical, + Conditional, + ConditionalDefault, + SpecialParamValue, + json_to_search_parameter, +) + + +class TestSpecialParamValue: + def test_ignored_value(self): + assert SpecialParamValue.IGNORED == "OLIVE_IGNORED_PARAM_VALUE" + + def test_invalid_value(self): + assert SpecialParamValue.INVALID == "OLIVE_INVALID_PARAM_VALUE" + + +class TestCategorical: + def test_int_support(self): + cat = Categorical([1, 2, 3]) + assert cat.get_support() == [1, 2, 3] + + def test_string_support(self): + cat = Categorical(["a", "b", "c"]) + assert cat.get_support() == ["a", "b", "c"] + + def test_float_support(self): + cat = Categorical([0.1, 0.5, 1.0]) + assert cat.get_support() == [0.1, 0.5, 1.0] + + def test_bool_support(self): + cat = Categorical([True, False]) + assert cat.get_support() == [True, False] + + def test_repr(self): + cat = Categorical([1, 2, 3]) + assert repr(cat) == "Categorical([1, 2, 3])" + + def test_to_json(self): + cat = Categorical([1, 2, 3]) + result = cat.to_json() + assert result["olive_parameter_type"] == "SearchParameter" + assert result["type"] == "Categorical" + assert result["support"] == [1, 2, 3] + + +class TestBoolean: + def test_support(self): + b = Boolean() + assert b.get_support() == [True, False] + + def test_is_categorical(self): + assert issubclass(Boolean, Categorical) + + +class TestConditional: + def test_single_parent(self): + cond = Conditional( + parents=("parent1",), + support={ + ("value1",): Categorical([1, 2, 3]), + ("value2",): Categorical([4, 5, 6]), + }, + default=Categorical([7, 8, 9]), + ) + assert cond.get_support_with_args({"parent1": "value1"}) == [1, 2, 3] + assert cond.get_support_with_args({"parent1": "value2"}) == [4, 5, 6] + assert cond.get_support_with_args({"parent1": "unknown"}) == [7, 8, 9] + + def test_multi_parent(self): + cond = Conditional( + parents=("parent1", "parent2"), + support={ + ("v1", "v2"): Categorical([10, 20]), + }, + ) + assert cond.get_support_with_args({"parent1": "v1", "parent2": "v2"}) == [10, 20] + + def test_default_is_invalid(self): + cond = Conditional( + parents=("p",), + support={("a",): Categorical([1])}, + ) + # Default is invalid choice + assert cond.get_support_with_args({"p": "missing"}) == [SpecialParamValue.INVALID] + + def test_get_invalid_choice(self): + result = Conditional.get_invalid_choice() + assert isinstance(result, Categorical) + assert result.get_support() == [SpecialParamValue.INVALID] + + def test_get_ignored_choice(self): + result = Conditional.get_ignored_choice() + assert isinstance(result, Categorical) + assert result.get_support() == [SpecialParamValue.IGNORED] + + def test_repr(self): + cond = Conditional( + parents=("p",), + support={("a",): Categorical([1])}, + ) + result = repr(cond) + assert "Conditional" in result + assert "parents" in result + + def test_to_json(self): + cond = Conditional( + parents=("p",), + support={("a",): Categorical([1, 2])}, + default=Categorical([3]), + ) + result = cond.to_json() + assert result["olive_parameter_type"] == "SearchParameter" + assert result["type"] == "Conditional" + assert result["parents"] == ("p",) + + def test_condition_single_parent(self): + cond = Conditional( + parents=("p",), + support={ + ("a",): Categorical([1, 2]), + ("b",): Categorical([3, 4]), + }, + ) + result = cond.condition({"p": "a"}) + assert isinstance(result, Categorical) + assert result.get_support() == [1, 2] + + def test_condition_returns_default_when_no_match(self): + cond = Conditional( + parents=("p1", "p2"), + support={("a", "b"): Categorical([1])}, + default=Categorical([99]), + ) + result = cond.condition({"p1": "missing"}) + assert isinstance(result, Categorical) + assert result.get_support() == [99] + + +class TestConditionalDefault: + def test_basic(self): + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1, ("b",): 2}, + default=3, + ) + assert cd.get_support_with_args({"p": "a"}) == 1 + assert cd.get_support_with_args({"p": "b"}) == 2 + assert cd.get_support_with_args({"p": "c"}) == 3 + + def test_default_invalid(self): + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1}, + ) + assert cd.get_support_with_args({"p": "missing"}) == SpecialParamValue.INVALID + + def test_condition(self): + cd = ConditionalDefault( + parents=("p",), + support={("a",): 10, ("b",): 20}, + default=30, + ) + assert cd.condition({"p": "a"}) == 10 + assert cd.condition({"p": "b"}) == 20 + assert cd.condition({"p": "c"}) == 30 + + def test_get_invalid_choice(self): + assert ConditionalDefault.get_invalid_choice() == SpecialParamValue.INVALID + + def test_get_ignored_choice(self): + assert ConditionalDefault.get_ignored_choice() == SpecialParamValue.IGNORED + + def test_to_json(self): + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1}, + default=2, + ) + result = cd.to_json() + assert result["type"] == "ConditionalDefault" + + def test_repr(self): + cd = ConditionalDefault( + parents=("p",), + support={("a",): 1}, + default=2, + ) + result = repr(cd) + assert "ConditionalDefault" in result + + +class TestJsonToSearchParameter: + def test_categorical(self): + json_data = {"olive_parameter_type": "SearchParameter", "type": "Categorical", "support": [1, 2, 3]} + result = json_to_search_parameter(json_data) + assert isinstance(result, Categorical) + assert result.get_support() == [1, 2, 3] + + def test_unknown_type_raises(self): + json_data = {"olive_parameter_type": "SearchParameter", "type": "Unknown"} + with pytest.raises(ValueError, match="Unknown search parameter type"): + json_to_search_parameter(json_data) diff --git a/test/search/test_search_point.py b/test/search/test_search_point.py new file mode 100644 index 0000000000..043bac234f --- /dev/null +++ b/test/search/test_search_point.py @@ -0,0 +1,83 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.search_parameter import SpecialParamValue +from olive.search.search_point import SearchPoint + + +class TestSearchPoint: + def _make_point(self, index=0, values=None): + if values is None: + values = OrderedDict( + { + "pass1": ( + 0, + OrderedDict( + { + "param1": (0, "a"), + "param2": (1, 10), + } + ), + ) + } + ) + return SearchPoint(index=index, values=values) + + def test_creation(self): + point = self._make_point() + assert point.index == 0 + + def test_repr(self): + point = self._make_point() + result = repr(point) + assert "SearchPoint" in result + assert "0" in result + + def test_equality_same(self): + point1 = self._make_point() + point2 = self._make_point() + assert point1 == point2 + + def test_equality_different_index(self): + point1 = self._make_point(index=0) + point2 = self._make_point(index=1) + assert point1 != point2 + + def test_equality_different_type(self): + point = self._make_point() + assert point != "not a search point" + + def test_is_valid_true(self): + point = self._make_point() + assert point.is_valid() is True + + def test_is_valid_false_with_invalid(self): + # is_valid checks for OrderedDict values recursively, and checks + # non-OrderedDict values against SpecialParamValue.INVALID + values = OrderedDict( + { + "pass1": OrderedDict( + { + "param1": SpecialParamValue.INVALID, + } + ) + } + ) + point = SearchPoint(index=0, values=values) + assert point.is_valid() is False + + def test_to_json(self): + point = self._make_point(index=5) + result = point.to_json() + assert result["index"] == 5 + assert "values" in result + + def test_from_json_roundtrip(self): + point = self._make_point(index=3) + json_data = point.to_json() + restored = SearchPoint.from_json(json_data) + assert restored.index == 3 + assert restored == point diff --git a/test/search/test_search_sample.py b/test/search/test_search_sample.py new file mode 100644 index 0000000000..c233abb687 --- /dev/null +++ b/test/search/test_search_sample.py @@ -0,0 +1,77 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.search_parameter import SpecialParamValue +from olive.search.search_point import SearchPoint +from olive.search.search_sample import SearchSample + + +class TestSearchSample: + def _make_sample(self, invalid_param=False, ignored_param=False): + param_value = "a" + if invalid_param: + param_value = SpecialParamValue.INVALID + elif ignored_param: + param_value = SpecialParamValue.IGNORED + + values = OrderedDict( + { + "pass1": ( + 0, + OrderedDict( + { + "param1": (0, param_value), + "param2": (1, 10), + } + ), + ) + } + ) + point = SearchPoint(index=0, values=values) + return SearchSample(search_point=point, model_ids=["model_0"]) + + def test_creation(self): + sample = self._make_sample() + assert sample.model_ids == ["model_0"] + assert sample.search_point.index == 0 + + def test_passes_configs_valid(self): + sample = self._make_sample() + configs = sample.passes_configs + assert configs is not None + assert "pass1" in configs + assert configs["pass1"]["params"]["param1"] == "a" + assert configs["pass1"]["params"]["param2"] == 10 + + def test_passes_configs_with_invalid_returns_none(self): + sample = self._make_sample(invalid_param=True) + assert sample.passes_configs is None + + def test_passes_configs_with_ignored_excludes_param(self): + sample = self._make_sample(ignored_param=True) + configs = sample.passes_configs + assert configs is not None + assert "param1" not in configs["pass1"]["params"] + assert configs["pass1"]["params"]["param2"] == 10 + + def test_to_json(self): + sample = self._make_sample() + result = sample.to_json() + assert "search_point" in result + assert "model_ids" in result + assert result["model_ids"] == ["model_0"] + + def test_from_json_roundtrip(self): + sample = self._make_sample() + json_data = sample.to_json() + restored = SearchSample.from_json(json_data) + assert restored.model_ids == sample.model_ids + assert restored.search_point.index == sample.search_point.index + + def test_repr(self): + sample = self._make_sample() + result = repr(sample) + assert "SearchSample" in result diff --git a/test/telemetry/__init__.py b/test/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/telemetry/test_telemetry_utils.py b/test/telemetry/test_telemetry_utils.py new file mode 100644 index 0000000000..4a412dd200 --- /dev/null +++ b/test/telemetry/test_telemetry_utils.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import base64 +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from olive.telemetry.utils import ( + _decode_cache_line, + _encode_cache_line, + _format_exception_message, + _resolve_home_dir, + get_telemetry_base_dir, +) + + +class TestResolveHomeDir: + def test_returns_path(self): + result = _resolve_home_dir() + assert isinstance(result, Path) + + def test_with_home_env_set(self): + with patch.dict(os.environ, {"HOME": "/tmp/test_home"}): + result = _resolve_home_dir() + assert isinstance(result, Path) + + def test_without_home_env(self): + with patch.dict(os.environ, {}, clear=True): + result = _resolve_home_dir() + assert isinstance(result, Path) + + +class TestGetTelemetryBaseDir: + def test_returns_path(self): + # Clear the lru_cache before test + get_telemetry_base_dir.cache_clear() + result = get_telemetry_base_dir() + assert isinstance(result, Path) + + def test_path_contains_onnxruntime(self): + get_telemetry_base_dir.cache_clear() + result = get_telemetry_base_dir() + assert ".onnxruntime" in str(result) + + +class TestFormatExceptionMessage: + def test_basic_exception(self): + try: + 1 / 0 # noqa: B018 + except ZeroDivisionError as ex: + result = _format_exception_message(ex, ex.__traceback__) + assert "ZeroDivisionError" in result + + def test_exception_without_traceback(self): + ex = ValueError("test error") + result = _format_exception_message(ex) + assert "test error" in result + + +class TestEncodeCacheLine: + def test_encode_basic_string(self): + result = _encode_cache_line("hello") + expected = base64.b64encode(b"hello").decode("ascii") + assert result == expected + + def test_encode_empty_string(self): + result = _encode_cache_line("") + expected = base64.b64encode(b"").decode("ascii") + assert result == expected + + def test_encode_unicode_string(self): + result = _encode_cache_line("hello 世界") + decoded = base64.b64decode(result).decode("utf-8") + assert decoded == "hello 世界" + + +class TestDecodeCacheLine: + def test_decode_basic_string(self): + encoded = base64.b64encode(b"hello").decode("ascii") + result = _decode_cache_line(encoded) + assert result == "hello" + + def test_decode_empty_string(self): + encoded = base64.b64encode(b"").decode("ascii") + result = _decode_cache_line(encoded) + assert result == "" + + +class TestEncodeDecodeRoundtrip: + @pytest.mark.parametrize( + "text", + [ + "simple text", + "path/to/file.json", + '{"key": "value"}', + "special chars: !@#$%^&*()", + "", + ], + ) + def test_roundtrip(self, text): + encoded = _encode_cache_line(text) + decoded = _decode_cache_line(encoded) + assert decoded == text diff --git a/test/test_constants.py b/test/test_constants.py new file mode 100644 index 0000000000..d7d06fa596 --- /dev/null +++ b/test/test_constants.py @@ -0,0 +1,152 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import pytest + +from olive.constants import ( + MSFT_DOMAIN, + AccuracyLevel, + DatasetRequirement, + DiffusersComponent, + DiffusersModelVariant, + Framework, + ModelFileFormat, + OpType, + Precision, + PrecisionBits, + QuantAlgorithm, + QuantEncoding, + precision_bits_from_precision, +) + + +class TestFramework: + def test_framework_values(self): + assert Framework.ONNX == "ONNX" + assert Framework.PYTORCH == "PyTorch" + assert Framework.OPENVINO == "OpenVINO" + + def test_framework_str(self): + assert str(Framework.ONNX) == "ONNX" + + def test_framework_all_members(self): + expected = {"ONNX", "PYTORCH", "QAIRT", "QNN", "TENSORFLOW", "OPENVINO"} + assert set(Framework.__members__.keys()) == expected + + +class TestModelFileFormat: + def test_model_file_format_values(self): + assert ModelFileFormat.ONNX == "ONNX" + assert ModelFileFormat.PYTORCH_STATE_DICT == "PyTorch.StateDict" + assert ModelFileFormat.COMPOSITE_MODEL == "Composite" + + def test_model_file_format_str(self): + assert str(ModelFileFormat.OPENVINO_IR) == "OpenVINO.IR" + + +class TestPrecision: + def test_precision_values(self): + assert Precision.INT4 == "int4" + assert Precision.FP16 == "fp16" + assert Precision.BF16 == "bf16" + + def test_precision_all_members(self): + expected_count = 14 + assert len(Precision) == expected_count + + +class TestPrecisionBits: + def test_precision_bits_values(self): + assert PrecisionBits.BITS2 == 2 + assert PrecisionBits.BITS4 == 4 + assert PrecisionBits.BITS8 == 8 + assert PrecisionBits.BITS16 == 16 + assert PrecisionBits.BITS32 == 32 + + def test_precision_bits_is_int(self): + assert isinstance(PrecisionBits.BITS4.value, int) + + +class TestQuantAlgorithm: + def test_quant_algorithm_case_insensitive(self): + assert QuantAlgorithm("awq") == QuantAlgorithm.AWQ + assert QuantAlgorithm("AWQ") == QuantAlgorithm.AWQ + assert QuantAlgorithm("Awq") == QuantAlgorithm.AWQ + + def test_quant_algorithm_values(self): + assert QuantAlgorithm.GPTQ == "gptq" + assert QuantAlgorithm.RTN == "rtn" + + def test_quant_algorithm_all_members(self): + expected = {"AWQ", "GPTQ", "HQQ", "RTN", "SPINQUANT", "QUAROT", "LPBQ", "SEQMSE", "ADAROUND"} + assert set(QuantAlgorithm.__members__.keys()) == expected + + +class TestQuantEncoding: + def test_quant_encoding_values(self): + assert QuantEncoding.QDQ == "qdq" + assert QuantEncoding.QOP == "qop" + + +class TestDatasetRequirement: + def test_dataset_requirement_values(self): + assert DatasetRequirement.REQUIRED == "dataset_required" + assert DatasetRequirement.OPTIONAL == "dataset_optional" + assert DatasetRequirement.NOT_REQUIRED == "dataset_not_required" + + +class TestOpType: + def test_op_type_values(self): + assert OpType.MatMul == "MatMul" + assert OpType.Add == "Add" + assert OpType.Custom == "custom" + + +class TestAccuracyLevel: + def test_accuracy_level_values(self): + assert AccuracyLevel.unset == 0 + assert AccuracyLevel.fp32 == 1 + assert AccuracyLevel.fp16 == 2 + assert AccuracyLevel.int8 == 4 + + +class TestDiffusersModelVariant: + def test_diffusers_variant_values(self): + assert DiffusersModelVariant.AUTO == "auto" + assert DiffusersModelVariant.SD == "sd" + assert DiffusersModelVariant.FLUX == "flux" + + +class TestDiffusersComponent: + def test_diffusers_component_values(self): + assert DiffusersComponent.TEXT_ENCODER == "text_encoder" + assert DiffusersComponent.UNET == "unet" + assert DiffusersComponent.VAE_DECODER == "vae_decoder" + + +class TestPrecisionBitsFromPrecision: + @pytest.mark.parametrize( + ("precision", "expected"), + [ + (Precision.INT4, PrecisionBits.BITS4), + (Precision.INT8, PrecisionBits.BITS8), + (Precision.INT16, PrecisionBits.BITS16), + (Precision.INT32, PrecisionBits.BITS32), + (Precision.UINT4, PrecisionBits.BITS4), + (Precision.UINT8, PrecisionBits.BITS8), + (Precision.UINT16, PrecisionBits.BITS16), + (Precision.UINT32, PrecisionBits.BITS32), + ], + ) + def test_precision_to_bits_mapping(self, precision, expected): + assert precision_bits_from_precision(precision) == expected + + @pytest.mark.parametrize("precision", [Precision.FP16, Precision.FP32, Precision.BF16, Precision.NF4]) + def test_precision_without_bits_mapping_returns_none(self, precision): + assert precision_bits_from_precision(precision) is None + + +class TestMsftDomain: + def test_msft_domain_value(self): + assert MSFT_DOMAIN == "com.microsoft" diff --git a/test/test_exception.py b/test/test_exception.py new file mode 100644 index 0000000000..088714eaa5 --- /dev/null +++ b/test/test_exception.py @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import pytest + +from olive.exception import EXCEPTIONS_TO_RAISE, OliveError, OliveEvaluationError, OlivePassError + + +class TestOliveError: + def test_olive_error_is_exception(self): + assert issubclass(OliveError, Exception) + + def test_olive_error_can_be_raised(self): + with pytest.raises(OliveError, match="test error"): + raise OliveError("test error") + + def test_olive_error_empty_message(self): + with pytest.raises(OliveError): + raise OliveError + + +class TestOlivePassError: + def test_olive_pass_error_inherits_olive_error(self): + assert issubclass(OlivePassError, OliveError) + + def test_olive_pass_error_can_be_raised(self): + with pytest.raises(OlivePassError, match="pass failed"): + raise OlivePassError("pass failed") + + def test_olive_pass_error_caught_as_olive_error(self): + with pytest.raises(OliveError): + raise OlivePassError("pass failed") + + +class TestOliveEvaluationError: + def test_olive_evaluation_error_inherits_olive_error(self): + assert issubclass(OliveEvaluationError, OliveError) + + def test_olive_evaluation_error_can_be_raised(self): + with pytest.raises(OliveEvaluationError, match="evaluation failed"): + raise OliveEvaluationError("evaluation failed") + + def test_olive_evaluation_error_caught_as_olive_error(self): + with pytest.raises(OliveError): + raise OliveEvaluationError("evaluation failed") + + +class TestExceptionsToRaise: + def test_exceptions_to_raise_is_tuple(self): + assert isinstance(EXCEPTIONS_TO_RAISE, tuple) + + def test_exceptions_to_raise_contains_expected_types(self): + expected = {AssertionError, AttributeError, ImportError, TypeError, ValueError} + assert set(EXCEPTIONS_TO_RAISE) == expected + + @pytest.mark.parametrize("exc_type", EXCEPTIONS_TO_RAISE) + def test_each_exception_is_catchable(self, exc_type): + with pytest.raises(exc_type): + raise exc_type("test") diff --git a/test/test_logging.py b/test/test_logging.py new file mode 100644 index 0000000000..0894a07dc1 --- /dev/null +++ b/test/test_logging.py @@ -0,0 +1,127 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from pathlib import Path +from unittest.mock import patch + +import pytest + +from olive.logging import ( + enable_filelog, + get_logger_level, + get_olive_logger, + get_verbosity, + set_default_logger_severity, + set_verbosity, + set_verbosity_critical, + set_verbosity_debug, + set_verbosity_error, + set_verbosity_from_env, + set_verbosity_info, + set_verbosity_warning, +) + + +class TestGetOliveLogger: + def test_returns_olive_logger(self): + logger = get_olive_logger() + assert logger.name == "olive" + + def test_returns_same_logger_instance(self): + logger1 = get_olive_logger() + logger2 = get_olive_logger() + assert logger1 is logger2 + + +class TestSetVerbosity: + def test_set_verbosity_info(self): + set_verbosity_info() + assert get_olive_logger().level == logging.INFO + + def test_set_verbosity_warning(self): + set_verbosity_warning() + assert get_olive_logger().level == logging.WARNING + + def test_set_verbosity_debug(self): + set_verbosity_debug() + assert get_olive_logger().level == logging.DEBUG + + def test_set_verbosity_error(self): + set_verbosity_error() + assert get_olive_logger().level == logging.ERROR + + def test_set_verbosity_critical(self): + set_verbosity_critical() + assert get_olive_logger().level == logging.CRITICAL + + def test_set_verbosity_custom_level(self): + set_verbosity(logging.WARNING) + assert get_olive_logger().level == logging.WARNING + + +class TestSetVerbosityFromEnv: + def test_set_verbosity_from_env_default(self): + with patch.dict("os.environ", {}, clear=True): + set_verbosity_from_env() + + def test_set_verbosity_from_env_custom(self): + with patch.dict("os.environ", {"OLIVE_LOG_LEVEL": "DEBUG"}): + set_verbosity_from_env() + assert get_olive_logger().level == logging.DEBUG + + +class TestGetVerbosity: + def test_get_verbosity_returns_int(self): + set_verbosity_info() + level = get_verbosity() + assert isinstance(level, int) + assert level == logging.INFO + + +class TestGetLoggerLevel: + @pytest.mark.parametrize( + ("level_int", "expected"), + [ + (0, logging.DEBUG), + (1, logging.INFO), + (2, logging.WARNING), + (3, logging.ERROR), + (4, logging.CRITICAL), + ], + ) + def test_valid_levels(self, level_int, expected): + assert get_logger_level(level_int) == expected + + @pytest.mark.parametrize("invalid_level", [-1, 5, 10, 100]) + def test_invalid_levels_raise_value_error(self, invalid_level): + with pytest.raises(ValueError, match="Invalid level"): + get_logger_level(invalid_level) + + +class TestSetDefaultLoggerSeverity: + @pytest.mark.parametrize("level", [0, 1, 2, 3, 4]) + def test_set_default_logger_severity(self, level): + set_default_logger_severity(level) + expected = get_logger_level(level) + assert get_olive_logger().level == expected + + +class TestEnableFilelog: + def test_enable_filelog_creates_handler(self, tmp_path): + workflow_id = "test_workflow" + enable_filelog(1, str(tmp_path), workflow_id) + + logger = get_olive_logger() + log_file_path = tmp_path / f"{workflow_id}.log" + + # Check that a file handler was added + file_handlers = [h for h in logger.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) > 0 + + # Clean up: remove the handler we added + for h in file_handlers: + if Path(h.baseFilename) == log_file_path.resolve(): + logger.removeHandler(h) + h.close() From fc72c734cea535fea493da00a19445de892e0c4b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 00:00:28 +0000 Subject: [PATCH 2/6] Improve comment clarity in test_search_point.py per code review Agent-Logs-Url: https://github.com/microsoft/Olive/sessions/656d3838-fc06-4d40-8a11-b75da7f23724 Co-authored-by: xiaoyu-work <85524621+xiaoyu-work@users.noreply.github.com> --- test/search/test_search_point.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/search/test_search_point.py b/test/search/test_search_point.py index 043bac234f..60a66e6bc9 100644 --- a/test/search/test_search_point.py +++ b/test/search/test_search_point.py @@ -55,8 +55,7 @@ def test_is_valid_true(self): assert point.is_valid() is True def test_is_valid_false_with_invalid(self): - # is_valid checks for OrderedDict values recursively, and checks - # non-OrderedDict values against SpecialParamValue.INVALID + # Verify that a SearchPoint with an invalid parameter value is detected as invalid values = OrderedDict( { "pass1": OrderedDict( From 142c767ee25c689a08f079e32b301db4d918d54f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 03:30:13 +0000 Subject: [PATCH 3/6] Add # setup / # execute / # assert section comments to all new test methods Agent-Logs-Url: https://github.com/microsoft/Olive/sessions/6dac7c0c-a4f5-4649-9a9a-1d9237f27a3d Co-authored-by: xiaoyu-work <85524621+xiaoyu-work@users.noreply.github.com> --- test/common/test_config_utils.py | 311 +++++++++++++++++++++++-- test/data_container/test_registry.py | 50 ++++ test/evaluator/test_metric_config.py | 129 +++++++++- test/evaluator/test_metric_result.py | 91 +++++++- test/search/test_search_parameter.py | 208 +++++++++++++++-- test/search/test_search_point.py | 61 ++++- test/search/test_search_sample.py | 38 ++- test/telemetry/test_telemetry_utils.py | 84 ++++++- test/test_constants.py | 212 ++++++++++++++--- test/test_exception.py | 61 ++++- test/test_logging.py | 81 ++++++- 11 files changed, 1204 insertions(+), 122 deletions(-) diff --git a/test/common/test_config_utils.py b/test/common/test_config_utils.py index e053fc7452..cb0ac3010b 100644 --- a/test/common/test_config_utils.py +++ b/test/common/test_config_utils.py @@ -31,10 +31,14 @@ class TestSerializeFunction: def test_serialize_function_returns_dict(self): + # setup def my_func(x): return x + # execute result = serialize_function(my_func) + + # assert assert result["olive_parameter_type"] == "Function" assert result["name"] == "my_func" assert "signature" in result @@ -43,8 +47,13 @@ def my_func(x): class TestSerializeObject: def test_serialize_object_returns_dict(self): + # setup obj = {"key": "value"} + + # execute result = serialize_object(obj) + + # assert assert result["olive_parameter_type"] == "Object" assert result["type"] == "dict" assert "hash" in result @@ -52,287 +61,521 @@ def test_serialize_object_returns_dict(self): class TestConfigJsonDumps: def test_basic_dict(self): + # setup data = {"key": "value", "num": 42} + + # execute result = config_json_dumps(data) parsed = json.loads(result) + + # assert assert parsed == data def test_path_serialization_absolute(self): + # setup data = {"path": Path("/some/path")} + + # execute result = config_json_dumps(data, make_absolute=True) parsed = json.loads(result) + + # assert assert isinstance(parsed["path"], str) def test_path_serialization_relative(self): + # setup data = {"path": Path("relative/path")} + + # execute result = config_json_dumps(data, make_absolute=False) parsed = json.loads(result) + + # assert assert parsed["path"] == "relative/path" def test_function_serialization(self): + # setup def sample_func(): pass data = {"func": sample_func} + + # execute result = config_json_dumps(data) parsed = json.loads(result) + + # assert assert parsed["func"]["olive_parameter_type"] == "Function" class TestConfigJsonLoads: def test_basic_json(self): + # setup data = '{"key": "value"}' + + # execute result = config_json_loads(data) + + # assert assert result == {"key": "value"} def test_function_object_raises_error(self): + # setup data = json.dumps({"olive_parameter_type": "Function", "name": "my_func"}) + + # execute & assert with pytest.raises(ValueError, match="Cannot load"): config_json_loads(data) def test_custom_object_hook(self): + # setup data = '{"key": "value"}' + + # execute result = config_json_loads(data, object_hook=lambda obj: obj) + + # assert assert result == {"key": "value"} class TestSerializeToJson: def test_dict_input(self): + # setup data = {"key": "value"} + + # execute result = serialize_to_json(data) + + # assert assert result == data def test_config_base_input(self): + # setup config = ConfigBase() + + # execute result = serialize_to_json(config) + + # assert assert isinstance(result, dict) def test_check_object_with_function_raises(self): + # setup def my_func(): pass + # execute & assert with pytest.raises(ValueError, match="Cannot serialize"): serialize_to_json({"func": my_func}, check_object=True) class TestLoadConfigFile: def test_load_json_file(self, tmp_path): + # setup config_file = tmp_path / "config.json" config_file.write_text('{"key": "value"}') + + # execute result = load_config_file(config_file) + + # assert assert result == {"key": "value"} def test_load_yaml_file(self, tmp_path): + # setup config_file = tmp_path / "config.yaml" config_file.write_text("key: value\n") + + # execute result = load_config_file(config_file) + + # assert assert result == {"key": "value"} def test_load_yml_file(self, tmp_path): + # setup config_file = tmp_path / "config.yml" config_file.write_text("key: value\n") + + # execute result = load_config_file(config_file) + + # assert assert result == {"key": "value"} def test_unsupported_file_type_raises(self, tmp_path): + # setup config_file = tmp_path / "config.txt" config_file.write_text("key=value") + + # execute & assert with pytest.raises(ValueError, match="Unsupported file type"): load_config_file(config_file) class TestConfigBase: def test_to_json(self): + # setup config = ConfigBase() + + # execute result = config.to_json() + + # assert assert isinstance(result, dict) def test_from_json(self): + # setup + + # execute config = ConfigBase.from_json({}) + + # assert assert isinstance(config, ConfigBase) def test_parse_file_or_obj_dict(self): + # setup + + # execute config = ConfigBase.parse_file_or_obj({}) + + # assert assert isinstance(config, ConfigBase) def test_parse_file_or_obj_json_file(self, tmp_path): + # setup config_file = tmp_path / "config.json" config_file.write_text("{}") + + # execute config = ConfigBase.parse_file_or_obj(config_file) + + # assert assert isinstance(config, ConfigBase) class TestConfigListBase: def test_iter(self): + # setup config = ConfigListBase.model_validate([1, 2, 3]) - assert list(config) == [1, 2, 3] + + # execute + result = list(config) + + # assert + assert result == [1, 2, 3] def test_getitem(self): + # setup config = ConfigListBase.model_validate([10, 20, 30]) - assert config[0] == 10 - assert config[2] == 30 + + # execute + first = config[0] + last = config[2] + + # assert + assert first == 10 + assert last == 30 def test_len(self): + # setup config = ConfigListBase.model_validate([1, 2, 3]) - assert len(config) == 3 + + # execute + result = len(config) + + # assert + assert result == 3 def test_to_json(self): + # setup config = ConfigListBase.model_validate([1, 2, 3]) + + # execute result = config.to_json() + + # assert assert isinstance(result, list) class TestConfigDictBase: def test_iter(self): + # setup config = ConfigDictBase.model_validate({"a": 1, "b": 2}) - assert set(config) == {"a", "b"} + + # execute + result = set(config) + + # assert + assert result == {"a", "b"} def test_keys(self): + # setup config = ConfigDictBase.model_validate({"a": 1, "b": 2}) - assert set(config.keys()) == {"a", "b"} + + # execute + result = set(config.keys()) + + # assert + assert result == {"a", "b"} def test_values(self): + # setup config = ConfigDictBase.model_validate({"a": 1, "b": 2}) - assert set(config.values()) == {1, 2} + + # execute + result = set(config.values()) + + # assert + assert result == {1, 2} def test_items(self): + # setup config = ConfigDictBase.model_validate({"a": 1, "b": 2}) - assert dict(config.items()) == {"a": 1, "b": 2} + + # execute + result = dict(config.items()) + + # assert + assert result == {"a": 1, "b": 2} def test_getitem(self): + # setup config = ConfigDictBase.model_validate({"key": "value"}) - assert config["key"] == "value" + + # execute + result = config["key"] + + # assert + assert result == "value" def test_len(self): + # setup config = ConfigDictBase.model_validate({"a": 1, "b": 2}) - assert len(config) == 2 + + # execute + result = len(config) + + # assert + assert result == 2 def test_len_empty(self): + # setup config = ConfigDictBase.model_validate({}) - assert len(config) == 0 + + # execute + result = len(config) + + # assert + assert result == 0 class TestNestedConfig: def test_gather_nested_field_basic(self): + # setup class MyConfig(NestedConfig): type: str config: dict = Field(default_factory=dict) + # execute c = MyConfig(type="test", key1="val1") + + # assert assert c.config == {"key1": "val1"} def test_gather_nested_field_none_values(self): + # setup class MyConfig(NestedConfig): type: str = "default" config: dict = Field(default_factory=dict) + # execute c = MyConfig.model_validate(None) + + # assert assert c.type == "default" def test_gather_nested_field_explicit_config(self): + # setup class MyConfig(NestedConfig): type: str config: dict = Field(default_factory=dict) + # execute c = MyConfig(type="test", config={"inner_key": "inner_val"}) + + # assert assert c.config == {"inner_key": "inner_val"} class TestCaseInsensitiveEnum: def test_case_insensitive_creation(self): - assert ParamCategory("none") == ParamCategory.NONE - assert ParamCategory("NONE") == ParamCategory.NONE - assert ParamCategory("None") == ParamCategory.NONE + # setup + + # execute + lower = ParamCategory("none") + upper = ParamCategory("NONE") + mixed = ParamCategory("None") + + # assert + assert lower == ParamCategory.NONE + assert upper == ParamCategory.NONE + assert mixed == ParamCategory.NONE def test_invalid_value_returns_none(self): + # setup + + # execute result = ParamCategory._missing_("nonexistent") + + # assert assert result is None class TestConfigParam: def test_config_param_defaults(self): + # setup + + # execute param = ConfigParam(type_=str) + + # assert assert param.required is False assert param.default_value is None assert param.category == ParamCategory.NONE def test_config_param_required(self): + # setup + + # execute param = ConfigParam(type_=str, required=True) + + # assert assert param.required is True def test_config_param_repr(self): + # setup param = ConfigParam(type_=str, required=True, description="A test param") + + # execute repr_str = repr(param) + + # assert assert "required=True" in repr_str assert "description=" in repr_str class TestValidateEnum: def test_valid_enum_value(self): + # setup + + # execute result = validate_enum(ParamCategory, "none") + + # assert assert result == ParamCategory.NONE def test_invalid_enum_value_raises(self): + # setup + + # execute & assert with pytest.raises(ValueError, match="Invalid value"): validate_enum(ParamCategory, "invalid_value") class TestValidateLowercase: def test_string_lowercased(self): - assert validate_lowercase("HELLO") == "hello" + # setup + + # execute + result = validate_lowercase("HELLO") + + # assert + assert result == "hello" def test_non_string_unchanged(self): - assert validate_lowercase(42) == 42 - assert validate_lowercase(None) is None + # setup + + # execute + result_int = validate_lowercase(42) + result_none = validate_lowercase(None) + + # assert + assert result_int == 42 + assert result_none is None class TestCreateConfigClass: def test_create_basic_config_class(self): + # setup config = { "name": ConfigParam(type_=str, required=True), "value": ConfigParam(type_=int, default_value=10), } + + # execute cls = create_config_class("TestConfig", config) instance = cls(name="test") + + # assert assert instance.name == "test" assert instance.value == 10 def test_create_config_class_with_optional_field(self): + # setup config = { "name": ConfigParam(type_=str, default_value=None), } + + # execute cls = create_config_class("OptionalConfig", config) instance = cls() + + # assert assert instance.name is None class TestValidateConfig: def test_validate_dict_config(self): + # setup config = {"name": "test"} class MyConfig(ConfigBase): name: str + # execute result = validate_config(config, MyConfig) + + # assert assert isinstance(result, MyConfig) assert result.name == "test" def test_validate_none_config(self): + # setup class MyConfig(ConfigBase): name: str = "default" + # execute result = validate_config(None, MyConfig) + + # assert assert result.name == "default" def test_validate_config_instance(self): + # setup class MyConfig(ConfigBase): name: str config = MyConfig(name="test") + + # execute result = validate_config(config, MyConfig) + + # assert assert result.name == "test" def test_validate_config_wrong_class_raises(self): + # setup class MyConfig(ConfigBase): name: str @@ -340,24 +583,50 @@ class OtherConfig(ConfigBase): value: int config = OtherConfig(value=42) + + # execute & assert with pytest.raises(ValueError, match="Invalid config class"): validate_config(config, MyConfig) class TestConvertConfigsToDicts: def test_config_base_to_dict(self): + # setup config = ConfigBase() + + # execute result = convert_configs_to_dicts(config) + + # assert assert isinstance(result, dict) def test_nested_dict_conversion(self): - result = convert_configs_to_dicts({"key": "value"}) + # setup + data = {"key": "value"} + + # execute + result = convert_configs_to_dicts(data) + + # assert assert result == {"key": "value"} def test_list_conversion(self): - result = convert_configs_to_dicts(["a", "b"]) + # setup + data = ["a", "b"] + + # execute + result = convert_configs_to_dicts(data) + + # assert assert result == ["a", "b"] def test_plain_value_passthrough(self): - assert convert_configs_to_dicts(42) == 42 - assert convert_configs_to_dicts("hello") == "hello" + # setup + + # execute + result_int = convert_configs_to_dicts(42) + result_str = convert_configs_to_dicts("hello") + + # assert + assert result_int == 42 + assert result_str == "hello" diff --git a/test/data_container/test_registry.py b/test/data_container/test_registry.py index a4cb4de4cd..c55e8243ce 100644 --- a/test/data_container/test_registry.py +++ b/test/data_container/test_registry.py @@ -12,95 +12,145 @@ class TestRegistryRegister: def test_register_dataset_component(self): + # setup & execute @Registry.register(DataComponentType.LOAD_DATASET, name="test_dataset_reg") def my_dataset(): return "dataset" + # assert result = Registry.get_load_dataset_component("test_dataset_reg") assert result is my_dataset def test_register_pre_process_component(self): + # setup & execute @Registry.register_pre_process(name="test_pre_process_reg") def my_pre_process(data): return data + # assert result = Registry.get_pre_process_component("test_pre_process_reg") assert result is my_pre_process def test_register_post_process_component(self): + # setup & execute @Registry.register_post_process(name="test_post_process_reg") def my_post_process(data): return data + # assert result = Registry.get_post_process_component("test_post_process_reg") assert result is my_post_process def test_register_dataloader_component(self): + # setup & execute @Registry.register_dataloader(name="test_dataloader_reg") def my_dataloader(data): return data + # assert result = Registry.get_dataloader_component("test_dataloader_reg") assert result is my_dataloader def test_register_case_insensitive(self): + # setup & execute @Registry.register(DataComponentType.LOAD_DATASET, name="CaseSensitiveTest_Reg") def my_func(): pass + # assert result = Registry.get_load_dataset_component("casesensitivetest_reg") assert result is my_func def test_register_uses_class_name_when_no_name(self): + # setup & execute @Registry.register(DataComponentType.LOAD_DATASET) def unique_named_test_func_reg(): pass + # assert result = Registry.get_load_dataset_component("unique_named_test_func_reg") assert result is unique_named_test_func_reg class TestRegistryGet: def test_get_component(self): + # setup @Registry.register(DataComponentType.LOAD_DATASET, name="test_get_comp_reg") def my_func(): pass + # execute result = Registry.get_component(DataComponentType.LOAD_DATASET.value, "test_get_comp_reg") + + # assert assert result is my_func def test_get_by_subtype(self): + # setup @Registry.register(DataComponentType.LOAD_DATASET, name="test_get_subtype_reg") def my_func(): pass + # execute result = Registry.get(DataComponentType.LOAD_DATASET.value, "test_get_subtype_reg") + + # assert assert result is my_func class TestRegistryDefaultComponents: def test_get_default_load_dataset(self): + # setup + + # execute result = Registry.get_default_load_dataset_component() + + # assert assert result is not None def test_get_default_pre_process(self): + # setup + + # execute result = Registry.get_default_pre_process_component() + + # assert assert result is not None def test_get_default_post_process(self): + # setup + + # execute result = Registry.get_default_post_process_component() + + # assert assert result is not None def test_get_default_dataloader(self): + # setup + + # execute result = Registry.get_default_dataloader_component() + + # assert assert result is not None class TestRegistryContainer: def test_get_container_default(self): + # setup + + # execute result = Registry.get_container(None) + + # assert assert result is not None def test_get_container_by_name(self): + # setup + + # execute result = Registry.get_container(DefaultDataContainer.DATA_CONTAINER.value) + + # assert assert result is not None diff --git a/test/evaluator/test_metric_config.py b/test/evaluator/test_metric_config.py index 3b3f1f364c..381958a57c 100644 --- a/test/evaluator/test_metric_config.py +++ b/test/evaluator/test_metric_config.py @@ -17,13 +17,23 @@ class TestLatencyMetricConfig: def test_defaults(self): + # setup + + # execute config = LatencyMetricConfig() + + # assert assert config.warmup_num == 10 assert config.repeat_test_num == 20 assert config.sleep_num == 0 def test_custom_values(self): + # setup + + # execute config = LatencyMetricConfig(warmup_num=5, repeat_test_num=100, sleep_num=2) + + # assert assert config.warmup_num == 5 assert config.repeat_test_num == 100 assert config.sleep_num == 2 @@ -31,13 +41,23 @@ def test_custom_values(self): class TestThroughputMetricConfig: def test_defaults(self): + # setup + + # execute config = ThroughputMetricConfig() + + # assert assert config.warmup_num == 10 assert config.repeat_test_num == 20 assert config.sleep_num == 0 def test_custom_values(self): + # setup + + # execute config = ThroughputMetricConfig(warmup_num=3, repeat_test_num=50, sleep_num=1) + + # assert assert config.warmup_num == 3 assert config.repeat_test_num == 50 assert config.sleep_num == 1 @@ -45,93 +65,188 @@ def test_custom_values(self): class TestSizeOnDiskMetricConfig: def test_creation(self): + # setup + + # execute config = SizeOnDiskMetricConfig() + + # assert assert isinstance(config, SizeOnDiskMetricConfig) class TestMetricGoal: def test_threshold_type(self): + # setup + + # execute goal = MetricGoal(type="threshold", value=0.9) + + # assert assert goal.type == "threshold" assert goal.value == 0.9 def test_min_improvement_type(self): + # setup + + # execute goal = MetricGoal(type="min-improvement", value=0.05) + + # assert assert goal.type == "min-improvement" assert goal.value == 0.05 def test_max_degradation_type(self): + # setup + + # execute goal = MetricGoal(type="max-degradation", value=0.1) + + # assert assert goal.type == "max-degradation" assert goal.value == 0.1 def test_percent_min_improvement_type(self): + # setup + + # execute goal = MetricGoal(type="percent-min-improvement", value=5.0) + + # assert assert goal.type == "percent-min-improvement" def test_percent_max_degradation_type(self): + # setup + + # execute goal = MetricGoal(type="percent-max-degradation", value=10.0) + + # assert assert goal.type == "percent-max-degradation" def test_invalid_type_raises(self): + # setup + + # execute & assert with pytest.raises(ValidationError, match="Metric goal type must be one of"): MetricGoal(type="invalid_type", value=0.5) def test_negative_value_for_min_improvement_raises(self): + # setup + + # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="min-improvement", value=-0.5) def test_negative_value_for_max_degradation_raises(self): + # setup + + # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="max-degradation", value=-0.1) def test_negative_value_for_percent_min_improvement_raises(self): + # setup + + # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="percent-min-improvement", value=-5.0) def test_negative_value_for_percent_max_degradation_raises(self): + # setup + + # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="percent-max-degradation", value=-10.0) def test_threshold_allows_negative_value(self): + # setup + + # execute goal = MetricGoal(type="threshold", value=-1.0) + + # assert assert goal.value == -1.0 def test_has_regression_goal_min_improvement(self): + # setup goal = MetricGoal(type="min-improvement", value=0.05) - assert goal.has_regression_goal() is False + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False def test_has_regression_goal_percent_min_improvement(self): + # setup goal = MetricGoal(type="percent-min-improvement", value=5.0) - assert goal.has_regression_goal() is False + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False def test_has_regression_goal_max_degradation_positive(self): + # setup goal = MetricGoal(type="max-degradation", value=0.1) - assert goal.has_regression_goal() is True + + # execute + result = goal.has_regression_goal() + + # assert + assert result is True def test_has_regression_goal_max_degradation_zero(self): + # setup goal = MetricGoal(type="max-degradation", value=0.0) - assert goal.has_regression_goal() is False + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False def test_has_regression_goal_percent_max_degradation_positive(self): + # setup goal = MetricGoal(type="percent-max-degradation", value=10.0) - assert goal.has_regression_goal() is True + + # execute + result = goal.has_regression_goal() + + # assert + assert result is True def test_has_regression_goal_threshold(self): + # setup goal = MetricGoal(type="threshold", value=0.9) - assert goal.has_regression_goal() is False + + # execute + result = goal.has_regression_goal() + + # assert + assert result is False class TestGetUserConfigClass: def test_custom_metric_type(self): + # setup + + # execute cls = get_user_config_class("custom") instance = cls() + + # assert assert hasattr(instance, "user_script") assert hasattr(instance, "evaluate_func") def test_unknown_metric_type(self): + # setup + + # execute cls = get_user_config_class("latency") instance = cls() + + # assert assert hasattr(instance, "user_script") - # Unknown metric types still get common config assert hasattr(instance, "inference_settings") diff --git a/test/evaluator/test_metric_result.py b/test/evaluator/test_metric_result.py index 89949476c7..396936fbf6 100644 --- a/test/evaluator/test_metric_result.py +++ b/test/evaluator/test_metric_result.py @@ -15,17 +15,32 @@ class TestSubMetricResult: def test_creation(self): + # setup + + # execute result = SubMetricResult(value=0.95, priority=1, higher_is_better=True) + + # assert assert result.value == 0.95 assert result.priority == 1 assert result.higher_is_better is True def test_integer_value(self): + # setup + + # execute result = SubMetricResult(value=100, priority=2, higher_is_better=False) + + # assert assert result.value == 100 def test_float_value(self): + # setup + + # execute result = SubMetricResult(value=0.001, priority=0, higher_is_better=True) + + # assert assert result.value == 0.001 @@ -40,60 +55,118 @@ def _make_result(self): ) def test_get_value(self): + # setup result = self._make_result() - assert result.get_value("accuracy", "top1") == 0.95 - assert result.get_value("latency", "avg") == 10.5 + + # execute + accuracy = result.get_value("accuracy", "top1") + latency = result.get_value("latency", "avg") + + # assert + assert accuracy == 0.95 + assert latency == 10.5 def test_get_all_sub_type_metric_value(self): + # setup result = self._make_result() + + # execute latency_values = result.get_all_sub_type_metric_value("latency") + + # assert assert latency_values == {"avg": 10.5, "p99": 20.0} def test_get_all_sub_type_single_metric(self): + # setup result = self._make_result() + + # execute accuracy_values = result.get_all_sub_type_metric_value("accuracy") + + # assert assert accuracy_values == {"top1": 0.95} def test_str_representation(self): + # setup result = self._make_result() + + # execute result_str = str(result) parsed = json.loads(result_str) + + # assert assert parsed["accuracy-top1"] == 0.95 assert parsed["latency-avg"] == 10.5 def test_len(self): + # setup result = self._make_result() - assert len(result) == 3 + + # execute + length = len(result) + + # assert + assert length == 3 def test_getitem(self): + # setup result = self._make_result() - assert result["accuracy-top1"].value == 0.95 + + # execute + item = result["accuracy-top1"] + + # assert + assert item.value == 0.95 def test_delimiter(self): - assert MetricResult.delimiter == "-" + # setup + + # execute + delimiter = MetricResult.delimiter + + # assert + assert delimiter == "-" class TestJointMetricKey: def test_basic(self): - assert joint_metric_key("accuracy", "top1") == "accuracy-top1" + # setup + + # execute + result = joint_metric_key("accuracy", "top1") + + # assert + assert result == "accuracy-top1" def test_with_special_names(self): - assert joint_metric_key("latency", "p99") == "latency-p99" + # setup + + # execute + result = joint_metric_key("latency", "p99") + + # assert + assert result == "latency-p99" class TestFlattenMetricSubType: def test_flatten(self): + # setup metric_dict = { "accuracy": {"top1": {"value": 0.95, "priority": 1, "higher_is_better": True}}, "latency": {"avg": {"value": 10.5, "priority": 2, "higher_is_better": False}}, } + + # execute result = flatten_metric_sub_type(metric_dict) + + # assert assert "accuracy-top1" in result assert "latency-avg" in result class TestFlattenMetricResult: def test_flatten_to_metric_result(self): + # setup dict_results = { "accuracy": { "top1": {"value": 0.95, "priority": 1, "higher_is_better": True}, @@ -102,7 +175,11 @@ def test_flatten_to_metric_result(self): "avg": {"value": 10.5, "priority": 2, "higher_is_better": False}, }, } + + # execute result = flatten_metric_result(dict_results) + + # assert assert isinstance(result, MetricResult) assert result.get_value("accuracy", "top1") == 0.95 assert result.get_value("latency", "avg") == 10.5 diff --git a/test/search/test_search_parameter.py b/test/search/test_search_parameter.py index e3c05231d4..a493ea54ec 100644 --- a/test/search/test_search_parameter.py +++ b/test/search/test_search_parameter.py @@ -17,36 +17,83 @@ class TestSpecialParamValue: def test_ignored_value(self): - assert SpecialParamValue.IGNORED == "OLIVE_IGNORED_PARAM_VALUE" + # setup + + # execute + result = SpecialParamValue.IGNORED + + # assert + assert result == "OLIVE_IGNORED_PARAM_VALUE" def test_invalid_value(self): - assert SpecialParamValue.INVALID == "OLIVE_INVALID_PARAM_VALUE" + # setup + + # execute + result = SpecialParamValue.INVALID + + # assert + assert result == "OLIVE_INVALID_PARAM_VALUE" class TestCategorical: def test_int_support(self): + # setup cat = Categorical([1, 2, 3]) - assert cat.get_support() == [1, 2, 3] + + # execute + result = cat.get_support() + + # assert + assert result == [1, 2, 3] def test_string_support(self): + # setup cat = Categorical(["a", "b", "c"]) - assert cat.get_support() == ["a", "b", "c"] + + # execute + result = cat.get_support() + + # assert + assert result == ["a", "b", "c"] def test_float_support(self): + # setup cat = Categorical([0.1, 0.5, 1.0]) - assert cat.get_support() == [0.1, 0.5, 1.0] + + # execute + result = cat.get_support() + + # assert + assert result == [0.1, 0.5, 1.0] def test_bool_support(self): + # setup cat = Categorical([True, False]) - assert cat.get_support() == [True, False] + + # execute + result = cat.get_support() + + # assert + assert result == [True, False] def test_repr(self): + # setup cat = Categorical([1, 2, 3]) - assert repr(cat) == "Categorical([1, 2, 3])" + + # execute + result = repr(cat) + + # assert + assert result == "Categorical([1, 2, 3])" def test_to_json(self): + # setup cat = Categorical([1, 2, 3]) + + # execute result = cat.to_json() + + # assert assert result["olive_parameter_type"] == "SearchParameter" assert result["type"] == "Categorical" assert result["support"] == [1, 2, 3] @@ -54,15 +101,28 @@ def test_to_json(self): class TestBoolean: def test_support(self): + # setup b = Boolean() - assert b.get_support() == [True, False] + + # execute + result = b.get_support() + + # assert + assert result == [True, False] def test_is_categorical(self): - assert issubclass(Boolean, Categorical) + # setup + + # execute + result = issubclass(Boolean, Categorical) + + # assert + assert result class TestConditional: def test_single_parent(self): + # setup cond = Conditional( parents=("parent1",), support={ @@ -71,58 +131,97 @@ def test_single_parent(self): }, default=Categorical([7, 8, 9]), ) - assert cond.get_support_with_args({"parent1": "value1"}) == [1, 2, 3] - assert cond.get_support_with_args({"parent1": "value2"}) == [4, 5, 6] - assert cond.get_support_with_args({"parent1": "unknown"}) == [7, 8, 9] + + # execute + result_v1 = cond.get_support_with_args({"parent1": "value1"}) + result_v2 = cond.get_support_with_args({"parent1": "value2"}) + result_unknown = cond.get_support_with_args({"parent1": "unknown"}) + + # assert + assert result_v1 == [1, 2, 3] + assert result_v2 == [4, 5, 6] + assert result_unknown == [7, 8, 9] def test_multi_parent(self): + # setup cond = Conditional( parents=("parent1", "parent2"), support={ ("v1", "v2"): Categorical([10, 20]), }, ) - assert cond.get_support_with_args({"parent1": "v1", "parent2": "v2"}) == [10, 20] + + # execute + result = cond.get_support_with_args({"parent1": "v1", "parent2": "v2"}) + + # assert + assert result == [10, 20] def test_default_is_invalid(self): + # setup cond = Conditional( parents=("p",), support={("a",): Categorical([1])}, ) - # Default is invalid choice - assert cond.get_support_with_args({"p": "missing"}) == [SpecialParamValue.INVALID] + + # execute + result = cond.get_support_with_args({"p": "missing"}) + + # assert + assert result == [SpecialParamValue.INVALID] def test_get_invalid_choice(self): + # setup + + # execute result = Conditional.get_invalid_choice() + + # assert assert isinstance(result, Categorical) assert result.get_support() == [SpecialParamValue.INVALID] def test_get_ignored_choice(self): + # setup + + # execute result = Conditional.get_ignored_choice() + + # assert assert isinstance(result, Categorical) assert result.get_support() == [SpecialParamValue.IGNORED] def test_repr(self): + # setup cond = Conditional( parents=("p",), support={("a",): Categorical([1])}, ) + + # execute result = repr(cond) + + # assert assert "Conditional" in result assert "parents" in result def test_to_json(self): + # setup cond = Conditional( parents=("p",), support={("a",): Categorical([1, 2])}, default=Categorical([3]), ) + + # execute result = cond.to_json() + + # assert assert result["olive_parameter_type"] == "SearchParameter" assert result["type"] == "Conditional" assert result["parents"] == ("p",) def test_condition_single_parent(self): + # setup cond = Conditional( parents=("p",), support={ @@ -130,82 +229,143 @@ def test_condition_single_parent(self): ("b",): Categorical([3, 4]), }, ) + + # execute result = cond.condition({"p": "a"}) + + # assert assert isinstance(result, Categorical) assert result.get_support() == [1, 2] def test_condition_returns_default_when_no_match(self): + # setup cond = Conditional( parents=("p1", "p2"), support={("a", "b"): Categorical([1])}, default=Categorical([99]), ) + + # execute result = cond.condition({"p1": "missing"}) + + # assert assert isinstance(result, Categorical) assert result.get_support() == [99] class TestConditionalDefault: def test_basic(self): + # setup cd = ConditionalDefault( parents=("p",), support={("a",): 1, ("b",): 2}, default=3, ) - assert cd.get_support_with_args({"p": "a"}) == 1 - assert cd.get_support_with_args({"p": "b"}) == 2 - assert cd.get_support_with_args({"p": "c"}) == 3 + + # execute + result_a = cd.get_support_with_args({"p": "a"}) + result_b = cd.get_support_with_args({"p": "b"}) + result_c = cd.get_support_with_args({"p": "c"}) + + # assert + assert result_a == 1 + assert result_b == 2 + assert result_c == 3 def test_default_invalid(self): + # setup cd = ConditionalDefault( parents=("p",), support={("a",): 1}, ) - assert cd.get_support_with_args({"p": "missing"}) == SpecialParamValue.INVALID + + # execute + result = cd.get_support_with_args({"p": "missing"}) + + # assert + assert result == SpecialParamValue.INVALID def test_condition(self): + # setup cd = ConditionalDefault( parents=("p",), support={("a",): 10, ("b",): 20}, default=30, ) - assert cd.condition({"p": "a"}) == 10 - assert cd.condition({"p": "b"}) == 20 - assert cd.condition({"p": "c"}) == 30 + + # execute + result_a = cd.condition({"p": "a"}) + result_b = cd.condition({"p": "b"}) + result_c = cd.condition({"p": "c"}) + + # assert + assert result_a == 10 + assert result_b == 20 + assert result_c == 30 def test_get_invalid_choice(self): - assert ConditionalDefault.get_invalid_choice() == SpecialParamValue.INVALID + # setup + + # execute + result = ConditionalDefault.get_invalid_choice() + + # assert + assert result == SpecialParamValue.INVALID def test_get_ignored_choice(self): - assert ConditionalDefault.get_ignored_choice() == SpecialParamValue.IGNORED + # setup + + # execute + result = ConditionalDefault.get_ignored_choice() + + # assert + assert result == SpecialParamValue.IGNORED def test_to_json(self): + # setup cd = ConditionalDefault( parents=("p",), support={("a",): 1}, default=2, ) + + # execute result = cd.to_json() + + # assert assert result["type"] == "ConditionalDefault" def test_repr(self): + # setup cd = ConditionalDefault( parents=("p",), support={("a",): 1}, default=2, ) + + # execute result = repr(cd) + + # assert assert "ConditionalDefault" in result class TestJsonToSearchParameter: def test_categorical(self): + # setup json_data = {"olive_parameter_type": "SearchParameter", "type": "Categorical", "support": [1, 2, 3]} + + # execute result = json_to_search_parameter(json_data) + + # assert assert isinstance(result, Categorical) assert result.get_support() == [1, 2, 3] def test_unknown_type_raises(self): + # setup json_data = {"olive_parameter_type": "SearchParameter", "type": "Unknown"} + + # execute & assert with pytest.raises(ValueError, match="Unknown search parameter type"): json_to_search_parameter(json_data) diff --git a/test/search/test_search_point.py b/test/search/test_search_point.py index 60a66e6bc9..3a736e3d91 100644 --- a/test/search/test_search_point.py +++ b/test/search/test_search_point.py @@ -27,35 +27,69 @@ def _make_point(self, index=0, values=None): return SearchPoint(index=index, values=values) def test_creation(self): + # setup + + # execute point = self._make_point() + + # assert assert point.index == 0 def test_repr(self): + # setup point = self._make_point() + + # execute result = repr(point) + + # assert assert "SearchPoint" in result assert "0" in result def test_equality_same(self): + # setup point1 = self._make_point() point2 = self._make_point() - assert point1 == point2 + + # execute + result = point1 == point2 + + # assert + assert result def test_equality_different_index(self): + # setup point1 = self._make_point(index=0) point2 = self._make_point(index=1) - assert point1 != point2 + + # execute + result = point1 == point2 + + # assert + assert not result def test_equality_different_type(self): + # setup point = self._make_point() - assert point != "not a search point" + + # execute + result = point == "not a search point" + + # assert + assert not result def test_is_valid_true(self): + # setup point = self._make_point() - assert point.is_valid() is True + + # execute + result = point.is_valid() + + # assert + assert result is True def test_is_valid_false_with_invalid(self): - # Verify that a SearchPoint with an invalid parameter value is detected as invalid + # setup values = OrderedDict( { "pass1": OrderedDict( @@ -66,17 +100,32 @@ def test_is_valid_false_with_invalid(self): } ) point = SearchPoint(index=0, values=values) - assert point.is_valid() is False + + # execute + result = point.is_valid() + + # assert + assert result is False def test_to_json(self): + # setup point = self._make_point(index=5) + + # execute result = point.to_json() + + # assert assert result["index"] == 5 assert "values" in result def test_from_json_roundtrip(self): + # setup point = self._make_point(index=3) json_data = point.to_json() + + # execute restored = SearchPoint.from_json(json_data) + + # assert assert restored.index == 3 assert restored == point diff --git a/test/search/test_search_sample.py b/test/search/test_search_sample.py index c233abb687..5c5a002fd5 100644 --- a/test/search/test_search_sample.py +++ b/test/search/test_search_sample.py @@ -34,44 +34,80 @@ def _make_sample(self, invalid_param=False, ignored_param=False): return SearchSample(search_point=point, model_ids=["model_0"]) def test_creation(self): + # setup + + # execute sample = self._make_sample() + + # assert assert sample.model_ids == ["model_0"] assert sample.search_point.index == 0 def test_passes_configs_valid(self): + # setup sample = self._make_sample() + + # execute configs = sample.passes_configs + + # assert assert configs is not None assert "pass1" in configs assert configs["pass1"]["params"]["param1"] == "a" assert configs["pass1"]["params"]["param2"] == 10 def test_passes_configs_with_invalid_returns_none(self): + # setup sample = self._make_sample(invalid_param=True) - assert sample.passes_configs is None + + # execute + configs = sample.passes_configs + + # assert + assert configs is None def test_passes_configs_with_ignored_excludes_param(self): + # setup sample = self._make_sample(ignored_param=True) + + # execute configs = sample.passes_configs + + # assert assert configs is not None assert "param1" not in configs["pass1"]["params"] assert configs["pass1"]["params"]["param2"] == 10 def test_to_json(self): + # setup sample = self._make_sample() + + # execute result = sample.to_json() + + # assert assert "search_point" in result assert "model_ids" in result assert result["model_ids"] == ["model_0"] def test_from_json_roundtrip(self): + # setup sample = self._make_sample() json_data = sample.to_json() + + # execute restored = SearchSample.from_json(json_data) + + # assert assert restored.model_ids == sample.model_ids assert restored.search_point.index == sample.search_point.index def test_repr(self): + # setup sample = self._make_sample() + + # execute result = repr(sample) + + # assert assert "SearchSample" in result diff --git a/test/telemetry/test_telemetry_utils.py b/test/telemetry/test_telemetry_utils.py index 4a412dd200..9aacc8e033 100644 --- a/test/telemetry/test_telemetry_utils.py +++ b/test/telemetry/test_telemetry_utils.py @@ -20,73 +20,134 @@ class TestResolveHomeDir: def test_returns_path(self): + # setup + + # execute result = _resolve_home_dir() + + # assert assert isinstance(result, Path) def test_with_home_env_set(self): + # setup + + # execute with patch.dict(os.environ, {"HOME": "/tmp/test_home"}): result = _resolve_home_dir() - assert isinstance(result, Path) + + # assert + assert isinstance(result, Path) def test_without_home_env(self): + # setup + + # execute with patch.dict(os.environ, {}, clear=True): result = _resolve_home_dir() - assert isinstance(result, Path) + + # assert + assert isinstance(result, Path) class TestGetTelemetryBaseDir: def test_returns_path(self): - # Clear the lru_cache before test + # setup get_telemetry_base_dir.cache_clear() + + # execute result = get_telemetry_base_dir() + + # assert assert isinstance(result, Path) def test_path_contains_onnxruntime(self): + # setup get_telemetry_base_dir.cache_clear() + + # execute result = get_telemetry_base_dir() + + # assert assert ".onnxruntime" in str(result) class TestFormatExceptionMessage: def test_basic_exception(self): + # setup try: 1 / 0 # noqa: B018 except ZeroDivisionError as ex: - result = _format_exception_message(ex, ex.__traceback__) - assert "ZeroDivisionError" in result + exception = ex + traceback = ex.__traceback__ + + # execute + result = _format_exception_message(exception, traceback) + + # assert + assert "ZeroDivisionError" in result def test_exception_without_traceback(self): + # setup ex = ValueError("test error") + + # execute result = _format_exception_message(ex) + + # assert assert "test error" in result class TestEncodeCacheLine: def test_encode_basic_string(self): - result = _encode_cache_line("hello") + # setup expected = base64.b64encode(b"hello").decode("ascii") + + # execute + result = _encode_cache_line("hello") + + # assert assert result == expected def test_encode_empty_string(self): - result = _encode_cache_line("") + # setup expected = base64.b64encode(b"").decode("ascii") + + # execute + result = _encode_cache_line("") + + # assert assert result == expected def test_encode_unicode_string(self): - result = _encode_cache_line("hello 世界") + # setup + + # execute + result = _encode_cache_line("hello \u4e16\u754c") decoded = base64.b64decode(result).decode("utf-8") - assert decoded == "hello 世界" + + # assert + assert decoded == "hello \u4e16\u754c" class TestDecodeCacheLine: def test_decode_basic_string(self): + # setup encoded = base64.b64encode(b"hello").decode("ascii") + + # execute result = _decode_cache_line(encoded) + + # assert assert result == "hello" def test_decode_empty_string(self): + # setup encoded = base64.b64encode(b"").decode("ascii") + + # execute result = _decode_cache_line(encoded) + + # assert assert result == "" @@ -102,6 +163,11 @@ class TestEncodeDecodeRoundtrip: ], ) def test_roundtrip(self, text): + # setup + + # execute encoded = _encode_cache_line(text) decoded = _decode_cache_line(encoded) + + # assert assert decoded == text diff --git a/test/test_constants.py b/test/test_constants.py index d7d06fa596..e512739eb0 100644 --- a/test/test_constants.py +++ b/test/test_constants.py @@ -23,41 +23,92 @@ class TestFramework: def test_framework_values(self): - assert Framework.ONNX == "ONNX" - assert Framework.PYTORCH == "PyTorch" - assert Framework.OPENVINO == "OpenVINO" + # setup + + # execute + onnx = Framework.ONNX + pytorch = Framework.PYTORCH + openvino = Framework.OPENVINO + + # assert + assert onnx == "ONNX" + assert pytorch == "PyTorch" + assert openvino == "OpenVINO" def test_framework_str(self): - assert str(Framework.ONNX) == "ONNX" + # setup + + # execute + result = str(Framework.ONNX) + + # assert + assert result == "ONNX" def test_framework_all_members(self): + # setup expected = {"ONNX", "PYTORCH", "QAIRT", "QNN", "TENSORFLOW", "OPENVINO"} - assert set(Framework.__members__.keys()) == expected + + # execute + result = set(Framework.__members__.keys()) + + # assert + assert result == expected class TestModelFileFormat: def test_model_file_format_values(self): - assert ModelFileFormat.ONNX == "ONNX" - assert ModelFileFormat.PYTORCH_STATE_DICT == "PyTorch.StateDict" - assert ModelFileFormat.COMPOSITE_MODEL == "Composite" + # setup + + # execute + onnx = ModelFileFormat.ONNX + state_dict = ModelFileFormat.PYTORCH_STATE_DICT + composite = ModelFileFormat.COMPOSITE_MODEL + + # assert + assert onnx == "ONNX" + assert state_dict == "PyTorch.StateDict" + assert composite == "Composite" def test_model_file_format_str(self): - assert str(ModelFileFormat.OPENVINO_IR) == "OpenVINO.IR" + # setup + + # execute + result = str(ModelFileFormat.OPENVINO_IR) + + # assert + assert result == "OpenVINO.IR" class TestPrecision: def test_precision_values(self): - assert Precision.INT4 == "int4" - assert Precision.FP16 == "fp16" - assert Precision.BF16 == "bf16" + # setup + + # execute + int4 = Precision.INT4 + fp16 = Precision.FP16 + bf16 = Precision.BF16 + + # assert + assert int4 == "int4" + assert fp16 == "fp16" + assert bf16 == "bf16" def test_precision_all_members(self): + # setup expected_count = 14 - assert len(Precision) == expected_count + + # execute + result = len(Precision) + + # assert + assert result == expected_count class TestPrecisionBits: def test_precision_bits_values(self): + # setup + + # execute & assert assert PrecisionBits.BITS2 == 2 assert PrecisionBits.BITS4 == 4 assert PrecisionBits.BITS8 == 8 @@ -65,46 +116,99 @@ def test_precision_bits_values(self): assert PrecisionBits.BITS32 == 32 def test_precision_bits_is_int(self): - assert isinstance(PrecisionBits.BITS4.value, int) + # setup + + # execute + result = isinstance(PrecisionBits.BITS4.value, int) + + # assert + assert result class TestQuantAlgorithm: def test_quant_algorithm_case_insensitive(self): - assert QuantAlgorithm("awq") == QuantAlgorithm.AWQ - assert QuantAlgorithm("AWQ") == QuantAlgorithm.AWQ - assert QuantAlgorithm("Awq") == QuantAlgorithm.AWQ + # setup + + # execute + lower = QuantAlgorithm("awq") + upper = QuantAlgorithm("AWQ") + mixed = QuantAlgorithm("Awq") + + # assert + assert lower == QuantAlgorithm.AWQ + assert upper == QuantAlgorithm.AWQ + assert mixed == QuantAlgorithm.AWQ def test_quant_algorithm_values(self): - assert QuantAlgorithm.GPTQ == "gptq" - assert QuantAlgorithm.RTN == "rtn" + # setup + + # execute + gptq = QuantAlgorithm.GPTQ + rtn = QuantAlgorithm.RTN + + # assert + assert gptq == "gptq" + assert rtn == "rtn" def test_quant_algorithm_all_members(self): + # setup expected = {"AWQ", "GPTQ", "HQQ", "RTN", "SPINQUANT", "QUAROT", "LPBQ", "SEQMSE", "ADAROUND"} - assert set(QuantAlgorithm.__members__.keys()) == expected + + # execute + result = set(QuantAlgorithm.__members__.keys()) + + # assert + assert result == expected class TestQuantEncoding: def test_quant_encoding_values(self): - assert QuantEncoding.QDQ == "qdq" - assert QuantEncoding.QOP == "qop" + # setup + + # execute + qdq = QuantEncoding.QDQ + qop = QuantEncoding.QOP + + # assert + assert qdq == "qdq" + assert qop == "qop" class TestDatasetRequirement: def test_dataset_requirement_values(self): - assert DatasetRequirement.REQUIRED == "dataset_required" - assert DatasetRequirement.OPTIONAL == "dataset_optional" - assert DatasetRequirement.NOT_REQUIRED == "dataset_not_required" + # setup + + # execute + required = DatasetRequirement.REQUIRED + optional = DatasetRequirement.OPTIONAL + not_required = DatasetRequirement.NOT_REQUIRED + + # assert + assert required == "dataset_required" + assert optional == "dataset_optional" + assert not_required == "dataset_not_required" class TestOpType: def test_op_type_values(self): - assert OpType.MatMul == "MatMul" - assert OpType.Add == "Add" - assert OpType.Custom == "custom" + # setup + + # execute + matmul = OpType.MatMul + add = OpType.Add + custom = OpType.Custom + + # assert + assert matmul == "MatMul" + assert add == "Add" + assert custom == "custom" class TestAccuracyLevel: def test_accuracy_level_values(self): + # setup + + # execute & assert assert AccuracyLevel.unset == 0 assert AccuracyLevel.fp32 == 1 assert AccuracyLevel.fp16 == 2 @@ -113,16 +217,32 @@ def test_accuracy_level_values(self): class TestDiffusersModelVariant: def test_diffusers_variant_values(self): - assert DiffusersModelVariant.AUTO == "auto" - assert DiffusersModelVariant.SD == "sd" - assert DiffusersModelVariant.FLUX == "flux" + # setup + + # execute + auto = DiffusersModelVariant.AUTO + sd = DiffusersModelVariant.SD + flux = DiffusersModelVariant.FLUX + + # assert + assert auto == "auto" + assert sd == "sd" + assert flux == "flux" class TestDiffusersComponent: def test_diffusers_component_values(self): - assert DiffusersComponent.TEXT_ENCODER == "text_encoder" - assert DiffusersComponent.UNET == "unet" - assert DiffusersComponent.VAE_DECODER == "vae_decoder" + # setup + + # execute + text_encoder = DiffusersComponent.TEXT_ENCODER + unet = DiffusersComponent.UNET + vae_decoder = DiffusersComponent.VAE_DECODER + + # assert + assert text_encoder == "text_encoder" + assert unet == "unet" + assert vae_decoder == "vae_decoder" class TestPrecisionBitsFromPrecision: @@ -140,13 +260,31 @@ class TestPrecisionBitsFromPrecision: ], ) def test_precision_to_bits_mapping(self, precision, expected): - assert precision_bits_from_precision(precision) == expected + # setup + + # execute + result = precision_bits_from_precision(precision) + + # assert + assert result == expected @pytest.mark.parametrize("precision", [Precision.FP16, Precision.FP32, Precision.BF16, Precision.NF4]) def test_precision_without_bits_mapping_returns_none(self, precision): - assert precision_bits_from_precision(precision) is None + # setup + + # execute + result = precision_bits_from_precision(precision) + + # assert + assert result is None class TestMsftDomain: def test_msft_domain_value(self): - assert MSFT_DOMAIN == "com.microsoft" + # setup + + # execute + result = MSFT_DOMAIN + + # assert + assert result == "com.microsoft" diff --git a/test/test_exception.py b/test/test_exception.py index 088714eaa5..4abcf21d0d 100644 --- a/test/test_exception.py +++ b/test/test_exception.py @@ -9,52 +9,103 @@ class TestOliveError: def test_olive_error_is_exception(self): - assert issubclass(OliveError, Exception) + # setup + + # execute + result = issubclass(OliveError, Exception) + + # assert + assert result def test_olive_error_can_be_raised(self): + # setup + + # execute & assert with pytest.raises(OliveError, match="test error"): raise OliveError("test error") def test_olive_error_empty_message(self): + # setup + + # execute & assert with pytest.raises(OliveError): raise OliveError class TestOlivePassError: def test_olive_pass_error_inherits_olive_error(self): - assert issubclass(OlivePassError, OliveError) + # setup + + # execute + result = issubclass(OlivePassError, OliveError) + + # assert + assert result def test_olive_pass_error_can_be_raised(self): + # setup + + # execute & assert with pytest.raises(OlivePassError, match="pass failed"): raise OlivePassError("pass failed") def test_olive_pass_error_caught_as_olive_error(self): + # setup + + # execute & assert with pytest.raises(OliveError): raise OlivePassError("pass failed") class TestOliveEvaluationError: def test_olive_evaluation_error_inherits_olive_error(self): - assert issubclass(OliveEvaluationError, OliveError) + # setup + + # execute + result = issubclass(OliveEvaluationError, OliveError) + + # assert + assert result def test_olive_evaluation_error_can_be_raised(self): + # setup + + # execute & assert with pytest.raises(OliveEvaluationError, match="evaluation failed"): raise OliveEvaluationError("evaluation failed") def test_olive_evaluation_error_caught_as_olive_error(self): + # setup + + # execute & assert with pytest.raises(OliveError): raise OliveEvaluationError("evaluation failed") class TestExceptionsToRaise: def test_exceptions_to_raise_is_tuple(self): - assert isinstance(EXCEPTIONS_TO_RAISE, tuple) + # setup + + # execute + result = isinstance(EXCEPTIONS_TO_RAISE, tuple) + + # assert + assert result def test_exceptions_to_raise_contains_expected_types(self): + # setup expected = {AssertionError, AttributeError, ImportError, TypeError, ValueError} - assert set(EXCEPTIONS_TO_RAISE) == expected + + # execute + result = set(EXCEPTIONS_TO_RAISE) + + # assert + assert result == expected @pytest.mark.parametrize("exc_type", EXCEPTIONS_TO_RAISE) def test_each_exception_is_catchable(self, exc_type): + # setup + + # execute & assert with pytest.raises(exc_type): raise exc_type("test") diff --git a/test/test_logging.py b/test/test_logging.py index 0894a07dc1..0a9a9f34ff 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -26,56 +26,111 @@ class TestGetOliveLogger: def test_returns_olive_logger(self): + # setup + + # execute logger = get_olive_logger() + + # assert assert logger.name == "olive" def test_returns_same_logger_instance(self): + # setup + + # execute logger1 = get_olive_logger() logger2 = get_olive_logger() + + # assert assert logger1 is logger2 class TestSetVerbosity: def test_set_verbosity_info(self): + # setup + + # execute set_verbosity_info() + + # assert assert get_olive_logger().level == logging.INFO def test_set_verbosity_warning(self): + # setup + + # execute set_verbosity_warning() + + # assert assert get_olive_logger().level == logging.WARNING def test_set_verbosity_debug(self): + # setup + + # execute set_verbosity_debug() + + # assert assert get_olive_logger().level == logging.DEBUG def test_set_verbosity_error(self): + # setup + + # execute set_verbosity_error() + + # assert assert get_olive_logger().level == logging.ERROR def test_set_verbosity_critical(self): + # setup + + # execute set_verbosity_critical() + + # assert assert get_olive_logger().level == logging.CRITICAL def test_set_verbosity_custom_level(self): + # setup + + # execute set_verbosity(logging.WARNING) + + # assert assert get_olive_logger().level == logging.WARNING class TestSetVerbosityFromEnv: def test_set_verbosity_from_env_default(self): + # setup + + # execute with patch.dict("os.environ", {}, clear=True): set_verbosity_from_env() + # assert (no exception raised) + def test_set_verbosity_from_env_custom(self): + # setup + + # execute with patch.dict("os.environ", {"OLIVE_LOG_LEVEL": "DEBUG"}): set_verbosity_from_env() + + # assert assert get_olive_logger().level == logging.DEBUG class TestGetVerbosity: def test_get_verbosity_returns_int(self): + # setup set_verbosity_info() + + # execute level = get_verbosity() + + # assert assert isinstance(level, int) assert level == logging.INFO @@ -92,10 +147,19 @@ class TestGetLoggerLevel: ], ) def test_valid_levels(self, level_int, expected): - assert get_logger_level(level_int) == expected + # setup + + # execute + result = get_logger_level(level_int) + + # assert + assert result == expected @pytest.mark.parametrize("invalid_level", [-1, 5, 10, 100]) def test_invalid_levels_raise_value_error(self, invalid_level): + # setup + + # execute & assert with pytest.raises(ValueError, match="Invalid level"): get_logger_level(invalid_level) @@ -103,24 +167,31 @@ def test_invalid_levels_raise_value_error(self, invalid_level): class TestSetDefaultLoggerSeverity: @pytest.mark.parametrize("level", [0, 1, 2, 3, 4]) def test_set_default_logger_severity(self, level): - set_default_logger_severity(level) + # setup expected = get_logger_level(level) + + # execute + set_default_logger_severity(level) + + # assert assert get_olive_logger().level == expected class TestEnableFilelog: def test_enable_filelog_creates_handler(self, tmp_path): + # setup workflow_id = "test_workflow" + + # execute enable_filelog(1, str(tmp_path), workflow_id) + # assert logger = get_olive_logger() log_file_path = tmp_path / f"{workflow_id}.log" - - # Check that a file handler was added file_handlers = [h for h in logger.handlers if isinstance(h, logging.FileHandler)] assert len(file_handlers) > 0 - # Clean up: remove the handler we added + # cleanup for h in file_handlers: if Path(h.baseFilename) == log_file_path.resolve(): logger.removeHandler(h) From abdc8adf92872194d840414a9d4ad566245c7796 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 03:38:41 +0000 Subject: [PATCH 4/6] Remove empty # setup comments from test methods where no setup code exists Agent-Logs-Url: https://github.com/microsoft/Olive/sessions/3dc88a98-09ac-43c9-86b5-b2eddc3e34ae Co-authored-by: xiaoyu-work <85524621+xiaoyu-work@users.noreply.github.com> --- test/common/test_config_utils.py | 22 ---------------- test/data_container/test_registry.py | 12 --------- test/evaluator/test_metric_config.py | 36 -------------------------- test/evaluator/test_metric_result.py | 12 --------- test/search/test_search_parameter.py | 14 ---------- test/search/test_search_point.py | 2 -- test/search/test_search_sample.py | 2 -- test/telemetry/test_telemetry_utils.py | 10 ------- test/test_constants.py | 36 -------------------------- test/test_exception.py | 22 ---------------- test/test_logging.py | 24 ----------------- 11 files changed, 192 deletions(-) diff --git a/test/common/test_config_utils.py b/test/common/test_config_utils.py index cb0ac3010b..16f09b60ba 100644 --- a/test/common/test_config_utils.py +++ b/test/common/test_config_utils.py @@ -225,8 +225,6 @@ def test_to_json(self): assert isinstance(result, dict) def test_from_json(self): - # setup - # execute config = ConfigBase.from_json({}) @@ -234,8 +232,6 @@ def test_from_json(self): assert isinstance(config, ConfigBase) def test_parse_file_or_obj_dict(self): - # setup - # execute config = ConfigBase.parse_file_or_obj({}) @@ -410,8 +406,6 @@ class MyConfig(NestedConfig): class TestCaseInsensitiveEnum: def test_case_insensitive_creation(self): - # setup - # execute lower = ParamCategory("none") upper = ParamCategory("NONE") @@ -423,8 +417,6 @@ def test_case_insensitive_creation(self): assert mixed == ParamCategory.NONE def test_invalid_value_returns_none(self): - # setup - # execute result = ParamCategory._missing_("nonexistent") @@ -434,8 +426,6 @@ def test_invalid_value_returns_none(self): class TestConfigParam: def test_config_param_defaults(self): - # setup - # execute param = ConfigParam(type_=str) @@ -445,8 +435,6 @@ def test_config_param_defaults(self): assert param.category == ParamCategory.NONE def test_config_param_required(self): - # setup - # execute param = ConfigParam(type_=str, required=True) @@ -467,8 +455,6 @@ def test_config_param_repr(self): class TestValidateEnum: def test_valid_enum_value(self): - # setup - # execute result = validate_enum(ParamCategory, "none") @@ -476,8 +462,6 @@ def test_valid_enum_value(self): assert result == ParamCategory.NONE def test_invalid_enum_value_raises(self): - # setup - # execute & assert with pytest.raises(ValueError, match="Invalid value"): validate_enum(ParamCategory, "invalid_value") @@ -485,8 +469,6 @@ def test_invalid_enum_value_raises(self): class TestValidateLowercase: def test_string_lowercased(self): - # setup - # execute result = validate_lowercase("HELLO") @@ -494,8 +476,6 @@ def test_string_lowercased(self): assert result == "hello" def test_non_string_unchanged(self): - # setup - # execute result_int = validate_lowercase(42) result_none = validate_lowercase(None) @@ -621,8 +601,6 @@ def test_list_conversion(self): assert result == ["a", "b"] def test_plain_value_passthrough(self): - # setup - # execute result_int = convert_configs_to_dicts(42) result_str = convert_configs_to_dicts("hello") diff --git a/test/data_container/test_registry.py b/test/data_container/test_registry.py index c55e8243ce..fa255593d1 100644 --- a/test/data_container/test_registry.py +++ b/test/data_container/test_registry.py @@ -100,8 +100,6 @@ def my_func(): class TestRegistryDefaultComponents: def test_get_default_load_dataset(self): - # setup - # execute result = Registry.get_default_load_dataset_component() @@ -109,8 +107,6 @@ def test_get_default_load_dataset(self): assert result is not None def test_get_default_pre_process(self): - # setup - # execute result = Registry.get_default_pre_process_component() @@ -118,8 +114,6 @@ def test_get_default_pre_process(self): assert result is not None def test_get_default_post_process(self): - # setup - # execute result = Registry.get_default_post_process_component() @@ -127,8 +121,6 @@ def test_get_default_post_process(self): assert result is not None def test_get_default_dataloader(self): - # setup - # execute result = Registry.get_default_dataloader_component() @@ -138,8 +130,6 @@ def test_get_default_dataloader(self): class TestRegistryContainer: def test_get_container_default(self): - # setup - # execute result = Registry.get_container(None) @@ -147,8 +137,6 @@ def test_get_container_default(self): assert result is not None def test_get_container_by_name(self): - # setup - # execute result = Registry.get_container(DefaultDataContainer.DATA_CONTAINER.value) diff --git a/test/evaluator/test_metric_config.py b/test/evaluator/test_metric_config.py index 381958a57c..4145ae2868 100644 --- a/test/evaluator/test_metric_config.py +++ b/test/evaluator/test_metric_config.py @@ -17,8 +17,6 @@ class TestLatencyMetricConfig: def test_defaults(self): - # setup - # execute config = LatencyMetricConfig() @@ -28,8 +26,6 @@ def test_defaults(self): assert config.sleep_num == 0 def test_custom_values(self): - # setup - # execute config = LatencyMetricConfig(warmup_num=5, repeat_test_num=100, sleep_num=2) @@ -41,8 +37,6 @@ def test_custom_values(self): class TestThroughputMetricConfig: def test_defaults(self): - # setup - # execute config = ThroughputMetricConfig() @@ -52,8 +46,6 @@ def test_defaults(self): assert config.sleep_num == 0 def test_custom_values(self): - # setup - # execute config = ThroughputMetricConfig(warmup_num=3, repeat_test_num=50, sleep_num=1) @@ -65,8 +57,6 @@ def test_custom_values(self): class TestSizeOnDiskMetricConfig: def test_creation(self): - # setup - # execute config = SizeOnDiskMetricConfig() @@ -76,8 +66,6 @@ def test_creation(self): class TestMetricGoal: def test_threshold_type(self): - # setup - # execute goal = MetricGoal(type="threshold", value=0.9) @@ -86,8 +74,6 @@ def test_threshold_type(self): assert goal.value == 0.9 def test_min_improvement_type(self): - # setup - # execute goal = MetricGoal(type="min-improvement", value=0.05) @@ -96,8 +82,6 @@ def test_min_improvement_type(self): assert goal.value == 0.05 def test_max_degradation_type(self): - # setup - # execute goal = MetricGoal(type="max-degradation", value=0.1) @@ -106,8 +90,6 @@ def test_max_degradation_type(self): assert goal.value == 0.1 def test_percent_min_improvement_type(self): - # setup - # execute goal = MetricGoal(type="percent-min-improvement", value=5.0) @@ -115,8 +97,6 @@ def test_percent_min_improvement_type(self): assert goal.type == "percent-min-improvement" def test_percent_max_degradation_type(self): - # setup - # execute goal = MetricGoal(type="percent-max-degradation", value=10.0) @@ -124,43 +104,31 @@ def test_percent_max_degradation_type(self): assert goal.type == "percent-max-degradation" def test_invalid_type_raises(self): - # setup - # execute & assert with pytest.raises(ValidationError, match="Metric goal type must be one of"): MetricGoal(type="invalid_type", value=0.5) def test_negative_value_for_min_improvement_raises(self): - # setup - # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="min-improvement", value=-0.5) def test_negative_value_for_max_degradation_raises(self): - # setup - # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="max-degradation", value=-0.1) def test_negative_value_for_percent_min_improvement_raises(self): - # setup - # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="percent-min-improvement", value=-5.0) def test_negative_value_for_percent_max_degradation_raises(self): - # setup - # execute & assert with pytest.raises(ValidationError, match="Value must be nonnegative"): MetricGoal(type="percent-max-degradation", value=-10.0) def test_threshold_allows_negative_value(self): - # setup - # execute goal = MetricGoal(type="threshold", value=-1.0) @@ -230,8 +198,6 @@ def test_has_regression_goal_threshold(self): class TestGetUserConfigClass: def test_custom_metric_type(self): - # setup - # execute cls = get_user_config_class("custom") instance = cls() @@ -241,8 +207,6 @@ def test_custom_metric_type(self): assert hasattr(instance, "evaluate_func") def test_unknown_metric_type(self): - # setup - # execute cls = get_user_config_class("latency") instance = cls() diff --git a/test/evaluator/test_metric_result.py b/test/evaluator/test_metric_result.py index 396936fbf6..8583ea79c5 100644 --- a/test/evaluator/test_metric_result.py +++ b/test/evaluator/test_metric_result.py @@ -15,8 +15,6 @@ class TestSubMetricResult: def test_creation(self): - # setup - # execute result = SubMetricResult(value=0.95, priority=1, higher_is_better=True) @@ -26,8 +24,6 @@ def test_creation(self): assert result.higher_is_better is True def test_integer_value(self): - # setup - # execute result = SubMetricResult(value=100, priority=2, higher_is_better=False) @@ -35,8 +31,6 @@ def test_integer_value(self): assert result.value == 100 def test_float_value(self): - # setup - # execute result = SubMetricResult(value=0.001, priority=0, higher_is_better=True) @@ -119,8 +113,6 @@ def test_getitem(self): assert item.value == 0.95 def test_delimiter(self): - # setup - # execute delimiter = MetricResult.delimiter @@ -130,8 +122,6 @@ def test_delimiter(self): class TestJointMetricKey: def test_basic(self): - # setup - # execute result = joint_metric_key("accuracy", "top1") @@ -139,8 +129,6 @@ def test_basic(self): assert result == "accuracy-top1" def test_with_special_names(self): - # setup - # execute result = joint_metric_key("latency", "p99") diff --git a/test/search/test_search_parameter.py b/test/search/test_search_parameter.py index a493ea54ec..dcb4cb72ea 100644 --- a/test/search/test_search_parameter.py +++ b/test/search/test_search_parameter.py @@ -17,8 +17,6 @@ class TestSpecialParamValue: def test_ignored_value(self): - # setup - # execute result = SpecialParamValue.IGNORED @@ -26,8 +24,6 @@ def test_ignored_value(self): assert result == "OLIVE_IGNORED_PARAM_VALUE" def test_invalid_value(self): - # setup - # execute result = SpecialParamValue.INVALID @@ -111,8 +107,6 @@ def test_support(self): assert result == [True, False] def test_is_categorical(self): - # setup - # execute result = issubclass(Boolean, Categorical) @@ -171,8 +165,6 @@ def test_default_is_invalid(self): assert result == [SpecialParamValue.INVALID] def test_get_invalid_choice(self): - # setup - # execute result = Conditional.get_invalid_choice() @@ -181,8 +173,6 @@ def test_get_invalid_choice(self): assert result.get_support() == [SpecialParamValue.INVALID] def test_get_ignored_choice(self): - # setup - # execute result = Conditional.get_ignored_choice() @@ -304,8 +294,6 @@ def test_condition(self): assert result_c == 30 def test_get_invalid_choice(self): - # setup - # execute result = ConditionalDefault.get_invalid_choice() @@ -313,8 +301,6 @@ def test_get_invalid_choice(self): assert result == SpecialParamValue.INVALID def test_get_ignored_choice(self): - # setup - # execute result = ConditionalDefault.get_ignored_choice() diff --git a/test/search/test_search_point.py b/test/search/test_search_point.py index 3a736e3d91..23f252751a 100644 --- a/test/search/test_search_point.py +++ b/test/search/test_search_point.py @@ -27,8 +27,6 @@ def _make_point(self, index=0, values=None): return SearchPoint(index=index, values=values) def test_creation(self): - # setup - # execute point = self._make_point() diff --git a/test/search/test_search_sample.py b/test/search/test_search_sample.py index 5c5a002fd5..3dc4c6345e 100644 --- a/test/search/test_search_sample.py +++ b/test/search/test_search_sample.py @@ -34,8 +34,6 @@ def _make_sample(self, invalid_param=False, ignored_param=False): return SearchSample(search_point=point, model_ids=["model_0"]) def test_creation(self): - # setup - # execute sample = self._make_sample() diff --git a/test/telemetry/test_telemetry_utils.py b/test/telemetry/test_telemetry_utils.py index 9aacc8e033..35ec93db2b 100644 --- a/test/telemetry/test_telemetry_utils.py +++ b/test/telemetry/test_telemetry_utils.py @@ -20,8 +20,6 @@ class TestResolveHomeDir: def test_returns_path(self): - # setup - # execute result = _resolve_home_dir() @@ -29,8 +27,6 @@ def test_returns_path(self): assert isinstance(result, Path) def test_with_home_env_set(self): - # setup - # execute with patch.dict(os.environ, {"HOME": "/tmp/test_home"}): result = _resolve_home_dir() @@ -39,8 +35,6 @@ def test_with_home_env_set(self): assert isinstance(result, Path) def test_without_home_env(self): - # setup - # execute with patch.dict(os.environ, {}, clear=True): result = _resolve_home_dir() @@ -119,8 +113,6 @@ def test_encode_empty_string(self): assert result == expected def test_encode_unicode_string(self): - # setup - # execute result = _encode_cache_line("hello \u4e16\u754c") decoded = base64.b64decode(result).decode("utf-8") @@ -163,8 +155,6 @@ class TestEncodeDecodeRoundtrip: ], ) def test_roundtrip(self, text): - # setup - # execute encoded = _encode_cache_line(text) decoded = _decode_cache_line(encoded) diff --git a/test/test_constants.py b/test/test_constants.py index e512739eb0..30f4c44c86 100644 --- a/test/test_constants.py +++ b/test/test_constants.py @@ -23,8 +23,6 @@ class TestFramework: def test_framework_values(self): - # setup - # execute onnx = Framework.ONNX pytorch = Framework.PYTORCH @@ -36,8 +34,6 @@ def test_framework_values(self): assert openvino == "OpenVINO" def test_framework_str(self): - # setup - # execute result = str(Framework.ONNX) @@ -57,8 +53,6 @@ def test_framework_all_members(self): class TestModelFileFormat: def test_model_file_format_values(self): - # setup - # execute onnx = ModelFileFormat.ONNX state_dict = ModelFileFormat.PYTORCH_STATE_DICT @@ -70,8 +64,6 @@ def test_model_file_format_values(self): assert composite == "Composite" def test_model_file_format_str(self): - # setup - # execute result = str(ModelFileFormat.OPENVINO_IR) @@ -81,8 +73,6 @@ def test_model_file_format_str(self): class TestPrecision: def test_precision_values(self): - # setup - # execute int4 = Precision.INT4 fp16 = Precision.FP16 @@ -106,8 +96,6 @@ def test_precision_all_members(self): class TestPrecisionBits: def test_precision_bits_values(self): - # setup - # execute & assert assert PrecisionBits.BITS2 == 2 assert PrecisionBits.BITS4 == 4 @@ -116,8 +104,6 @@ def test_precision_bits_values(self): assert PrecisionBits.BITS32 == 32 def test_precision_bits_is_int(self): - # setup - # execute result = isinstance(PrecisionBits.BITS4.value, int) @@ -127,8 +113,6 @@ def test_precision_bits_is_int(self): class TestQuantAlgorithm: def test_quant_algorithm_case_insensitive(self): - # setup - # execute lower = QuantAlgorithm("awq") upper = QuantAlgorithm("AWQ") @@ -140,8 +124,6 @@ def test_quant_algorithm_case_insensitive(self): assert mixed == QuantAlgorithm.AWQ def test_quant_algorithm_values(self): - # setup - # execute gptq = QuantAlgorithm.GPTQ rtn = QuantAlgorithm.RTN @@ -163,8 +145,6 @@ def test_quant_algorithm_all_members(self): class TestQuantEncoding: def test_quant_encoding_values(self): - # setup - # execute qdq = QuantEncoding.QDQ qop = QuantEncoding.QOP @@ -176,8 +156,6 @@ def test_quant_encoding_values(self): class TestDatasetRequirement: def test_dataset_requirement_values(self): - # setup - # execute required = DatasetRequirement.REQUIRED optional = DatasetRequirement.OPTIONAL @@ -191,8 +169,6 @@ def test_dataset_requirement_values(self): class TestOpType: def test_op_type_values(self): - # setup - # execute matmul = OpType.MatMul add = OpType.Add @@ -206,8 +182,6 @@ def test_op_type_values(self): class TestAccuracyLevel: def test_accuracy_level_values(self): - # setup - # execute & assert assert AccuracyLevel.unset == 0 assert AccuracyLevel.fp32 == 1 @@ -217,8 +191,6 @@ def test_accuracy_level_values(self): class TestDiffusersModelVariant: def test_diffusers_variant_values(self): - # setup - # execute auto = DiffusersModelVariant.AUTO sd = DiffusersModelVariant.SD @@ -232,8 +204,6 @@ def test_diffusers_variant_values(self): class TestDiffusersComponent: def test_diffusers_component_values(self): - # setup - # execute text_encoder = DiffusersComponent.TEXT_ENCODER unet = DiffusersComponent.UNET @@ -260,8 +230,6 @@ class TestPrecisionBitsFromPrecision: ], ) def test_precision_to_bits_mapping(self, precision, expected): - # setup - # execute result = precision_bits_from_precision(precision) @@ -270,8 +238,6 @@ def test_precision_to_bits_mapping(self, precision, expected): @pytest.mark.parametrize("precision", [Precision.FP16, Precision.FP32, Precision.BF16, Precision.NF4]) def test_precision_without_bits_mapping_returns_none(self, precision): - # setup - # execute result = precision_bits_from_precision(precision) @@ -281,8 +247,6 @@ def test_precision_without_bits_mapping_returns_none(self, precision): class TestMsftDomain: def test_msft_domain_value(self): - # setup - # execute result = MSFT_DOMAIN diff --git a/test/test_exception.py b/test/test_exception.py index 4abcf21d0d..6afd936a0d 100644 --- a/test/test_exception.py +++ b/test/test_exception.py @@ -9,8 +9,6 @@ class TestOliveError: def test_olive_error_is_exception(self): - # setup - # execute result = issubclass(OliveError, Exception) @@ -18,15 +16,11 @@ def test_olive_error_is_exception(self): assert result def test_olive_error_can_be_raised(self): - # setup - # execute & assert with pytest.raises(OliveError, match="test error"): raise OliveError("test error") def test_olive_error_empty_message(self): - # setup - # execute & assert with pytest.raises(OliveError): raise OliveError @@ -34,8 +28,6 @@ def test_olive_error_empty_message(self): class TestOlivePassError: def test_olive_pass_error_inherits_olive_error(self): - # setup - # execute result = issubclass(OlivePassError, OliveError) @@ -43,15 +35,11 @@ def test_olive_pass_error_inherits_olive_error(self): assert result def test_olive_pass_error_can_be_raised(self): - # setup - # execute & assert with pytest.raises(OlivePassError, match="pass failed"): raise OlivePassError("pass failed") def test_olive_pass_error_caught_as_olive_error(self): - # setup - # execute & assert with pytest.raises(OliveError): raise OlivePassError("pass failed") @@ -59,8 +47,6 @@ def test_olive_pass_error_caught_as_olive_error(self): class TestOliveEvaluationError: def test_olive_evaluation_error_inherits_olive_error(self): - # setup - # execute result = issubclass(OliveEvaluationError, OliveError) @@ -68,15 +54,11 @@ def test_olive_evaluation_error_inherits_olive_error(self): assert result def test_olive_evaluation_error_can_be_raised(self): - # setup - # execute & assert with pytest.raises(OliveEvaluationError, match="evaluation failed"): raise OliveEvaluationError("evaluation failed") def test_olive_evaluation_error_caught_as_olive_error(self): - # setup - # execute & assert with pytest.raises(OliveError): raise OliveEvaluationError("evaluation failed") @@ -84,8 +66,6 @@ def test_olive_evaluation_error_caught_as_olive_error(self): class TestExceptionsToRaise: def test_exceptions_to_raise_is_tuple(self): - # setup - # execute result = isinstance(EXCEPTIONS_TO_RAISE, tuple) @@ -104,8 +84,6 @@ def test_exceptions_to_raise_contains_expected_types(self): @pytest.mark.parametrize("exc_type", EXCEPTIONS_TO_RAISE) def test_each_exception_is_catchable(self, exc_type): - # setup - # execute & assert with pytest.raises(exc_type): raise exc_type("test") diff --git a/test/test_logging.py b/test/test_logging.py index 0a9a9f34ff..1cedd594a4 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -26,8 +26,6 @@ class TestGetOliveLogger: def test_returns_olive_logger(self): - # setup - # execute logger = get_olive_logger() @@ -35,8 +33,6 @@ def test_returns_olive_logger(self): assert logger.name == "olive" def test_returns_same_logger_instance(self): - # setup - # execute logger1 = get_olive_logger() logger2 = get_olive_logger() @@ -47,8 +43,6 @@ def test_returns_same_logger_instance(self): class TestSetVerbosity: def test_set_verbosity_info(self): - # setup - # execute set_verbosity_info() @@ -56,8 +50,6 @@ def test_set_verbosity_info(self): assert get_olive_logger().level == logging.INFO def test_set_verbosity_warning(self): - # setup - # execute set_verbosity_warning() @@ -65,8 +57,6 @@ def test_set_verbosity_warning(self): assert get_olive_logger().level == logging.WARNING def test_set_verbosity_debug(self): - # setup - # execute set_verbosity_debug() @@ -74,8 +64,6 @@ def test_set_verbosity_debug(self): assert get_olive_logger().level == logging.DEBUG def test_set_verbosity_error(self): - # setup - # execute set_verbosity_error() @@ -83,8 +71,6 @@ def test_set_verbosity_error(self): assert get_olive_logger().level == logging.ERROR def test_set_verbosity_critical(self): - # setup - # execute set_verbosity_critical() @@ -92,8 +78,6 @@ def test_set_verbosity_critical(self): assert get_olive_logger().level == logging.CRITICAL def test_set_verbosity_custom_level(self): - # setup - # execute set_verbosity(logging.WARNING) @@ -103,8 +87,6 @@ def test_set_verbosity_custom_level(self): class TestSetVerbosityFromEnv: def test_set_verbosity_from_env_default(self): - # setup - # execute with patch.dict("os.environ", {}, clear=True): set_verbosity_from_env() @@ -112,8 +94,6 @@ def test_set_verbosity_from_env_default(self): # assert (no exception raised) def test_set_verbosity_from_env_custom(self): - # setup - # execute with patch.dict("os.environ", {"OLIVE_LOG_LEVEL": "DEBUG"}): set_verbosity_from_env() @@ -147,8 +127,6 @@ class TestGetLoggerLevel: ], ) def test_valid_levels(self, level_int, expected): - # setup - # execute result = get_logger_level(level_int) @@ -157,8 +135,6 @@ def test_valid_levels(self, level_int, expected): @pytest.mark.parametrize("invalid_level", [-1, 5, 10, 100]) def test_invalid_levels_raise_value_error(self, invalid_level): - # setup - # execute & assert with pytest.raises(ValueError, match="Invalid level"): get_logger_level(invalid_level) From 5b0187ae1ab7354877d75df776a057700f9f0668 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 03:47:13 +0000 Subject: [PATCH 5/6] Remove nonsensical enum/constant value tests from test_constants.py Agent-Logs-Url: https://github.com/microsoft/Olive/sessions/4d88a644-eb92-4192-8494-b25f650bc10d Co-authored-by: xiaoyu-work <85524621+xiaoyu-work@users.noreply.github.com> --- test/test_constants.py | 164 ----------------------------------------- 1 file changed, 164 deletions(-) diff --git a/test/test_constants.py b/test/test_constants.py index 30f4c44c86..77283c2089 100644 --- a/test/test_constants.py +++ b/test/test_constants.py @@ -5,41 +5,15 @@ import pytest from olive.constants import ( - MSFT_DOMAIN, - AccuracyLevel, - DatasetRequirement, - DiffusersComponent, - DiffusersModelVariant, Framework, - ModelFileFormat, - OpType, Precision, PrecisionBits, QuantAlgorithm, - QuantEncoding, precision_bits_from_precision, ) class TestFramework: - def test_framework_values(self): - # execute - onnx = Framework.ONNX - pytorch = Framework.PYTORCH - openvino = Framework.OPENVINO - - # assert - assert onnx == "ONNX" - assert pytorch == "PyTorch" - assert openvino == "OpenVINO" - - def test_framework_str(self): - # execute - result = str(Framework.ONNX) - - # assert - assert result == "ONNX" - def test_framework_all_members(self): # setup expected = {"ONNX", "PYTORCH", "QAIRT", "QNN", "TENSORFLOW", "OPENVINO"} @@ -51,38 +25,7 @@ def test_framework_all_members(self): assert result == expected -class TestModelFileFormat: - def test_model_file_format_values(self): - # execute - onnx = ModelFileFormat.ONNX - state_dict = ModelFileFormat.PYTORCH_STATE_DICT - composite = ModelFileFormat.COMPOSITE_MODEL - - # assert - assert onnx == "ONNX" - assert state_dict == "PyTorch.StateDict" - assert composite == "Composite" - - def test_model_file_format_str(self): - # execute - result = str(ModelFileFormat.OPENVINO_IR) - - # assert - assert result == "OpenVINO.IR" - - class TestPrecision: - def test_precision_values(self): - # execute - int4 = Precision.INT4 - fp16 = Precision.FP16 - bf16 = Precision.BF16 - - # assert - assert int4 == "int4" - assert fp16 == "fp16" - assert bf16 == "bf16" - def test_precision_all_members(self): # setup expected_count = 14 @@ -94,23 +37,6 @@ def test_precision_all_members(self): assert result == expected_count -class TestPrecisionBits: - def test_precision_bits_values(self): - # execute & assert - assert PrecisionBits.BITS2 == 2 - assert PrecisionBits.BITS4 == 4 - assert PrecisionBits.BITS8 == 8 - assert PrecisionBits.BITS16 == 16 - assert PrecisionBits.BITS32 == 32 - - def test_precision_bits_is_int(self): - # execute - result = isinstance(PrecisionBits.BITS4.value, int) - - # assert - assert result - - class TestQuantAlgorithm: def test_quant_algorithm_case_insensitive(self): # execute @@ -123,15 +49,6 @@ def test_quant_algorithm_case_insensitive(self): assert upper == QuantAlgorithm.AWQ assert mixed == QuantAlgorithm.AWQ - def test_quant_algorithm_values(self): - # execute - gptq = QuantAlgorithm.GPTQ - rtn = QuantAlgorithm.RTN - - # assert - assert gptq == "gptq" - assert rtn == "rtn" - def test_quant_algorithm_all_members(self): # setup expected = {"AWQ", "GPTQ", "HQQ", "RTN", "SPINQUANT", "QUAROT", "LPBQ", "SEQMSE", "ADAROUND"} @@ -143,78 +60,6 @@ def test_quant_algorithm_all_members(self): assert result == expected -class TestQuantEncoding: - def test_quant_encoding_values(self): - # execute - qdq = QuantEncoding.QDQ - qop = QuantEncoding.QOP - - # assert - assert qdq == "qdq" - assert qop == "qop" - - -class TestDatasetRequirement: - def test_dataset_requirement_values(self): - # execute - required = DatasetRequirement.REQUIRED - optional = DatasetRequirement.OPTIONAL - not_required = DatasetRequirement.NOT_REQUIRED - - # assert - assert required == "dataset_required" - assert optional == "dataset_optional" - assert not_required == "dataset_not_required" - - -class TestOpType: - def test_op_type_values(self): - # execute - matmul = OpType.MatMul - add = OpType.Add - custom = OpType.Custom - - # assert - assert matmul == "MatMul" - assert add == "Add" - assert custom == "custom" - - -class TestAccuracyLevel: - def test_accuracy_level_values(self): - # execute & assert - assert AccuracyLevel.unset == 0 - assert AccuracyLevel.fp32 == 1 - assert AccuracyLevel.fp16 == 2 - assert AccuracyLevel.int8 == 4 - - -class TestDiffusersModelVariant: - def test_diffusers_variant_values(self): - # execute - auto = DiffusersModelVariant.AUTO - sd = DiffusersModelVariant.SD - flux = DiffusersModelVariant.FLUX - - # assert - assert auto == "auto" - assert sd == "sd" - assert flux == "flux" - - -class TestDiffusersComponent: - def test_diffusers_component_values(self): - # execute - text_encoder = DiffusersComponent.TEXT_ENCODER - unet = DiffusersComponent.UNET - vae_decoder = DiffusersComponent.VAE_DECODER - - # assert - assert text_encoder == "text_encoder" - assert unet == "unet" - assert vae_decoder == "vae_decoder" - - class TestPrecisionBitsFromPrecision: @pytest.mark.parametrize( ("precision", "expected"), @@ -243,12 +88,3 @@ def test_precision_without_bits_mapping_returns_none(self, precision): # assert assert result is None - - -class TestMsftDomain: - def test_msft_domain_value(self): - # execute - result = MSFT_DOMAIN - - # assert - assert result == "com.microsoft" From cb03be75acdee397ac526546560882b4574ca1e3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 04:07:51 +0000 Subject: [PATCH 6/6] Address PR review comments: fix CodeQL issues, add logger fixture, improve assertions, rename test Agent-Logs-Url: https://github.com/microsoft/Olive/sessions/6e93cade-0aec-4b0e-b700-2063624d8f7d Co-authored-by: xiaoyu-work <85524621+xiaoyu-work@users.noreply.github.com> --- test/evaluator/test_metric_config.py | 2 +- test/telemetry/test_telemetry_utils.py | 2 ++ test/test_logging.py | 41 ++++++++++++++++++-------- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/test/evaluator/test_metric_config.py b/test/evaluator/test_metric_config.py index 4145ae2868..f88dbe3de4 100644 --- a/test/evaluator/test_metric_config.py +++ b/test/evaluator/test_metric_config.py @@ -206,7 +206,7 @@ def test_custom_metric_type(self): assert hasattr(instance, "user_script") assert hasattr(instance, "evaluate_func") - def test_unknown_metric_type(self): + def test_non_custom_metric_type_includes_common_fields(self): # execute cls = get_user_config_class("latency") instance = cls() diff --git a/test/telemetry/test_telemetry_utils.py b/test/telemetry/test_telemetry_utils.py index 35ec93db2b..97029abd33 100644 --- a/test/telemetry/test_telemetry_utils.py +++ b/test/telemetry/test_telemetry_utils.py @@ -68,6 +68,8 @@ def test_path_contains_onnxruntime(self): class TestFormatExceptionMessage: def test_basic_exception(self): # setup + exception = None + traceback = None try: 1 / 0 # noqa: B018 except ZeroDivisionError as ex: diff --git a/test/test_logging.py b/test/test_logging.py index 1cedd594a4..99a23ed89a 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -24,6 +24,15 @@ ) +@pytest.fixture(autouse=True) +def _restore_logger_level(): + """Save and restore the olive logger level between tests.""" + logger = get_olive_logger() + original_level = logger.level + yield + logger.setLevel(original_level) + + class TestGetOliveLogger: def test_returns_olive_logger(self): # execute @@ -91,7 +100,8 @@ def test_set_verbosity_from_env_default(self): with patch.dict("os.environ", {}, clear=True): set_verbosity_from_env() - # assert (no exception raised) + # assert + assert get_olive_logger().level == logging.INFO def test_set_verbosity_from_env_custom(self): # execute @@ -157,18 +167,25 @@ class TestEnableFilelog: def test_enable_filelog_creates_handler(self, tmp_path): # setup workflow_id = "test_workflow" - - # execute - enable_filelog(1, str(tmp_path), workflow_id) - - # assert logger = get_olive_logger() - log_file_path = tmp_path / f"{workflow_id}.log" - file_handlers = [h for h in logger.handlers if isinstance(h, logging.FileHandler)] - assert len(file_handlers) > 0 + original_handler_ids = {id(h) for h in logger.handlers} + log_file_path = (tmp_path / f"{workflow_id}.log").resolve() - # cleanup - for h in file_handlers: - if Path(h.baseFilename) == log_file_path.resolve(): + try: + # execute + enable_filelog(1, str(tmp_path), workflow_id) + + # assert + new_handlers = [h for h in logger.handlers if id(h) not in original_handler_ids] + matching_handlers = [ + h + for h in new_handlers + if isinstance(h, logging.FileHandler) and Path(h.baseFilename).resolve() == log_file_path + ] + assert matching_handlers, f"Expected a FileHandler for {log_file_path}, but none was added." + assert log_file_path.exists() + finally: + # cleanup + for h in [h for h in logger.handlers if id(h) not in original_handler_ids]: logger.removeHandler(h) h.close()