Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 33acff9

Browse files
author
Matthias Ekundayo
committed
enhanced support for snowflake ssh key and presto cert
1 parent e357473 commit 33acff9

File tree

4 files changed

+73
-10
lines changed

4 files changed

+73
-10
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,6 @@ benchmark_*.png
141141

142142
# IntelliJ
143143
.idea
144+
145+
# VS Code
146+
.vscode

data_diff/databases/presto.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,28 @@ class Presto(Database):
3535
}
3636
ROUNDS_ON_PREC_LOSS = True
3737

38-
def __init__(self, **kw):
38+
def __init__(self, host, port, user, password, *, catalog, schema=None, **kw):
3939
prestodb = import_presto()
40+
self.args = dict(
41+
host=host, port=port, user=user, catalog=catalog, schema=schema, **kw
42+
) # include port if specified
43+
44+
if (
45+
"cert" in self.args
46+
): # cert used after connection to verify session, but keyword is not valid so remove from connection params
47+
self.args.pop("cert")
48+
49+
if "auth" in kw and kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto
50+
self.args["auth"] = prestodb.auth.BasicAuthentication(user, password)
51+
52+
if schema: # if schema was specified in URI, override default
53+
self.default_schema = schema
54+
self._conn = prestodb.dbapi.connect(**self.args)
55+
# self._conn = prestodb.dbapi.connect(**kw)
56+
57+
if "cert" in kw: # if a certificate was specified in URI, verify session with cert
58+
self._conn._http_session.verify = kw.get("cert")
4059

41-
self._conn = prestodb.dbapi.connect(**kw)
4260

4361
def quote(self, s: str):
4462
return f'"{s}"'

data_diff/databases/snowflake.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
@import_helper("snowflake")
88
def import_snowflake():
99
import snowflake.connector
10-
11-
return snowflake
10+
from cryptography.hazmat.primitives import serialization
11+
from cryptography.hazmat.backends import default_backend
12+
13+
return snowflake, serialization, default_backend
1214

1315

1416
class Snowflake(Database):
@@ -25,8 +27,19 @@ class Snowflake(Database):
2527
}
2628
ROUNDS_ON_PREC_LOSS = False
2729

28-
def __init__(self, *, schema: str, **kw):
29-
snowflake = import_snowflake()
30+
def __init__(self,
31+
account: str,
32+
user: str,
33+
*,
34+
warehouse: str,
35+
schema: str,
36+
database: str,
37+
role: str = None,
38+
_port: int = None,
39+
password: str = None, # default to None incase ssh key is used
40+
**kw,
41+
):
42+
snowflake, serialization, default_backend = import_snowflake()
3043
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
3144

3245
# Got an error: snowflake.connector.network.RetryRequest: could not find io module state (interpreter shutdown?)
@@ -35,10 +48,38 @@ def __init__(self, *, schema: str, **kw):
3548
logging.getLogger("snowflake.connector.network").disabled = True
3649

3750
assert '"' not in schema, "Schema name should not contain quotes!"
38-
self._conn = snowflake.connector.connect(
39-
schema=f'"{schema}"',
40-
**kw,
41-
)
51+
if (
52+
not password and "key" in kw
53+
): # if private keys are used instead of password for Snowflake connection, read in key from path specified and pass as "private_key" to connector.
54+
with open(kw.get("key"), "rb") as key:
55+
p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend())
56+
57+
pkb = p_key.private_bytes(
58+
encoding=serialization.Encoding.DER,
59+
format=serialization.PrivateFormat.PKCS8,
60+
encryption_algorithm=serialization.NoEncryption(),
61+
)
62+
self._conn = snowflake.connector.connect(
63+
user=user,
64+
private_key=pkb, # replaces password
65+
account=account,
66+
role=role,
67+
database=database,
68+
warehouse=warehouse,
69+
schema=f'"{schema}"',
70+
**kw,
71+
)
72+
else: # otherwise use password for connection
73+
self._conn = snowflake.connector.connect(
74+
user=user,
75+
password=password,
76+
account=account,
77+
role=role,
78+
database=database,
79+
warehouse=warehouse,
80+
schema=f'"{schema}"',
81+
**kw,
82+
)
4283

4384
self.default_schema = schema
4485

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ sphinx_markdown_tables
44
sphinx-copybutton
55
sphinx-rtd-theme
66
recommonmark
7+
cryptography
78

89
# Requirements. TODO Use poetry instead of this redundant list
910
data_diff

0 commit comments

Comments
 (0)