11import pytest
22from sqlmesh .core .context import Context
3- from sqlmesh .lsp .context import LSPContext
3+ from sqlmesh .lsp .context import LSPContext , ModelTarget , AuditTarget
44from sqlmesh .lsp .reference import get_model_definitions_for_a_path
55
66
@@ -9,11 +9,16 @@ def test_reference() -> None:
99 context = Context (paths = ["examples/sushi" ])
1010 lsp_context = LSPContext (context )
1111
12+ # Find model URIs
1213 active_customers_uri = next (
13- uri for uri , models in lsp_context .map .items () if "sushi.active_customers" in models
14+ uri
15+ for uri , info in lsp_context .map .items ()
16+ if isinstance (info , ModelTarget ) and "sushi.active_customers" in info .names
1417 )
1518 sushi_customers_uri = next (
16- uri for uri , models in lsp_context .map .items () if "sushi.customers" in models
19+ uri
20+ for uri , info in lsp_context .map .items ()
21+ if isinstance (info , ModelTarget ) and "sushi.customers" in info .names
1722 )
1823
1924 references = get_model_definitions_for_a_path (lsp_context , active_customers_uri )
@@ -35,7 +40,9 @@ def test_reference_with_alias() -> None:
3540 lsp_context = LSPContext (context )
3641
3742 waiter_revenue_by_day_uri = next (
38- uri for uri , models in lsp_context .map .items () if "sushi.waiter_revenue_by_day" in models
43+ uri
44+ for uri , info in lsp_context .map .items ()
45+ if isinstance (info , ModelTarget ) and "sushi.waiter_revenue_by_day" in info .names
3946 )
4047
4148 references = get_model_definitions_for_a_path (lsp_context , waiter_revenue_by_day_uri )
@@ -52,6 +59,37 @@ def test_reference_with_alias() -> None:
5259 assert get_string_from_range (read_file , references [2 ].range ) == "sushi.items"
5360
5461
62+ @pytest .mark .fast
63+ def test_standalone_audit_reference () -> None :
64+ context = Context (paths = ["examples/sushi" ])
65+ lsp_context = LSPContext (context )
66+
67+ # Find the standalone audit URI
68+ audit_uri = next (
69+ uri
70+ for uri , info in lsp_context .map .items ()
71+ if isinstance (info , AuditTarget ) and info .name == "assert_item_price_above_zero"
72+ )
73+
74+ # Find the items model URI
75+ items_uri = next (
76+ uri
77+ for uri , info in lsp_context .map .items ()
78+ if isinstance (info , ModelTarget ) and "sushi.items" in info .names
79+ )
80+
81+ references = get_model_definitions_for_a_path (lsp_context , audit_uri )
82+
83+ assert len (references ) == 1
84+ assert references [0 ].uri == items_uri
85+
86+ # Check that the reference in the correct range is sushi.items
87+ path = audit_uri .removeprefix ("file://" )
88+ read_file = open (path , "r" ).readlines ()
89+ referenced_text = get_string_from_range (read_file , references [0 ].range )
90+ assert referenced_text == "sushi.items"
91+
92+
5593def get_string_from_range (file_lines , range_obj ) -> str :
5694 start_line = range_obj .start .line
5795 end_line = range_obj .end .line
0 commit comments