Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@
from typing_extensions import Literal

from sqlmesh.core.engine_adapter._typing import (
DF,
BigframeSession,
DF,
PySparkDataFrame,
PySparkSession,
SnowparkSession,
Expand Down Expand Up @@ -403,6 +403,7 @@ def __init__(
self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {}
self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {}
self._models_with_tests: t.Set[str] = set()

self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
self._jinja_macros = JinjaMacroRegistry()
Expand Down Expand Up @@ -656,6 +657,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
self._requirements.update(project.requirements)
self._excluded_requirements.update(project.excluded_requirements)
self._environment_statements.extend(project.environment_statements)

self._model_test_metadata.extend(project.model_test_metadata)
for metadata in project.model_test_metadata:
if metadata.path not in self._model_test_metadata_path_index:
Expand Down Expand Up @@ -2243,9 +2245,7 @@ def test(

pd.set_option("display.max_columns", None)

test_meta = self._select_tests(
test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns
)
test_meta = self.select_tests(tests=tests, patterns=match_patterns)

result = run_tests(
model_test_metadata=test_meta,
Expand Down Expand Up @@ -2807,33 +2807,6 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
return self.engine_adapter

def _select_tests(
self,
test_meta: t.List[ModelTestMetadata],
tests: t.Optional[t.List[str]] = None,
patterns: t.Optional[t.List[str]] = None,
) -> t.List[ModelTestMetadata]:
"""Filter pre-loaded test metadata based on tests and patterns."""

if tests:
filtered_tests = []
for test in tests:
if "::" in test:
if test in self._model_test_metadata_fully_qualified_name_index:
filtered_tests.append(
self._model_test_metadata_fully_qualified_name_index[test]
)
else:
test_path = Path(test)
if test_path in self._model_test_metadata_path_index:
filtered_tests.extend(self._model_test_metadata_path_index[test_path])
test_meta = filtered_tests

if patterns:
test_meta = filter_tests_by_patterns(test_meta, patterns)

return test_meta

def _snapshots(
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
) -> t.Dict[str, Snapshot]:
Expand Down Expand Up @@ -3245,18 +3218,34 @@ def lint_models(

return all_violations

def load_model_tests(
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
def select_tests(
self,
tests: t.Optional[t.List[str]] = None,
patterns: t.Optional[t.List[str]] = None,
) -> t.List[ModelTestMetadata]:
# If a set of specific test path(s) are provided, we can use a single loader
# since it's not required to walk every tests/ folder in each repo
loaders = [self._loaders[0]] if tests else self._loaders
"""Filter pre-loaded test metadata based on tests and patterns."""

test_meta = self._model_test_metadata

if tests:
filtered_tests = []
for test in tests:
if "::" in test:
if test in self._model_test_metadata_fully_qualified_name_index:
filtered_tests.append(
self._model_test_metadata_fully_qualified_name_index[test]
)
else:
test_path = Path(test)
if test_path in self._model_test_metadata_path_index:
filtered_tests.extend(self._model_test_metadata_path_index[test_path])

test_meta = filtered_tests

model_tests = []
for loader in loaders:
model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns))
if patterns:
test_meta = filter_tests_by_patterns(test_meta, patterns)

return model_tests
return test_meta


class Context(GenericContext[Config]):
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:


class NoMissingUnitTest(Rule):
"""All models must have a unit test found in the test/ directory yaml files"""
"""All models must have a unit test found in the tests/ directory yaml files"""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
# External models cannot have unit tests
Expand Down
43 changes: 13 additions & 30 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sqlmesh.core.model import model as model_registry
from sqlmesh.core.model.common import make_python_env
from sqlmesh.core.signal import signal
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
from sqlmesh.core.test import ModelTestMetadata
from sqlmesh.utils import UniqueKeyDict, sys_path
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
Expand Down Expand Up @@ -427,9 +427,7 @@ def _load_linting_rules(self) -> RuleSet:
"""Loads user linting rules"""
return RuleSet()

def load_model_tests(
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
) -> t.List[ModelTestMetadata]:
def load_model_tests(self) -> t.List[ModelTestMetadata]:
"""Loads YAML-based model tests"""
return []

Expand Down Expand Up @@ -868,38 +866,23 @@ def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]:

return model_test_metadata

def load_model_tests(
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
) -> t.List[ModelTestMetadata]:
def load_model_tests(self) -> t.List[ModelTestMetadata]:
"""Loads YAML-based model tests"""
test_meta_list: t.List[ModelTestMetadata] = []

if tests:
for test in tests:
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
search_path = Path(self.config_path) / c.TESTS

test_meta = self._load_model_test_file(Path(filename))
if test_name:
test_meta_list.append(test_meta[test_name])
else:
test_meta_list.extend(test_meta.values())
else:
search_path = Path(self.config_path) / c.TESTS

for yaml_file in itertools.chain(
search_path.glob("**/test*.yaml"),
search_path.glob("**/test*.yml"),
for yaml_file in itertools.chain(
search_path.glob("**/test*.yaml"),
search_path.glob("**/test*.yml"),
):
if any(
yaml_file.match(ignore_pattern)
for ignore_pattern in self.config.ignore_patterns or []
):
if any(
yaml_file.match(ignore_pattern)
for ignore_pattern in self.config.ignore_patterns or []
):
continue

test_meta_list.extend(self._load_model_test_file(yaml_file).values())
continue

if patterns:
test_meta_list = filter_tests_by_patterns(test_meta_list, patterns)
test_meta_list.extend(self._load_model_test_file(yaml_file).values())

return test_meta_list

Expand Down
6 changes: 4 additions & 2 deletions sqlmesh/lsp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, context: Context) -> None:

def list_workspace_tests(self) -> t.List[TestEntry]:
"""List all tests in the workspace."""
tests = self.context.load_model_tests()
tests = self.context.select_tests()

# Use a set to ensure unique URIs
unique_test_uris = {URI.from_path(test.path).value for test in tests}
Expand All @@ -81,7 +81,9 @@ def list_workspace_tests(self) -> t.List[TestEntry]:
test_ranges = get_test_ranges(URI(uri).to_path())
if uri not in test_uris:
test_uris[uri] = {}

test_uris[uri].update(test_ranges)

return [
TestEntry(
name=test.test_name,
Expand All @@ -100,7 +102,7 @@ def get_document_tests(self, uri: URI) -> t.List[TestEntry]:
Returns:
List of TestEntry objects for the specified document.
"""
tests = self.context.load_model_tests(tests=[str(uri.to_path())])
tests = self.context.select_tests(tests=[str(uri.to_path())])
test_ranges = get_test_ranges(uri.to_path())
return [
TestEntry(
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None
if not args.test_name and not args.ls:
raise MagicError("Must provide either test name or `--ls` to list tests")

test_meta = context.load_model_tests()
test_meta = context.select_tests()

tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict)
for model_test_metadata in test_meta:
Expand Down
Loading