Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
610 changes: 610 additions & 0 deletions test/common/test_config_utils.py

Large diffs are not rendered by default.

144 changes: 144 additions & 0 deletions test/data_container/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from olive.data.constants import (
DataComponentType,
DefaultDataContainer,
)
from olive.data.registry import Registry


class TestRegistryRegister:
def test_register_dataset_component(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET, name="test_dataset_reg")
def my_dataset():
return "dataset"

# assert
result = Registry.get_load_dataset_component("test_dataset_reg")
assert result is my_dataset

def test_register_pre_process_component(self):
# setup & execute
@Registry.register_pre_process(name="test_pre_process_reg")
def my_pre_process(data):
return data

# assert
result = Registry.get_pre_process_component("test_pre_process_reg")
assert result is my_pre_process

def test_register_post_process_component(self):
# setup & execute
@Registry.register_post_process(name="test_post_process_reg")
def my_post_process(data):
return data

# assert
result = Registry.get_post_process_component("test_post_process_reg")
assert result is my_post_process

def test_register_dataloader_component(self):
# setup & execute
@Registry.register_dataloader(name="test_dataloader_reg")
def my_dataloader(data):
return data

# assert
result = Registry.get_dataloader_component("test_dataloader_reg")
assert result is my_dataloader

def test_register_case_insensitive(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET, name="CaseSensitiveTest_Reg")
def my_func():
pass

# assert
result = Registry.get_load_dataset_component("casesensitivetest_reg")
assert result is my_func

def test_register_uses_class_name_when_no_name(self):
# setup & execute
@Registry.register(DataComponentType.LOAD_DATASET)
def unique_named_test_func_reg():
pass

# assert
result = Registry.get_load_dataset_component("unique_named_test_func_reg")
assert result is unique_named_test_func_reg


class TestRegistryGet:
def test_get_component(self):
# setup
@Registry.register(DataComponentType.LOAD_DATASET, name="test_get_comp_reg")
def my_func():
pass

# execute
result = Registry.get_component(DataComponentType.LOAD_DATASET.value, "test_get_comp_reg")

# assert
assert result is my_func

def test_get_by_subtype(self):
# setup
@Registry.register(DataComponentType.LOAD_DATASET, name="test_get_subtype_reg")
def my_func():
pass

# execute
result = Registry.get(DataComponentType.LOAD_DATASET.value, "test_get_subtype_reg")

# assert
assert result is my_func


class TestRegistryDefaultComponents:
def test_get_default_load_dataset(self):
# execute
result = Registry.get_default_load_dataset_component()

# assert
assert result is not None

def test_get_default_pre_process(self):
# execute
result = Registry.get_default_pre_process_component()

# assert
assert result is not None

def test_get_default_post_process(self):
# execute
result = Registry.get_default_post_process_component()

# assert
assert result is not None

def test_get_default_dataloader(self):
# execute
result = Registry.get_default_dataloader_component()

# assert
assert result is not None


class TestRegistryContainer:
def test_get_container_default(self):
# execute
result = Registry.get_container(None)

# assert
assert result is not None

def test_get_container_by_name(self):
# execute
result = Registry.get_container(DefaultDataContainer.DATA_CONTAINER.value)

# assert
assert result is not None
216 changes: 216 additions & 0 deletions test/evaluator/test_metric_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import pytest
from pydantic import ValidationError

from olive.evaluator.metric_config import (
LatencyMetricConfig,
MetricGoal,
SizeOnDiskMetricConfig,
ThroughputMetricConfig,
get_user_config_class,
)


class TestLatencyMetricConfig:
def test_defaults(self):
# execute
config = LatencyMetricConfig()

# assert
assert config.warmup_num == 10
assert config.repeat_test_num == 20
assert config.sleep_num == 0

def test_custom_values(self):
# execute
config = LatencyMetricConfig(warmup_num=5, repeat_test_num=100, sleep_num=2)

# assert
assert config.warmup_num == 5
assert config.repeat_test_num == 100
assert config.sleep_num == 2


