Skip to content

Commit 88bcedc

Browse files
committed
Chore: clean up unit test selection
1 parent 98998d4 commit 88bcedc

File tree

5 files changed

+48
-74
lines changed

5 files changed

+48
-74
lines changed

sqlmesh/core/context.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@
147147
from typing_extensions import Literal
148148

149149
from sqlmesh.core.engine_adapter._typing import (
150-
DF,
151150
BigframeSession,
151+
DF,
152152
PySparkDataFrame,
153153
PySparkSession,
154154
SnowparkSession,
@@ -403,6 +403,7 @@ def __init__(
403403
self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {}
404404
self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {}
405405
self._models_with_tests: t.Set[str] = set()
406+
406407
self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
407408
self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
408409
self._jinja_macros = JinjaMacroRegistry()
@@ -656,6 +657,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
656657
self._requirements.update(project.requirements)
657658
self._excluded_requirements.update(project.excluded_requirements)
658659
self._environment_statements.extend(project.environment_statements)
660+
659661
self._model_test_metadata.extend(project.model_test_metadata)
660662
for metadata in project.model_test_metadata:
661663
if metadata.path not in self._model_test_metadata_path_index:
@@ -2243,9 +2245,7 @@ def test(
22432245

22442246
pd.set_option("display.max_columns", None)
22452247

2246-
test_meta = self._select_tests(
2247-
test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns
2248-
)
2248+
test_meta = self.select_tests(tests=tests, patterns=match_patterns)
22492249

22502250
result = run_tests(
22512251
model_test_metadata=test_meta,
@@ -2807,33 +2807,6 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
28072807
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
28082808
return self.engine_adapter
28092809

2810-
def _select_tests(
2811-
self,
2812-
test_meta: t.List[ModelTestMetadata],
2813-
tests: t.Optional[t.List[str]] = None,
2814-
patterns: t.Optional[t.List[str]] = None,
2815-
) -> t.List[ModelTestMetadata]:
2816-
"""Filter pre-loaded test metadata based on tests and patterns."""
2817-
2818-
if tests:
2819-
filtered_tests = []
2820-
for test in tests:
2821-
if "::" in test:
2822-
if test in self._model_test_metadata_fully_qualified_name_index:
2823-
filtered_tests.append(
2824-
self._model_test_metadata_fully_qualified_name_index[test]
2825-
)
2826-
else:
2827-
test_path = Path(test)
2828-
if test_path in self._model_test_metadata_path_index:
2829-
filtered_tests.extend(self._model_test_metadata_path_index[test_path])
2830-
test_meta = filtered_tests
2831-
2832-
if patterns:
2833-
test_meta = filter_tests_by_patterns(test_meta, patterns)
2834-
2835-
return test_meta
2836-
28372810
def _snapshots(
28382811
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
28392812
) -> t.Dict[str, Snapshot]:
@@ -3245,18 +3218,34 @@ def lint_models(
32453218

32463219
return all_violations
32473220

3248-
def load_model_tests(
3249-
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
3221+
def select_tests(
3222+
self,
3223+
tests: t.Optional[t.List[str]] = None,
3224+
patterns: t.Optional[t.List[str]] = None,
32503225
) -> t.List[ModelTestMetadata]:
3251-
# If a set of specific test path(s) are provided, we can use a single loader
3252-
# since it's not required to walk every tests/ folder in each repo
3253-
loaders = [self._loaders[0]] if tests else self._loaders
3226+
"""Filter pre-loaded test metadata based on tests and patterns."""
3227+
3228+
test_meta = self._model_test_metadata
3229+
3230+
if tests:
3231+
filtered_tests = []
3232+
for test in tests:
3233+
if "::" in test:
3234+
if test in self._model_test_metadata_fully_qualified_name_index:
3235+
filtered_tests.append(
3236+
self._model_test_metadata_fully_qualified_name_index[test]
3237+
)
3238+
else:
3239+
test_path = Path(test)
3240+
if test_path in self._model_test_metadata_path_index:
3241+
filtered_tests.extend(self._model_test_metadata_path_index[test_path])
3242+
3243+
test_meta = filtered_tests
32543244

3255-
model_tests = []
3256-
for loader in loaders:
3257-
model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns))
3245+
if patterns:
3246+
test_meta = filter_tests_by_patterns(test_meta, patterns)
32583247

3259-
return model_tests
3248+
return test_meta
32603249

32613250

32623251
class Context(GenericContext[Config]):

sqlmesh/core/linter/rules/builtin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
130130

131131

132132
class NoMissingUnitTest(Rule):
133-
"""All models must have a unit test found in the test/ directory yaml files"""
133+
"""All models must have a unit test found in the tests/ directory yaml files"""
134134

135135
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
136136
# External models cannot have unit tests

sqlmesh/core/loader.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from sqlmesh.core.model import model as model_registry
3636
from sqlmesh.core.model.common import make_python_env
3737
from sqlmesh.core.signal import signal
38-
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
38+
from sqlmesh.core.test import ModelTestMetadata
3939
from sqlmesh.utils import UniqueKeyDict, sys_path
4040
from sqlmesh.utils.errors import ConfigError
4141
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
@@ -427,9 +427,7 @@ def _load_linting_rules(self) -> RuleSet:
427427
"""Loads user linting rules"""
428428
return RuleSet()
429429

430-
def load_model_tests(
431-
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
432-
) -> t.List[ModelTestMetadata]:
430+
def load_model_tests(self) -> t.List[ModelTestMetadata]:
433431
"""Loads YAML-based model tests"""
434432
return []
435433

@@ -868,38 +866,23 @@ def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]:
868866

869867
return model_test_metadata
870868

871-
def load_model_tests(
872-
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
873-
) -> t.List[ModelTestMetadata]:
869+
def load_model_tests(self) -> t.List[ModelTestMetadata]:
874870
"""Loads YAML-based model tests"""
875871
test_meta_list: t.List[ModelTestMetadata] = []
876872

877-
if tests:
878-
for test in tests:
879-
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
873+
search_path = Path(self.config_path) / c.TESTS
880874

881-
test_meta = self._load_model_test_file(Path(filename))
882-
if test_name:
883-
test_meta_list.append(test_meta[test_name])
884-
else:
885-
test_meta_list.extend(test_meta.values())
886-
else:
887-
search_path = Path(self.config_path) / c.TESTS
888-
889-
for yaml_file in itertools.chain(
890-
search_path.glob("**/test*.yaml"),
891-
search_path.glob("**/test*.yml"),
875+
for yaml_file in itertools.chain(
876+
search_path.glob("**/test*.yaml"),
877+
search_path.glob("**/test*.yml"),
878+
):
879+
if any(
880+
yaml_file.match(ignore_pattern)
881+
for ignore_pattern in self.config.ignore_patterns or []
892882
):
893-
if any(
894-
yaml_file.match(ignore_pattern)
895-
for ignore_pattern in self.config.ignore_patterns or []
896-
):
897-
continue
898-
899-
test_meta_list.extend(self._load_model_test_file(yaml_file).values())
883+
continue
900884

901-
if patterns:
902-
test_meta_list = filter_tests_by_patterns(test_meta_list, patterns)
885+
test_meta_list.extend(self._load_model_test_file(yaml_file).values())
903886

904887
return test_meta_list
905888

sqlmesh/lsp/context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(self, context: Context) -> None:
7272

7373
def list_workspace_tests(self) -> t.List[TestEntry]:
7474
"""List all tests in the workspace."""
75-
tests = self.context.load_model_tests()
75+
tests = self.context.select_tests()
7676

7777
# Use a set to ensure unique URIs
7878
unique_test_uris = {URI.from_path(test.path).value for test in tests}
@@ -81,7 +81,9 @@ def list_workspace_tests(self) -> t.List[TestEntry]:
8181
test_ranges = get_test_ranges(URI(uri).to_path())
8282
if uri not in test_uris:
8383
test_uris[uri] = {}
84+
8485
test_uris[uri].update(test_ranges)
86+
8587
return [
8688
TestEntry(
8789
name=test.test_name,
@@ -100,7 +102,7 @@ def get_document_tests(self, uri: URI) -> t.List[TestEntry]:
100102
Returns:
101103
List of TestEntry objects for the specified document.
102104
"""
103-
tests = self.context.load_model_tests(tests=[str(uri.to_path())])
105+
tests = self.context.select_tests(tests=[str(uri.to_path())])
104106
test_ranges = get_test_ranges(uri.to_path())
105107
return [
106108
TestEntry(

sqlmesh/magics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None
337337
if not args.test_name and not args.ls:
338338
raise MagicError("Must provide either test name or `--ls` to list tests")
339339

340-
test_meta = context.load_model_tests()
340+
test_meta = context.select_tests()
341341

342342
tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict)
343343
for model_test_metadata in test_meta:

0 commit comments

Comments
 (0)