Skip to content

Commit 27c29b4

Browse files
authored
Feat: Introduce support for MSSQL
1 parent cab40ae commit 27c29b4

File tree

7 files changed

+900
-1
lines changed

7 files changed

+900
-1
lines changed

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ ignore_missing_imports = True
6969
[mypy-mysql.*]
7070
ignore_missing_imports = True
7171

72+
[mypy-pymssql.*]
73+
ignore_missing_imports = True
74+
7275
[mypy-psycopg2.*]
7376
ignore_missing_imports = True
7477

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@
104104
"langchain",
105105
"openai",
106106
],
107+
"mssql": [
108+
"pymssql",
109+
],
107110
"mysql": [
108111
"mysql-connector-python",
109112
],

sqlmesh/core/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DatabricksConnectionConfig,
77
DuckDBConnectionConfig,
88
GCPPostgresConnectionConfig,
9+
MSSQLConnectionConfig,
910
MySQLConnectionConfig,
1011
PostgresConnectionConfig,
1112
RedshiftConnectionConfig,

sqlmesh/core/config/connection.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,54 @@ def _connection_factory(self) -> t.Callable:
673673
return connect
674674

675675

676+
class MSSQLConnectionConfig(_ConnectionConfig):
677+
host: str
678+
user: str
679+
password: str
680+
database: t.Optional[str] = ""
681+
timeout: t.Optional[int] = 0
682+
login_timeout: t.Optional[int] = 60
683+
charset: t.Optional[str] = "UTF-8"
684+
as_dict: t.Optional[bool] = False
685+
appname: t.Optional[str] = None
686+
port: t.Optional[int] = 1433
687+
conn_properties: t.Optional[t.Union[t.Iterable[str], str]] = None
688+
autocommit: t.Optional[bool] = False
689+
tds_version: t.Optional[str] = None
690+
691+
concurrent_tasks: int = 4
692+
693+
type_: Literal["mssql"] = Field(alias="type", default="mssql")
694+
695+
@property
696+
def _connection_kwargs_keys(self) -> t.Set[str]:
697+
return {
698+
"host",
699+
"user",
700+
"password",
701+
"database",
702+
"timeout",
703+
"login_timeout",
704+
"charset",
705+
"as_dict",
706+
"appname",
707+
"port",
708+
"conn_properties",
709+
"autocommit",
710+
"tds_version",
711+
}
712+
713+
@property
714+
def _engine_adapter(self) -> t.Type[EngineAdapter]:
715+
return engine_adapter.MSSQLEngineAdapter
716+
717+
@property
718+
def _connection_factory(self) -> t.Callable:
719+
import pymssql
720+
721+
return pymssql.connect
722+
723+
676724
class SparkConnectionConfig(_ConnectionConfig):
677725
"""
678726
Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks.
@@ -727,6 +775,7 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
727775
GCPPostgresConnectionConfig,
728776
DatabricksConnectionConfig,
729777
DuckDBConnectionConfig,
778+
MSSQLConnectionConfig,
730779
MySQLConnectionConfig,
731780
PostgresConnectionConfig,
732781
RedshiftConnectionConfig,

sqlmesh/core/engine_adapter/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqlmesh.core.engine_adapter.bigquery import BigQueryEngineAdapter
1010
from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
1111
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
12+
from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
1213
from sqlmesh.core.engine_adapter.mysql import MySQLEngineAdapter
1314
from sqlmesh.core.engine_adapter.postgres import PostgresEngineAdapter
1415
from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter
@@ -25,7 +26,7 @@
2526
"redshift": RedshiftEngineAdapter,
2627
"postgres": PostgresEngineAdapter,
2728
"mysql": MySQLEngineAdapter,
28-
"mssql": EngineAdapterWithIndexSupport,
29+
"mssql": MSSQLEngineAdapter,
2930
}
3031

3132
DIALECT_ALIASES = {
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""Contains MSSQLEngineAdapter."""
2+
3+
4+
from __future__ import annotations
5+
6+
import contextlib
7+
import typing as t
8+
9+
import pandas as pd
10+
from sqlglot import exp
11+
12+
from sqlmesh.core.engine_adapter.base import EngineAdapterWithIndexSupport
13+
from sqlmesh.core.engine_adapter.mixins import (
14+
LogicalReplaceQueryMixin,
15+
PandasNativeFetchDFSupportMixin,
16+
)
17+
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
18+
from sqlmesh.utils.errors import SQLMeshError
19+
20+
if t.TYPE_CHECKING:
21+
import pymssql
22+
23+
from sqlmesh.core._typing import TableName
24+
from sqlmesh.core.engine_adapter._typing import Query, QueryOrDF
25+
26+
27+
class MSSQLEngineAdapter(
28+
EngineAdapterWithIndexSupport,
29+
LogicalReplaceQueryMixin,
30+
PandasNativeFetchDFSupportMixin,
31+
):
32+
"""Implementation of EngineAdapterWithIndexSupport for MsSql compatibility.
33+
34+
Args:
35+
connection_factory: a callable which produces a new Database API-compliant
36+
connection on every call.
37+
dialect: The dialect with which this adapter is associated.
38+
multithreaded: Indicates whether this adapter will be used by more than one thread.
39+
"""
40+
41+
DIALECT: str = "tsql"
42+
43+
def table_exists(self, table_name: TableName) -> bool:
44+
"""
45+
Similar to Postgres, MsSql doesn't support describe so I'm using what
46+
is used there and what the redshift cursor does to check if a table
47+
exists. We don't use this directly in order for this to work as a base
48+
class for other postgres.
49+
50+
Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553
51+
"""
52+
table = exp.to_table(table_name)
53+
54+
catalog_name = table.args.get("catalog") or "master"
55+
sql = (
56+
exp.select("1")
57+
.from_(f"{catalog_name}.information_schema.tables")
58+
.where(f"table_name = '{table.alias_or_name}'")
59+
)
60+
database_name = table.args.get("db")
61+
if database_name:
62+
sql = sql.where(f"table_schema = '{database_name}'")
63+
64+
self.execute(sql)
65+
66+
result = self.cursor.fetchone()
67+
68+
return result[0] == 1 if result else False
69+
70+
@property
71+
def connection(self) -> pymssql.Connection:
72+
return self.cursor.connection
73+
74+
@contextlib.contextmanager
75+
def __try_load_pandas_to_temp_table(
76+
self,
77+
reference_table_name: TableName,
78+
query_or_df: QueryOrDF,
79+
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
80+
) -> t.Generator[Query, None, None]:
81+
reference_table = exp.to_table(reference_table_name)
82+
df = self.try_get_pandas_df(query_or_df)
83+
if df is None:
84+
yield t.cast("Query", query_or_df)
85+
return
86+
if columns_to_types is None:
87+
raise SQLMeshError("columns_to_types must be provided when using Pandas DataFrames")
88+
if reference_table.db is None:
89+
raise SQLMeshError("table must be qualified when using Pandas DataFrames")
90+
with self.temp_table(query_or_df, reference_table) as temp_table:
91+
rows: t.List[t.Iterable[t.Any]] = list(df.itertuples(False, None))
92+
93+
conn = self._connection_pool.get()
94+
conn.bulk_copy(temp_table.name, rows)
95+
96+
yield exp.select(*columns_to_types).from_(temp_table)
97+
98+
def _insert_overwrite_by_condition(
99+
self,
100+
table_name: TableName,
101+
query_or_df: QueryOrDF,
102+
where: t.Optional[exp.Condition] = None,
103+
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
104+
) -> None:
105+
"""
106+
SQL Server does not directly support `INSERT OVERWRITE` but it does
107+
support `MERGE` with a `False` condition and delete that mimics an
108+
`INSERT OVERWRITE`. Based on documentation this should have the same
109+
runtime performance as `INSERT OVERWRITE`.
110+
111+
If a Pandas DataFrame is provided, it will be loaded into a temporary
112+
table and then merged with the target table. This temporary table is
113+
deleted after the merge is complete or after it's expiration time has
114+
passed.
115+
"""
116+
with self.__try_load_pandas_to_temp_table(
117+
table_name,
118+
query_or_df,
119+
columns_to_types,
120+
) as source_table:
121+
query = self._add_where_to_query(source_table, where)
122+
123+
columns = [
124+
exp.to_column(col)
125+
for col in (columns_to_types or [col.alias_or_name for col in query.expressions])
126+
]
127+
when_not_matched_by_source = exp.When(
128+
matched=False,
129+
source=True,
130+
condition=where,
131+
then=exp.Delete(),
132+
)
133+
when_not_matched_by_target = exp.When(
134+
matched=False,
135+
source=False,
136+
then=exp.Insert(
137+
this=exp.Tuple(expressions=columns),
138+
expression=exp.Tuple(expressions=columns),
139+
),
140+
)
141+
self._merge(
142+
target_table=table_name,
143+
source_table=query,
144+
on=exp.condition("1=2"),
145+
match_expressions=[when_not_matched_by_source, when_not_matched_by_target],
146+
)
147+
148+
def _get_data_objects(
149+
self,
150+
schema_name: str,
151+
catalog_name: t.Optional[str] = None,
152+
) -> t.List[DataObject]:
153+
"""
154+
Returns all the data objects that exist in the given schema and catalog.
155+
"""
156+
catalog_name = f"[{catalog_name}]" if catalog_name else "master"
157+
query = f"""
158+
SELECT
159+
'{catalog_name}' AS catalog_name,
160+
TABLE_NAME AS name,
161+
TABLE_SCHEMA AS schema_name,
162+
'TABLE' AS type
163+
FROM {catalog_name}.INFORMATION_SCHEMA.TABLES
164+
WHERE TABLE_SCHEMA LIKE '%{schema_name}%'
165+
UNION ALL
166+
SELECT
167+
'{catalog_name}' AS catalog_name,
168+
TABLE_NAME AS name,
169+
TABLE_SCHEMA AS schema_name,
170+
'VIEW' AS type
171+
FROM {catalog_name}.INFORMATION_SCHEMA.VIEWS
172+
WHERE TABLE_SCHEMA LIKE '%{schema_name}%'
173+
"""
174+
dataframe: pd.DataFrame = self.fetchdf(query)
175+
return [
176+
DataObject(
177+
catalog=row.catalog_name, # type: ignore
178+
schema=row.schema_name, # type: ignore
179+
name=row.name, # type: ignore
180+
type=DataObjectType.from_str(row.type), # type: ignore
181+
)
182+
for row in dataframe.itertuples()
183+
]

0 commit comments

Comments
 (0)