1111from sqlmesh .core .linter .rule import Rule , RuleViolation
1212from sqlmesh .core .console import LinterConsole , get_console
1313
14+ if t .TYPE_CHECKING :
15+ from sqlmesh .core .context import GenericContext
16+
1417
1518def select_rules (all_rules : RuleSet , rule_names : t .Set [str ]) -> RuleSet :
1619 if "all" in rule_names :
@@ -52,7 +55,7 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
5255 return Linter (config .enabled , all_rules , rules , warn_rules )
5356
5457 def lint_model (
55- self , model : Model , console : LinterConsole = get_console ()
58+ self , model : Model , context : GenericContext , console : LinterConsole = get_console ()
5659 ) -> t .Tuple [bool , t .List [AnnotatedRuleViolation ]]:
5760 if not self .enabled :
5861 return False , []
@@ -62,8 +65,8 @@ def lint_model(
6265 rules = self .rules .difference (ignored_rules )
6366 warn_rules = self .warn_rules .difference (ignored_rules )
6467
65- error_violations = rules .check_model (model )
66- warn_violations = warn_rules .check_model (model )
68+ error_violations = rules .check_model (model , context )
69+ warn_violations = warn_rules .check_model (model , context )
6770
6871 all_violations : t .List [AnnotatedRuleViolation ] = [
6972 AnnotatedRuleViolation (
@@ -96,11 +99,11 @@ class RuleSet(Mapping[str, type[Rule]]):
9699 def __init__ (self , rules : Iterable [type [Rule ]] = ()) -> None :
97100 self ._underlying = {rule .name : rule for rule in rules }
98101
99- def check_model (self , model : Model ) -> t .List [RuleViolation ]:
102+ def check_model (self , model : Model , context : GenericContext ) -> t .List [RuleViolation ]:
100103 violations = []
101104
102105 for rule in self ._underlying .values ():
103- violation = rule ().check_model (model )
106+ violation = rule (context ).check_model (model )
104107
105108 if violation :
106109 violations .append (violation )
0 commit comments