Skip to content

Commit 60e53e5

Browse files
committed
add copy grants
1 parent 91cd584 commit 60e53e5

2 files changed

Lines changed: 76 additions & 7 deletions

File tree

snowflake_utils/models/table.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,12 @@ def get_create_temporary_external_stage(
9999
def get_create_table_statement(
100100
self,
101101
full_refresh: bool = False,
102+
copy_grants: bool = True,
102103
) -> str:
103104
logging.debug(f"Creating table: {self.fqn}")
105+
copy_grants_clause = " COPY GRANTS" if copy_grants and full_refresh else ""
104106
if self.table_structure:
105-
return f"{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.fqn} ({self.table_structure.parsed_columns})"
107+
return f"{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.fqn}{copy_grants_clause} ({self.table_structure.parsed_columns})"
106108
else:
107109
template = """ARRAY_AGG(
108110
OBJECT_CONSTRUCT(
@@ -129,7 +131,7 @@ def get_create_table_statement(
129131

130132
stage_query = f"LOCATION => '@{self.stage}'"
131133
return f"""
132-
{"CREATE OR REPLACE TABLE" if full_refresh else "CREATE TABLE IF NOT EXISTS"} {self.fqn}
134+
{"CREATE OR REPLACE TABLE" if full_refresh else "CREATE TABLE IF NOT EXISTS"} {self.fqn}{copy_grants_clause}
133135
USING TEMPLATE (
134136
SELECT {template}
135137
FROM TABLE(
@@ -151,7 +153,9 @@ def bulk_insert(
151153
cursor = connection.cursor()
152154
_execute_statement = partial(execute_statement, cursor)
153155
_execute_statement(self.get_create_schema_statement())
154-
_execute_statement(self.get_create_table_statement(full_refresh))
156+
_execute_statement(
157+
self.get_create_table_statement(full_refresh, copy_grants=True)
158+
)
155159
for k in records:
156160
cols = ", ".join([k for k in records[k].keys()])
157161
vals = ", ".join([_type_cast(v) for v in records[k].values()])
@@ -173,14 +177,15 @@ def _copy(
173177
sync_tags: bool = False,
174178
stage: str | None = None,
175179
create_table: bool = True,
180+
copy_grants: bool = True,
176181
) -> None:
177182
with connect() as connection:
178183
cursor = connection.cursor()
179184
execute = self.setup_connection(
180185
path, storage_integration, cursor, file_format, stage
181186
)
182187
if create_table:
183-
self.create_table(full_refresh, execute)
188+
self.create_table(full_refresh, execute, copy_grants)
184189

185190
if sync_tags and self.table_structure:
186191
self.sync_tags(cursor)
@@ -202,6 +207,7 @@ def copy_into(
202207
stage: str | None = None,
203208
files: list[str] | None = None,
204209
create_table: bool = True,
210+
copy_grants: bool = True,
205211
) -> None:
206212
col_str = f"({', '.join(target_columns)})" if target_columns else ""
207213
files_clause = ""
@@ -229,6 +235,7 @@ def copy_into(
229235
sync_tags,
230236
stage,
231237
create_table,
238+
copy_grants,
232239
)
233240
with connect() as connection:
234241
cursor = connection.cursor()
@@ -249,10 +256,13 @@ def copy_into(
249256
sync_tags,
250257
stage,
251258
create_table,
259+
copy_grants,
252260
)
253261

254-
def create_table(self, full_refresh: bool, execute_statement: callable) -> None:
255-
execute_statement(self.get_create_table_statement(full_refresh))
262+
def create_table(
263+
self, full_refresh: bool, execute_statement: callable, copy_grants: bool = True
264+
) -> None:
265+
execute_statement(self.get_create_table_statement(full_refresh, copy_grants))
256266

257267
def setup_file_format(
258268
self,
@@ -311,7 +321,9 @@ def _merge(
311321

312322
with connect() as connection:
313323
cursor = connection.cursor()
314-
cursor.execute(self.get_create_table_statement(full_refresh=False))
324+
cursor.execute(
325+
self.get_create_table_statement(full_refresh=False, copy_grants=True)
326+
)
315327
old_columns = {x.name: x.data_type for x in self.get_columns(cursor)}
316328
new_columns = temp_table.get_columns(cursor)
317329

@@ -338,6 +350,7 @@ def merge(
338350
match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE,
339351
qualify: bool = False,
340352
files: list[str] | None = None,
353+
copy_grants: bool = True,
341354
) -> None:
342355
def copy_callable(table: Table, sync_tags: bool) -> None:
343356
return table.copy_into(
@@ -347,6 +360,7 @@ def copy_callable(table: Table, sync_tags: bool) -> None:
347360
match_by_column_name=match_by_column_name,
348361
sync_tags=sync_tags,
349362
files=files,
363+
copy_grants=copy_grants,
350364
)
351365

352366
return self._merge(copy_callable, primary_keys, replication_keys, qualify)
@@ -567,6 +581,7 @@ def copy_custom(
567581
stage: str | None = None,
568582
files: list[str] | None = None,
569583
create_table: bool = True,
584+
copy_grants: bool = True,
570585
) -> None:
571586
column_names = ", ".join(column_definitions.keys())
572587
definitions = ", ".join(column_definitions.values())
@@ -593,6 +608,7 @@ def copy_custom(
593608
sync_tags,
594609
stage,
595610
create_table,
611+
copy_grants,
596612
)
597613

598614
def merge_custom(
@@ -606,6 +622,7 @@ def merge_custom(
606622
qualify: bool = False,
607623
files: list[str] | None = None,
608624
create_table: bool = True,
625+
copy_grants: bool = True,
609626
) -> None:
610627
def copy_callable(table: Table, sync_tags: bool) -> None:
611628
return table.copy_custom(
@@ -617,6 +634,7 @@ def copy_callable(table: Table, sync_tags: bool) -> None:
617634
sync_tags=sync_tags,
618635
files=files,
619636
create_table=create_table,
637+
copy_grants=copy_grants,
620638
)
621639

622640
return self._merge(copy_callable, primary_keys, replication_keys, qualify)

tests/test_models.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,40 @@ def test_create_or_replace_table(mock_connect):
9797
assert result == f"Table {test_table.name} successfully created."
9898

9999

100+
@patch("snowflake_utils.settings.connect")
101+
def test_create_or_replace_table_with_copy_grants(mock_connect):
102+
"""Test that COPY GRANTS clause is included when copy_grants=True and full_refresh=True."""
103+
mock_cursor = make_mock_cursor(
104+
fetchall_return=[(f"Table {test_table.name} successfully created.",)]
105+
)
106+
mock_conn = make_mock_conn(cursor=mock_cursor)
107+
mock_connect.return_value = mock_conn
108+
statement = test_table.get_create_table_statement(
109+
full_refresh=True, copy_grants=True
110+
)
111+
result = mock_cursor.execute(statement).fetchall()[0][0]
112+
assert result == f"Table {test_table.name} successfully created."
113+
# Verify that COPY GRANTS is included in the statement
114+
assert "COPY GRANTS" in statement
115+
116+
117+
@patch("snowflake_utils.settings.connect")
118+
def test_create_or_replace_table_without_copy_grants(mock_connect):
119+
"""Test that COPY GRANTS clause is not included when copy_grants=False and full_refresh=True."""
120+
mock_cursor = make_mock_cursor(
121+
fetchall_return=[(f"Table {test_table.name} successfully created.",)]
122+
)
123+
mock_conn = make_mock_conn(cursor=mock_cursor)
124+
mock_connect.return_value = mock_conn
125+
statement = test_table.get_create_table_statement(
126+
full_refresh=True, copy_grants=False
127+
)
128+
result = mock_cursor.execute(statement).fetchall()[0][0]
129+
assert result == f"Table {test_table.name} successfully created."
130+
# Verify that COPY GRANTS is not included in the statement
131+
assert "COPY GRANTS" not in statement
132+
133+
100134
@patch("snowflake_utils.settings.connect")
101135
def test_create_table_if_not_exists(mock_connect):
102136
mock_cursor = make_mock_cursor(fetchall_return=[("statement succeeded: PYTEST",)])
@@ -109,6 +143,23 @@ def test_create_table_if_not_exists(mock_connect):
109143
)
110144

111145

146+
@patch("snowflake_utils.settings.connect")
147+
def test_create_table_if_not_exists_copy_grants_ignored(mock_connect):
148+
"""Test that COPY GRANTS clause is not included when full_refresh=False, even if copy_grants=True."""
149+
mock_cursor = make_mock_cursor(fetchall_return=[("statement succeeded: PYTEST",)])
150+
mock_conn = make_mock_conn(cursor=mock_cursor)
151+
mock_connect.return_value = mock_conn
152+
statement = test_table.get_create_table_statement(
153+
full_refresh=False, copy_grants=True
154+
)
155+
result = mock_cursor.execute(statement).fetchall()[0][0]
156+
assert ("statement succeeded" in result and test_table.name in result) or (
157+
f"Table {test_table.name} successfully created."
158+
)
159+
# Verify that COPY GRANTS is not included in the statement (only applies to CREATE OR REPLACE)
160+
assert "COPY GRANTS" not in statement
161+
162+
112163
@patch("snowflake_utils.settings.connect")
113164
def test_temporary_external_stage_creation(mock_connect):
114165
mock_cursor = make_mock_cursor(

0 commit comments

Comments
 (0)