Skip to content

Commit 2595aca

Browse files
authored
feat: add position details to linter (#4663)
1 parent bc8d3d7 commit 2595aca

File tree

8 files changed

+293
-98
lines changed

8 files changed

+293
-98
lines changed

sqlmesh/core/linter/definition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Iterator, Iterable, Set, Mapping, Callable
99
from functools import reduce
1010
from sqlmesh.core.model import Model
11-
from sqlmesh.core.linter.rule import Rule, RuleViolation
11+
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
1212
from sqlmesh.core.console import LinterConsole, get_console
1313

1414
if t.TYPE_CHECKING:
@@ -74,6 +74,7 @@ def lint_model(
7474
violation_msg=violation.violation_msg,
7575
model=model,
7676
violation_type="error",
77+
violation_range=violation.violation_range,
7778
)
7879
for violation in error_violations
7980
] + [
@@ -82,6 +83,7 @@ def lint_model(
8283
violation_msg=violation.violation_msg,
8384
model=model,
8485
violation_type="warning",
86+
violation_range=violation.violation_range,
8587
)
8688
for violation in warn_violations
8789
]
@@ -149,7 +151,8 @@ def __init__(
149151
violation_msg: str,
150152
model: Model,
151153
violation_type: t.Literal["error", "warning"],
154+
violation_range: t.Optional[Range] = None,
152155
) -> None:
153-
super().__init__(rule, violation_msg)
156+
super().__init__(rule, violation_msg, violation_range)
154157
self.model = model
155158
self.violation_type = violation_type

sqlmesh/core/linter/helpers.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from pathlib import Path
2+
3+
from sqlmesh.core.linter.rule import Position, Range
4+
from sqlmesh.utils.pydantic import PydanticModel
5+
import typing as t
6+
7+
8+
class TokenPositionDetails(PydanticModel):
9+
"""
10+
Details about a token's position in the source code in the structure provided by SQLGlot.
11+
12+
Attributes:
13+
line (int): The line that the token ends on.
14+
col (int): The column that the token ends on.
15+
start (int): The start index of the token.
16+
end (int): The ending index of the token.
17+
"""
18+
19+
line: int
20+
col: int
21+
start: int
22+
end: int
23+
24+
@staticmethod
25+
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
26+
return TokenPositionDetails(
27+
line=meta["line"],
28+
col=meta["col"],
29+
start=meta["start"],
30+
end=meta["end"],
31+
)
32+
33+
def to_range(self, read_file: t.Optional[t.List[str]]) -> Range:
34+
"""
35+
Convert a TokenPositionDetails object to a Range object.
36+
37+
In the circumstances where the token's start and end positions are the same,
38+
there is no need for a read_file parameter, as the range can be derived from the token's
39+
line and column. This is an optimization to avoid unnecessary file reads and should
40+
only be used when the token represents a single character or position in the file.
41+
42+
If the token's start and end positions are different, the read_file parameter is required.
43+
44+
:param read_file: List of lines from the file. Optional
45+
:return: A Range object representing the token's position
46+
"""
47+
if self.start == self.end:
48+
# If the start and end positions are the same, we can create a range directly
49+
return Range(
50+
start=Position(line=self.line - 1, character=self.col - 1),
51+
end=Position(line=self.line - 1, character=self.col),
52+
)
53+
54+
if read_file is None:
55+
raise ValueError("read_file must be provided when start and end positions differ.")
56+
57+
# Convert from 1-indexed to 0-indexed for line only
58+
end_line_0 = self.line - 1
59+
end_col_0 = self.col
60+
61+
# Find the start line and column by counting backwards from the end position
62+
start_pos = self.start
63+
end_pos = self.end
64+
65+
# Initialize with the end position
66+
start_line_0 = end_line_0
67+
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
68+
69+
# If start_col_0 is negative, we need to go back to previous lines
70+
while start_col_0 < 0 and start_line_0 > 0:
71+
start_line_0 -= 1
72+
start_col_0 += len(read_file[start_line_0])
73+
# Account for newline character
74+
if start_col_0 >= 0:
75+
break
76+
start_col_0 += 1 # For the newline character
77+
78+
# Ensure we don't have negative values
79+
start_col_0 = max(0, start_col_0)
80+
return Range(
81+
start=Position(line=start_line_0, character=start_col_0),
82+
end=Position(line=end_line_0, character=end_col_0),
83+
)
84+
85+
86+
def read_range_from_file(file: Path, text_range: Range) -> str:
87+
"""
88+
Read the file and return the content within the specified range.
89+
90+
Args:
91+
file: Path to the file to read
92+
text_range: The range of text to extract
93+
94+
Returns:
95+
The content within the specified range
96+
"""
97+
with file.open("r", encoding="utf-8") as f:
98+
lines = f.readlines()
99+
100+
# Ensure the range is within bounds
101+
start_line = max(0, text_range.start.line)
102+
end_line = min(len(lines), text_range.end.line + 1)
103+
104+
if start_line >= end_line:
105+
return ""
106+
107+
# Extract the relevant portions of each line
108+
result = []
109+
for i in range(start_line, end_line):
110+
line = lines[i]
111+
start_char = text_range.start.character if i == text_range.start.line else 0
112+
end_char = text_range.end.character if i == text_range.end.line else len(line)
113+
result.append(line[start_char:end_char])
114+
115+
return "".join(result)