class TestThroughputMetricConfig:
def test_defaults(self):
# execute
config = ThroughputMetricConfig()

# assert
assert config.warmup_num == 10
assert config.repeat_test_num == 20
assert config.sleep_num == 0

def test_custom_values(self):
# execute
config = ThroughputMetricConfig(warmup_num=3, repeat_test_num=50, sleep_num=1)

# assert
assert config.warmup_num == 3
assert config.repeat_test_num == 50
assert config.sleep_num == 1


class TestSizeOnDiskMetricConfig:
def test_creation(self):
# execute
config = SizeOnDiskMetricConfig()

# assert
assert isinstance(config, SizeOnDiskMetricConfig)


class TestMetricGoal:
def test_threshold_type(self):
# execute
goal = MetricGoal(type="threshold", value=0.9)

# assert
assert goal.type == "threshold"
assert goal.value == 0.9

def test_min_improvement_type(self):
# execute
goal = MetricGoal(type="min-improvement", value=0.05)

# assert
assert goal.type == "min-improvement"
assert goal.value == 0.05

def test_max_degradation_type(self):
# execute
goal = MetricGoal(type="max-degradation", value=0.1)

# assert
assert goal.type == "max-degradation"
assert goal.value == 0.1

def test_percent_min_improvement_type(self):
# execute
goal = MetricGoal(type="percent-min-improvement", value=5.0)

# assert
assert goal.type == "percent-min-improvement"

def test_percent_max_degradation_type(self):
# execute
goal = MetricGoal(type="percent-max-degradation", value=10.0)

# assert
assert goal.type == "percent-max-degradation"

def test_invalid_type_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Metric goal type must be one of"):
MetricGoal(type="invalid_type", value=0.5)

def test_negative_value_for_min_improvement_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="min-improvement", value=-0.5)

def test_negative_value_for_max_degradation_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="max-degradation", value=-0.1)

def test_negative_value_for_percent_min_improvement_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="percent-min-improvement", value=-5.0)

def test_negative_value_for_percent_max_degradation_raises(self):
# execute & assert
with pytest.raises(ValidationError, match="Value must be nonnegative"):
MetricGoal(type="percent-max-degradation", value=-10.0)

def test_threshold_allows_negative_value(self):
# execute
goal = MetricGoal(type="threshold", value=-1.0)

# assert
assert goal.value == -1.0

def test_has_regression_goal_min_improvement(self):
# setup
goal = MetricGoal(type="min-improvement", value=0.05)

# execute
result = goal.has_regression_goal()

# assert
assert result is False

def test_has_regression_goal_percent_min_improvement(self):
# setup
goal = MetricGoal(type="percent-min-improvement", value=5.0)

# execute
result = goal.has_regression_goal()

# assert
assert result is False

def test_has_regression_goal_max_degradation_positive(self):
# setup
goal = MetricGoal(type="max-degradation", value=0.1)

# execute
result = goal.has_regression_goal()

# assert
assert result is True

def test_has_regression_goal_max_degradation_zero(self):
# setup
goal = MetricGoal(type="max-degradation", value=0.0)

# execute
result = goal.has_regression_goal()

# assert
assert result is False

def test_has_regression_goal_percent_max_degradation_positive(self):
# setup
goal = MetricGoal(type="percent-max-degradation", value=10.0)

# execute
result = goal.has_regression_goal()

# assert
assert result is True

def test_has_regression_goal_threshold(self):
# setup
goal = MetricGoal(type="threshold", value=0.9)

# execute
result = goal.has_regression_goal()

# assert
assert result is False


class TestGetUserConfigClass:
def test_custom_metric_type(self):
# execute
cls = get_user_config_class("custom")
instance = cls()

# assert
assert hasattr(instance, "user_script")
assert hasattr(instance, "evaluate_func")

def test_non_custom_metric_type_includes_common_fields(self):
# execute
cls = get_user_config_class("latency")
instance = cls()

# assert
assert hasattr(instance, "user_script")
assert hasattr(instance, "inference_settings")
Loading
Loading