2323from sqlmesh .core .engine_adapter .shared import CatalogSupport
2424from sqlmesh .core .engine_adapter import EngineAdapter
2525from sqlmesh .utils .errors import ConfigError
26- from sqlmesh .utils .pydantic import (
27- field_validator ,
28- model_validator ,
29- model_validator_v1_args ,
30- field_validator_v1_args ,
31- )
26+ from sqlmesh .utils .pydantic import ValidationInfo , field_validator , model_validator
3227from sqlmesh .utils .aws import validate_s3_uri
3328
29+ if t .TYPE_CHECKING :
30+ from sqlmesh .core ._typing import Self
31+
3432logger = logging .getLogger (__name__ )
3533
3634RECOMMENDED_STATE_SYNC_ENGINES = {"postgres" , "gcp_postgres" , "mysql" , "mssql" }
@@ -163,19 +161,20 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
163161 _data_file_to_adapter : t .ClassVar [t .Dict [str , EngineAdapter ]] = {}
164162
165163 @model_validator (mode = "before" )
166- @ model_validator_v1_args
167- def _validate_database_catalogs (
168- cls , values : t . Dict [ str , t . Optional [ str ]]
169- ) -> t . Dict [ str , t . Optional [ str ]]:
170- if db_path := values .get ("database" ) and values .get ("catalogs" ):
164+ def _validate_database_catalogs ( cls , data : t . Any ) -> t . Any :
165+ if not isinstance ( data , dict ):
166+ return data
167+
168+ if db_path := data .get ("database" ) and data .get ("catalogs" ):
171169 raise ConfigError (
172170 "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
173171 )
174172 if isinstance (db_path , str ) and db_path .startswith ("md:" ):
175173 raise ConfigError (
176174 "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`."
177175 )
178- return values
176+
177+ return data
179178
180179 @property
181180 def _engine_adapter (self ) -> t .Type [EngineAdapter ]:
@@ -430,29 +429,29 @@ class SnowflakeConnectionConfig(ConnectionConfig):
430429 _concurrent_tasks_validator = concurrent_tasks_validator
431430
432431 @model_validator (mode = "before" )
433- @model_validator_v1_args
434- def _validate_authenticator (
435- cls , values : t .Dict [str , t .Optional [str ]]
436- ) -> t .Dict [str , t .Optional [str ]]:
437- from snowflake .connector .network import (
438- DEFAULT_AUTHENTICATOR ,
439- OAUTH_AUTHENTICATOR ,
440- )
432+ def _validate_authenticator (cls , data : t .Any ) -> t .Any :
433+ if not isinstance (data , dict ):
434+ return data
441435
442- auth = values .get ("authenticator" )
436+ from snowflake .connector .network import DEFAULT_AUTHENTICATOR , OAUTH_AUTHENTICATOR
437+
438+ auth = data .get ("authenticator" )
443439 auth = auth .upper () if auth else DEFAULT_AUTHENTICATOR
444- user = values .get ("user" )
445- password = values .get ("password" )
446- values ["private_key" ] = cls ._get_private_key (values , auth ) # type: ignore
440+ user = data .get ("user" )
441+ password = data .get ("password" )
442+ data ["private_key" ] = cls ._get_private_key (data , auth ) # type: ignore
443+
447444 if (
448445 auth == DEFAULT_AUTHENTICATOR
449- and not values .get ("private_key" )
446+ and not data .get ("private_key" )
450447 and (not user or not password )
451448 ):
452449 raise ConfigError ("User and password must be provided if using default authentication" )
453- if auth == OAUTH_AUTHENTICATOR and not values .get ("token" ):
450+
451+ if auth == OAUTH_AUTHENTICATOR and not data .get ("token" ):
454452 raise ConfigError ("Token must be provided if using oauth authentication" )
455- return values
453+
454+ return data
456455
457456 @classmethod
458457 def _get_private_key (cls , values : t .Dict [str , t .Optional [str ]], auth : str ) -> t .Optional [bytes ]:
@@ -621,26 +620,28 @@ class DatabricksConnectionConfig(ConnectionConfig):
621620 _http_headers_validator = http_headers_validator
622621
623622 @model_validator (mode = "before" )
624- @model_validator_v1_args
625- def _databricks_connect_validator (cls , values : t .Dict [str , t .Any ]) -> t .Dict [str , t .Any ]:
623+ def _databricks_connect_validator (cls , data : t .Any ) -> t .Any :
624+ if not isinstance (data , dict ):
625+ return data
626+
626627 from sqlmesh .core .engine_adapter .databricks import DatabricksEngineAdapter
627628
628629 if DatabricksEngineAdapter .can_access_spark_session (
629- bool (values .get ("disable_spark_session" ))
630+ bool (data .get ("disable_spark_session" ))
630631 ):
631- return values
632+ return data
632633
633- databricks_connect_use_serverless = values .get ("databricks_connect_use_serverless" )
634+ databricks_connect_use_serverless = data .get ("databricks_connect_use_serverless" )
634635 server_hostname , http_path , access_token , auth_type = (
635- values .get ("server_hostname" ),
636- values .get ("http_path" ),
637- values .get ("access_token" ),
638- values .get ("auth_type" ),
636+ data .get ("server_hostname" ),
637+ data .get ("http_path" ),
638+ data .get ("access_token" ),
639+ data .get ("auth_type" ),
639640 )
640641
641642 if databricks_connect_use_serverless :
642- values ["force_databricks_connect" ] = True
643- values ["disable_databricks_connect" ] = False
643+ data ["force_databricks_connect" ] = True
644+ data ["disable_databricks_connect" ] = False
644645
645646 if (not server_hostname or not http_path or not access_token ) and (
646647 not databricks_connect_use_serverless and not auth_type
@@ -651,35 +652,35 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
651652 if (
652653 databricks_connect_use_serverless
653654 and not server_hostname
654- and not values .get ("databricks_connect_server_hostname" )
655+ and not data .get ("databricks_connect_server_hostname" )
655656 ):
656657 raise ValueError (
657658 "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
658659 )
659660 if DatabricksEngineAdapter .can_access_databricks_connect (
660- bool (values .get ("disable_databricks_connect" ))
661+ bool (data .get ("disable_databricks_connect" ))
661662 ):
662- if not values .get ("databricks_connect_access_token" ):
663- values ["databricks_connect_access_token" ] = access_token
664- if not values .get ("databricks_connect_server_hostname" ):
665- values ["databricks_connect_server_hostname" ] = f"https://{ server_hostname } "
663+ if not data .get ("databricks_connect_access_token" ):
664+ data ["databricks_connect_access_token" ] = access_token
665+ if not data .get ("databricks_connect_server_hostname" ):
666+ data ["databricks_connect_server_hostname" ] = f"https://{ server_hostname } "
666667 if not databricks_connect_use_serverless :
667- if not values .get ("databricks_connect_cluster_id" ):
668+ if not data .get ("databricks_connect_cluster_id" ):
668669 if t .TYPE_CHECKING :
669670 assert http_path is not None
670- values ["databricks_connect_cluster_id" ] = http_path .split ("/" )[- 1 ]
671+ data ["databricks_connect_cluster_id" ] = http_path .split ("/" )[- 1 ]
671672
672673 if auth_type :
673674 from databricks .sql .auth .auth import AuthType
674675
675- all_values = [m .value for m in AuthType ]
676- if auth_type not in all_values :
676+ all_data = [m .value for m in AuthType ]
677+ if auth_type not in all_data :
677678 raise ValueError (
678- f"`auth_type` { auth_type } does not match a valid option: { all_values } "
679+ f"`auth_type` { auth_type } does not match a valid option: { all_data } "
679680 )
680681
681- client_id = values .get ("oauth_client_id" )
682- client_secret = values .get ("oauth_client_secret" )
682+ client_id = data .get ("oauth_client_id" )
683+ client_secret = data .get ("oauth_client_secret" )
683684
684685 if client_secret and not client_id :
685686 raise ValueError (
@@ -689,7 +690,7 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
689690 if not http_path :
690691 raise ValueError ("`http_path` is still required when using `auth_type`" )
691692
692- return values
693+ return data
693694
694695 @property
695696 def _connection_kwargs_keys (self ) -> t .Set [str ]:
@@ -866,26 +867,24 @@ class BigQueryConnectionConfig(ConnectionConfig):
866867 type_ : t .Literal ["bigquery" ] = Field (alias = "type" , default = "bigquery" )
867868
868869 @field_validator ("execution_project" )
869- @field_validator_v1_args
870870 def validate_execution_project (
871871 cls ,
872872 v : t .Optional [str ],
873- values : t . Dict [ str , t . Any ] ,
873+ info : ValidationInfo ,
874874 ) -> t .Optional [str ]:
875- if v and not values .get ("project" ):
875+ if v and not info . data .get ("project" ):
876876 raise ConfigError (
877877 "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
878878 )
879879 return v
880880
881881 @field_validator ("quota_project" )
882- @field_validator_v1_args
883882 def validate_quota_project (
884883 cls ,
885884 v : t .Optional [str ],
886- values : t . Dict [ str , t . Any ] ,
885+ info : ValidationInfo ,
887886 ) -> t .Optional [str ]:
888- if v and not values .get ("project" ):
887+ if v and not info . data .get ("project" ):
889888 raise ConfigError (
890889 "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
891890 )
@@ -998,12 +997,13 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
998997 pre_ping : bool = True
999998
1000999 @model_validator (mode = "before" )
1001- @model_validator_v1_args
1002- def _validate_auth_method (
1003- cls , values : t .Dict [str , t .Optional [str ]]
1004- ) -> t .Dict [str , t .Optional [str ]]:
1005- password = values .get ("password" )
1006- enable_iam_auth = values .get ("enable_iam_auth" )
1000+ def _validate_auth_method (cls , data : t .Any ) -> t .Any :
1001+ if not isinstance (data , dict ):
1002+ return data
1003+
1004+ password = data .get ("password" )
1005+ enable_iam_auth = data .get ("enable_iam_auth" )
1006+
10071007 if password and enable_iam_auth :
10081008 raise ConfigError (
10091009 "Invalid GCP Postgres connection configuration - both password and"
@@ -1016,7 +1016,8 @@ def _validate_auth_method(
10161016 " for a postgres user account or enable_iam_auth set to 'True'"
10171017 " for an IAM user account."
10181018 )
1019- return values
1019+
1020+ return data
10201021
10211022 @property
10221023 def _connection_kwargs_keys (self ) -> t .Set [str ]:
@@ -1437,40 +1438,37 @@ class TrinoConnectionConfig(ConnectionConfig):
14371438 type_ : t .Literal ["trino" ] = Field (alias = "type" , default = "trino" )
14381439
14391440 @model_validator (mode = "after" )
1440- @model_validator_v1_args
1441- def _root_validator (cls , values : t .Dict [str , t .Any ]) -> t .Dict [str , t .Any ]:
1442- port = values .get ("port" )
1443- if (
1444- values ["http_scheme" ] == "http"
1445- and not values ["method" ].is_no_auth
1446- and not values ["method" ].is_basic
1447- ):
1441+ def _root_validator (self ) -> Self :
1442+ port = self .port
1443+ if self .http_scheme == "http" and not self .method .is_no_auth and not self .method .is_basic :
14481444 raise ConfigError ("HTTP scheme can only be used with no-auth or basic method" )
1445+
14491446 if port is None :
1450- values ["port" ] = 80 if values ["http_scheme" ] == "http" else 443
1451- if (values ["method" ].is_ldap or values ["method" ].is_basic ) and (
1452- not values ["password" ] or not values ["user" ]
1453- ):
1447+ self .port = 80 if self .http_scheme == "http" else 443
1448+
1449+ if (self .method .is_ldap or self .method .is_basic ) and (not self .password or not self .user ):
14541450 raise ConfigError (
1455- f"Username and Password must be provided if using { values [ ' method' ] .value } authentication"
1451+ f"Username and Password must be provided if using { self . method .value } authentication"
14561452 )
1457- if values ["method" ].is_kerberos and (
1458- not values ["principal" ] or not values ["keytab" ] or not values ["krb5_config" ]
1453+
1454+ if self .method .is_kerberos and (
1455+ not self .principal or not self .keytab or not self .krb5_config
14591456 ):
14601457 raise ConfigError (
14611458 "Kerberos requires the following fields: principal, keytab, and krb5_config"
14621459 )
1463- if values ["method" ].is_jwt and not values ["jwt_token" ]:
1460+
1461+ if self .method .is_jwt and not self .jwt_token :
14641462 raise ConfigError ("JWT requires `jwt_token` to be set" )
1465- if values ["method" ].is_certificate and (
1466- not values ["cert" ]
1467- or not values ["client_certificate" ]
1468- or not values ["client_private_key" ]
1463+
1464+ if self .method .is_certificate and (
1465+ not self .cert or not self .client_certificate or not self .client_private_key
14691466 ):
14701467 raise ConfigError (
14711468 "Certificate requires the following fields: cert, client_certificate, and client_private_key"
14721469 )
1473- return values
1470+
1471+ return self
14741472
14751473 @property
14761474 def _connection_kwargs_keys (self ) -> t .Set [str ]:
@@ -1677,26 +1675,23 @@ class AthenaConnectionConfig(ConnectionConfig):
16771675 type_ : t .Literal ["athena" ] = Field (alias = "type" , default = "athena" )
16781676
16791677 @model_validator (mode = "after" )
1680- @model_validator_v1_args
1681- def _root_validator (cls , values : t .Dict [str , t .Any ]) -> t .Dict [str , t .Any ]:
1682- work_group = values .get ("work_group" )
1683- s3_staging_dir = values .get ("s3_staging_dir" )
1684- s3_warehouse_location = values .get ("s3_warehouse_location" )
1678+ def _root_validator (self ) -> Self :
1679+ work_group = self .work_group
1680+ s3_staging_dir = self .s3_staging_dir
1681+ s3_warehouse_location = self .s3_warehouse_location
16851682
16861683 if not work_group and not s3_staging_dir :
16871684 raise ConfigError ("At least one of work_group or s3_staging_dir must be set" )
16881685
16891686 if s3_staging_dir :
1690- values ["s3_staging_dir" ] = validate_s3_uri (
1691- s3_staging_dir , base = True , error_type = ConfigError
1692- )
1687+ self .s3_staging_dir = validate_s3_uri (s3_staging_dir , base = True , error_type = ConfigError )
16931688
16941689 if s3_warehouse_location :
1695- values [ " s3_warehouse_location" ] = validate_s3_uri (
1690+ self . s3_warehouse_location = validate_s3_uri (
16961691 s3_warehouse_location , base = True , error_type = ConfigError
16971692 )
16981693
1699- return values
1694+ return self
17001695
17011696 @property
17021697 def _connection_kwargs_keys (self ) -> t .Set [str ]:
0 commit comments