sqlmesh/core/linter/rule.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4+
from dataclasses import dataclass
45

56
from sqlmesh.core.model import Model
67

@@ -22,6 +23,22 @@ class RuleLocation(PydanticModel):
2223
start_line: t.Optional[int] = None
2324

2425

26+
@dataclass(frozen=True)
27+
class Position:
28+
"""The position of a rule violation in a file, the position follows the LSP standard."""
29+
30+
line: int
31+
character: int
32+
33+
34+
@dataclass(frozen=True)
35+
class Range:
36+
"""The range of a rule violation in a file. The range follows the LSP standard."""
37+
38+
start: Position
39+
end: Position
40+
41+
2542
class _Rule(abc.ABCMeta):
2643
def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule:
2744
attrs["name"] = clsname.lower()
@@ -45,9 +62,15 @@ def summary(self) -> str:
4562
"""A summary of what this rule checks for."""
4663
return self.__doc__ or ""
4764

48-
def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation:
65+
def violation(
66+
self,
67+
violation_msg: t.Optional[str] = None,
68+
violation_range: t.Optional[Range] = None,
69+
) -> RuleViolation:
4970
"""Create a RuleViolation instance for this rule"""
50-
return RuleViolation(rule=self, violation_msg=violation_msg or self.summary)
71+
return RuleViolation(
72+
rule=self, violation_msg=violation_msg or self.summary, violation_range=violation_range
73+
)
5174

5275
def get_definition_location(self) -> RuleLocation:
5376
"""Return the file path and position information for this rule.
@@ -79,9 +102,12 @@ def __repr__(self) -> str:
79102

80103

81104
class RuleViolation:
82-
def __init__(self, rule: Rule, violation_msg: str) -> None:
105+
def __init__(
106+
self, rule: Rule, violation_msg: str, violation_range: t.Optional[Range] = None
107+
) -> None:
83108
self.rule = rule
84109
self.violation_msg = violation_msg
110+
self.violation_range = violation_range
85111

86112
def __repr__(self) -> str:
87113
return f"{self.rule.name}: {self.violation_msg}"

sqlmesh/core/linter/rules/builtin.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import typing as t
66

7+
from sqlglot.expressions import Star
78
from sqlglot.helper import subclasses
89

9-
from sqlmesh.core.linter.rule import Rule, RuleViolation
10+
from sqlmesh.core.linter.helpers import TokenPositionDetails
11+
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
1012
from sqlmesh.core.linter.definition import RuleSet
1113
from sqlmesh.core.model import Model, SqlModel
1214

@@ -15,10 +17,25 @@ class NoSelectStar(Rule):
1517
"""Query should not contain SELECT * on its outer most projections, even if it can be expanded."""
1618

1719
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
20+
# Only applies to SQL models, as other model types do not have a query.
1821
if not isinstance(model, SqlModel):
1922
return None
20-
21-
return self.violation() if model.query.is_star else None
23+
if model.query.is_star:
24+
violation_range = self._get_range(model)
25+
return self.violation(violation_range=violation_range)
26+
return None
27+
28+
def _get_range(self, model: SqlModel) -> t.Optional[Range]:
29+
"""Get the range of the violation if available."""
30+
try:
31+
if len(model.query.expressions) == 1 and isinstance(model.query.expressions[0], Star):
32+
return TokenPositionDetails.from_meta(model.query.expressions[0].meta).to_range(
33+
None
34+
)
35+
except Exception:
36+
pass
37+
38+
return None
2239

2340

2441
class InvalidSelectStarExpansion(Rule):

sqlmesh/lsp/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from sqlmesh.core.model.definition import SqlModel
77
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
8-
from sqlmesh.lsp.custom import RenderModelEntry, ModelForRendering
8+
from sqlmesh.lsp.custom import ModelForRendering
99
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1010
from sqlmesh.lsp.uri import URI
1111

sqlmesh/lsp/main.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,24 @@ def _diagnostic_to_lsp_diagnostic(
655655
) -> t.Optional[types.Diagnostic]:
656656
if diagnostic.model._path is None:
657657
return None
658-
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
659-
lines = file.readlines()
658+
if not diagnostic.violation_range:
659+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
660+
lines = file.readlines()
661+
range = types.Range(
662+
start=types.Position(line=0, character=0),
663+
end=types.Position(line=len(lines) - 1, character=len(lines[-1])),
664+
)
665+
else:
666+
range = types.Range(
667+
start=types.Position(
668+
line=diagnostic.violation_range.start.line,
669+
character=diagnostic.violation_range.start.character,
670+
),
671+
end=types.Position(
672+
line=diagnostic.violation_range.end.line,
673+
character=diagnostic.violation_range.end.character,
674+
),
675+
)
660676

661677
# Get rule definition location for diagnostics link
662678
rule_location = diagnostic.rule.get_definition_location()
@@ -665,10 +681,7 @@ def _diagnostic_to_lsp_diagnostic(
665681

666682
# Use URI format to create a link for "related information"
667683
return types.Diagnostic(
668-
range=types.Range(
669-
start=types.Position(line=0, character=0),
670-
end=types.Position(line=len(lines), character=len(lines[-1])),
671-
),
684+
range=range,
672685
message=diagnostic.violation_msg,
673686
severity=types.DiagnosticSeverity.Error
674687
if diagnostic.violation_type == "error"

0 commit comments

Comments
 (0)