1111import pandas as pd
1212import pytest
1313from sqlglot import exp , parse_one
14+ from sqlglot .optimizer .normalize_identifiers import normalize_identifiers
1415
1516from sqlmesh import Config , Context , EngineAdapter
1617from sqlmesh .cli .example_project import init_example_project
@@ -51,6 +52,7 @@ def __init__(
5152 self .gateway = gateway
5253 self ._columns_to_types = columns_to_types
5354 self .test_id = random_id (short = True )
55+ self ._context = None
5456
5557 @property
5658 def columns_to_types (self ):
@@ -411,11 +413,14 @@ def create_context(
411413 self ._context = Context (paths = "." , config = config , gateway = self .gateway )
412414 return self ._context
413415
414- def cleanup (self , ctx : Context ):
415- schemas = []
416- for _ , model in ctx .models .items ():
417- schemas .append (model .schema_name )
418- schemas .append (model .physical_schema )
416+ def cleanup (self , ctx : t .Optional [Context ] = None ):
417+ schemas = [self .schema (TEST_SCHEMA )]
418+
419+ ctx = ctx or self ._context
420+ if ctx and ctx .models :
421+ for _ , model in ctx .models .items ():
422+ schemas .append (model .schema_name )
423+ schemas .append (model .physical_schema )
419424
420425 for schema_name in set (schemas ):
421426 self .engine_adapter .drop_schema (
@@ -662,6 +667,14 @@ def ctx(engine_adapter, test_type, mark_gateway):
662667 return TestContext (test_type , engine_adapter , gateway )
663668
664669
670+ @pytest .fixture (autouse = True )
671+ def cleanup (ctx : TestContext ):
672+ yield # run test
673+
674+ if ctx :
675+ ctx .cleanup ()
676+
677+
665678def test_catalog_operations (ctx : TestContext ):
666679 if (
667680 ctx .engine_adapter .CATALOG_SUPPORT .is_unsupported
@@ -691,11 +704,11 @@ def test_catalog_operations(ctx: TestContext):
691704 ctx .engine_adapter .execute (f'CREATE DATABASE IF NOT EXISTS "{ catalog_name } "' )
692705 except Exception :
693706 pass
694- current_catalog = ctx .engine_adapter .get_current_catalog ()
707+ current_catalog = ctx .engine_adapter .get_current_catalog (). lower ()
695708 ctx .engine_adapter .set_current_catalog (catalog_name )
696- assert ctx .engine_adapter .get_current_catalog () == catalog_name
709+ assert ctx .engine_adapter .get_current_catalog (). lower () == catalog_name
697710 ctx .engine_adapter .set_current_catalog (current_catalog )
698- assert ctx .engine_adapter .get_current_catalog () == current_catalog
711+ assert ctx .engine_adapter .get_current_catalog (). lower () == current_catalog
699712
700713
701714def test_drop_schema_catalog (ctx : TestContext , caplog ):
@@ -782,21 +795,14 @@ def test_temp_table(ctx: TestContext):
782795 )
783796 table = ctx .table ("example" )
784797
785- # The snowflake adapter persists the DataFrame to an intermediate table because we use the `write_pandas()` function from the Snowflake python library
786- # Other adapters just use SQLGlot to convert the dataframe directly into a SELECT query
787- expected_tables = 2 if ctx .dialect == "snowflake" and ctx .test_type == "df" else 1
788798 with ctx .engine_adapter .temp_table (ctx .input_data (input_data ), table .sql ()) as table_name :
789799 results = ctx .get_metadata_results ()
790800 assert len (results .views ) == 0
791- assert len (results .tables ) == expected_tables
801+ assert len (results .tables ) == 1
792802 assert len (results .non_temp_tables ) == 0
793803 assert len (results .materialized_views ) == 0
794804 ctx .compare_with_current (table_name , input_data )
795805
796- if ctx .dialect == "snowflake" :
797- # force the next query to create a new connection to prove temp tables have been dropped
798- ctx .engine_adapter ._connection_pool .close ()
799-
800806 results = ctx .get_metadata_results ()
801807 assert len (results .views ) == len (results .tables ) == len (results .non_temp_tables ) == 0
802808
@@ -1735,6 +1741,14 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext):
17351741 personal_paths = [pathlib .Path ("~/.sqlmesh/config.yaml" ).expanduser ()],
17361742 )
17371743 _ , gateway = mark_gateway
1744+
1745+ # clear cache from prior runs
1746+ cache_dir = pathlib .Path ("./examples/sushi/.cache" )
1747+ if cache_dir .exists ():
1748+ import shutil
1749+
1750+ shutil .rmtree (cache_dir )
1751+
17381752 context = Context (paths = "./examples/sushi" , config = config , gateway = gateway )
17391753
17401754 # clean up any leftover schemas from previous runs (requires context)
@@ -1769,7 +1783,7 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext):
17691783
17701784 context ._models .update ({cust_rev_by_day_key : cust_rev_by_day_model_tbl_props })
17711785
1772- context .plan (
1786+ plan : Plan = context .plan (
17731787 environment = "test_prod" ,
17741788 start = start ,
17751789 end = end ,
@@ -1785,6 +1799,7 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext):
17851799 yesterday (),
17861800 env_name = "test_prod" ,
17871801 dialect = ctx .dialect ,
1802+ environment_naming_info = plan .environment_naming_info ,
17881803 )
17891804
17901805 # Ensure table and column comments were correctly registered with engine
@@ -1977,10 +1992,13 @@ def validate_no_comments(
19771992 # confirm physical temp table comments are not registered
19781993 validate_no_comments ("sqlmesh__sushi" , table_name_suffix = "__temp" , check_temp_tables = True )
19791994 # confirm view layer comments are not registered in non-PROD environment
1980- validate_no_comments ("sushi__test_prod" , is_physical_layer = False )
1995+ env_name = "test_prod"
1996+ if plan .environment_naming_info and plan .environment_naming_info .normalize_name :
1997+ env_name = normalize_identifiers (env_name , dialect = ctx .dialect ).name
1998+ validate_no_comments (f"sushi__{ env_name } " , is_physical_layer = False )
19811999
19822000 # Ensure that the plan has been applied successfully.
1983- no_change_plan = context .plan (
2001+ no_change_plan : Plan = context .plan (
19842002 environment = "test_dev" ,
19852003 start = start ,
19862004 end = end ,
@@ -2000,6 +2018,7 @@ def validate_no_comments(
20002018 yesterday (),
20012019 env_name = "test_dev" ,
20022020 dialect = ctx .dialect ,
2021+ environment_naming_info = no_change_plan .environment_naming_info ,
20032022 )
20042023
20052024 # confirm view layer comments are registered in PROD
@@ -2051,7 +2070,7 @@ def test_init_project(ctx: TestContext, mark_gateway: t.Tuple[str, str], tmp_pat
20512070 assert len (physical_layer_results .tables ) == len (physical_layer_results .non_temp_tables ) == 6
20522071
20532072 # make and validate unmodified dev environment
2054- no_change_plan = context .plan (
2073+ no_change_plan : Plan = context .plan (
20552074 environment = "test_dev" ,
20562075 skip_tests = True ,
20572076 no_prompts = True ,
@@ -2062,7 +2081,12 @@ def test_init_project(ctx: TestContext, mark_gateway: t.Tuple[str, str], tmp_pat
20622081
20632082 context .apply (no_change_plan )
20642083
2065- dev_schema_results = ctx .get_metadata_results ("sqlmesh_example__test_dev" )
2084+ environment = no_change_plan .environment
2085+ first_snapshot = no_change_plan .environment .snapshots [0 ]
2086+ schema_name = first_snapshot .qualified_view_name .schema_for_environment (
2087+ environment , dialect = ctx .dialect
2088+ )
2089+ dev_schema_results = ctx .get_metadata_results (schema_name )
20662090 assert sorted (dev_schema_results .views ) == [
20672091 "full_model" ,
20682092 "incremental_model" ,
@@ -2234,6 +2258,7 @@ def _mutate_config(current_gateway_name: str, config: Config):
22342258 connection .concurrent_tasks = 1
22352259
22362260 context = ctx .create_context (_mutate_config )
2261+ assert context .default_dialect == "duckdb"
22372262
22382263 schema = ctx .schema (TEST_SCHEMA )
22392264 seed_query = ctx .input_data (
@@ -2278,13 +2303,13 @@ def _mutate_config(current_gateway_name: str, config: Config):
22782303 try :
22792304 context .plan (auto_apply = True , no_prompts = True )
22802305
2281- results = ctx .get_metadata_results (schema )
2306+ test_model = context .get_model (f"{ schema } .test_model" )
2307+ normalized_schema_name = test_model .fully_qualified_table .db
2308+ results = ctx .get_metadata_results (normalized_schema_name )
22822309 assert "test_model" in results .views
22832310
22842311 actual_df = (
2285- ctx .get_current_data (f"{ schema } .test_model" )
2286- .sort_values (by = "event_date" )
2287- .reset_index (drop = True )
2312+ ctx .get_current_data (test_model .fqn ).sort_values (by = "event_date" ).reset_index (drop = True )
22882313 )
22892314 actual_df ["event_date" ] = actual_df ["event_date" ].astype (str )
22902315 assert actual_df .count ()[0 ] == 3
0 commit comments