|
147 | 147 | from typing_extensions import Literal |
148 | 148 |
|
149 | 149 | from sqlmesh.core.engine_adapter._typing import ( |
150 | | - DF, |
151 | 150 | BigframeSession, |
| 151 | + DF, |
152 | 152 | PySparkDataFrame, |
153 | 153 | PySparkSession, |
154 | 154 | SnowparkSession, |
@@ -403,6 +403,7 @@ def __init__( |
403 | 403 | self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {} |
404 | 404 | self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {} |
405 | 405 | self._models_with_tests: t.Set[str] = set() |
| 406 | + |
406 | 407 | self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros") |
407 | 408 | self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics") |
408 | 409 | self._jinja_macros = JinjaMacroRegistry() |
@@ -656,6 +657,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: |
656 | 657 | self._requirements.update(project.requirements) |
657 | 658 | self._excluded_requirements.update(project.excluded_requirements) |
658 | 659 | self._environment_statements.extend(project.environment_statements) |
| 660 | + |
659 | 661 | self._model_test_metadata.extend(project.model_test_metadata) |
660 | 662 | for metadata in project.model_test_metadata: |
661 | 663 | if metadata.path not in self._model_test_metadata_path_index: |
@@ -2243,9 +2245,7 @@ def test( |
2243 | 2245 |
|
2244 | 2246 | pd.set_option("display.max_columns", None) |
2245 | 2247 |
|
2246 | | - test_meta = self._select_tests( |
2247 | | - test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns |
2248 | | - ) |
| 2248 | + test_meta = self.select_tests(tests=tests, patterns=match_patterns) |
2249 | 2249 |
|
2250 | 2250 | result = run_tests( |
2251 | 2251 | model_test_metadata=test_meta, |
@@ -2807,33 +2807,6 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: |
2807 | 2807 | raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.") |
2808 | 2808 | return self.engine_adapter |
2809 | 2809 |
|
2810 | | - def _select_tests( |
2811 | | - self, |
2812 | | - test_meta: t.List[ModelTestMetadata], |
2813 | | - tests: t.Optional[t.List[str]] = None, |
2814 | | - patterns: t.Optional[t.List[str]] = None, |
2815 | | - ) -> t.List[ModelTestMetadata]: |
2816 | | - """Filter pre-loaded test metadata based on tests and patterns.""" |
2817 | | - |
2818 | | - if tests: |
2819 | | - filtered_tests = [] |
2820 | | - for test in tests: |
2821 | | - if "::" in test: |
2822 | | - if test in self._model_test_metadata_fully_qualified_name_index: |
2823 | | - filtered_tests.append( |
2824 | | - self._model_test_metadata_fully_qualified_name_index[test] |
2825 | | - ) |
2826 | | - else: |
2827 | | - test_path = Path(test) |
2828 | | - if test_path in self._model_test_metadata_path_index: |
2829 | | - filtered_tests.extend(self._model_test_metadata_path_index[test_path]) |
2830 | | - test_meta = filtered_tests |
2831 | | - |
2832 | | - if patterns: |
2833 | | - test_meta = filter_tests_by_patterns(test_meta, patterns) |
2834 | | - |
2835 | | - return test_meta |
2836 | | - |
2837 | 2810 | def _snapshots( |
2838 | 2811 | self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None |
2839 | 2812 | ) -> t.Dict[str, Snapshot]: |
@@ -3245,18 +3218,34 @@ def lint_models( |
3245 | 3218 |
|
3246 | 3219 | return all_violations |
3247 | 3220 |
|
3248 | | - def load_model_tests( |
3249 | | - self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None |
| 3221 | + def select_tests( |
| 3222 | + self, |
| 3223 | + tests: t.Optional[t.List[str]] = None, |
| 3224 | + patterns: t.Optional[t.List[str]] = None, |
3250 | 3225 | ) -> t.List[ModelTestMetadata]: |
3251 | | - # If a set of specific test path(s) are provided, we can use a single loader |
3252 | | - # since it's not required to walk every tests/ folder in each repo |
3253 | | - loaders = [self._loaders[0]] if tests else self._loaders |
| 3226 | + """Filter pre-loaded test metadata based on tests and patterns.""" |
| 3227 | + |
| 3228 | + test_meta = self._model_test_metadata |
| 3229 | + |
| 3230 | + if tests: |
| 3231 | + filtered_tests = [] |
| 3232 | + for test in tests: |
| 3233 | + if "::" in test: |
| 3234 | + if test in self._model_test_metadata_fully_qualified_name_index: |
| 3235 | + filtered_tests.append( |
| 3236 | + self._model_test_metadata_fully_qualified_name_index[test] |
| 3237 | + ) |
| 3238 | + else: |
| 3239 | + test_path = Path(test) |
| 3240 | + if test_path in self._model_test_metadata_path_index: |
| 3241 | + filtered_tests.extend(self._model_test_metadata_path_index[test_path]) |
| 3242 | + |
| 3243 | + test_meta = filtered_tests |
3254 | 3244 |
|
3255 | | - model_tests = [] |
3256 | | - for loader in loaders: |
3257 | | - model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns)) |
| 3245 | + if patterns: |
| 3246 | + test_meta = filter_tests_by_patterns(test_meta, patterns) |
3258 | 3247 |
|
3259 | | - return model_tests |
| 3248 | + return test_meta |
3260 | 3249 |
|
3261 | 3250 |
|
3262 | 3251 | class Context(GenericContext[Config]): |
|
0 commit comments