Skip to content

Commit 2b5bb71

Browse files
refactor: simplify bulk_insert validation and reuse existing execution logic
- Replace sqlparse with simple string checks in _validate_bulk_insert_query - Add extra_params parameter to _build_fb_numeric_query_params for extensibility - Refactor _executemany_bulk_insert to preprocess and delegate to existing methods - Use TimeoutController for consistent timeout handling - Parametrize integration tests to avoid duplication Addresses PR feedback from ptiurin on #463 Co-Authored-By: petro.tiurin@firebolt.io <petro.tiurin@firebolt.io>
1 parent 3e898e2 commit 2b5bb71

5 files changed

Lines changed: 96 additions & 134 deletions

File tree

src/firebolt/async_db/cursor.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
import logging
54
import time
65
import warnings
@@ -462,20 +461,16 @@ async def executemany(
462461

463462
def _validate_bulk_insert_query(self, query: str) -> None:
464463
"""Validate that query is an INSERT statement for bulk_insert."""
465-
statements = parse_sql(query)
466-
if not statements:
467-
raise ProgrammingError("bulk_insert requires a valid INSERT statement")
464+
query_normalized = query.lstrip().lower()
468465

469-
if len(statements) > 1:
466+
if not query_normalized.startswith("insert"):
470467
raise ProgrammingError(
471-
"bulk_insert does not support multi-statement queries"
468+
"bulk_insert is only supported for INSERT statements"
472469
)
473470

474-
statement_type = statements[0].get_type()
475-
if statement_type != "INSERT":
471+
if ";" in query.strip().rstrip(";"):
476472
raise ProgrammingError(
477-
f"bulk_insert is only supported for INSERT statements. "
478-
f"Got {statement_type} statement"
473+
"bulk_insert does not support multi-statement queries"
479474
)
480475

481476
async def _executemany_bulk_insert(
@@ -497,37 +492,32 @@ async def _executemany_bulk_insert(
497492
except ValueError:
498493
raise ProgrammingError(f"Unsupported paramstyle: {paramstyle}")
499494

495+
concatenated_query = "; ".join([query] * len(parameters_seq))
496+
500497
await self._close_rowset_and_reset()
501498
self._row_set = InMemoryAsyncRowSet()
502499

503500
try:
504501
if parameter_style == ParameterStyle.FB_NUMERIC:
505-
concatenated_query = "; ".join([query] * len(parameters_seq))
506-
507502
flattened_params: List[ParameterType] = []
508503
for param_set in parameters_seq:
509504
flattened_params.extend(param_set)
510505

511-
query_parameters = [
512-
{
513-
"name": f"${i+1}",
514-
"value": self._formatter.convert_parameter_for_serialization(
515-
value
516-
),
517-
}
518-
for i, value in enumerate(flattened_params)
519-
]
520-
521-
query_params: Dict[str, Any] = {
522-
"output_format": self._get_output_format(False),
523-
"merge_prepared_statement_batches": "true",
524-
}
525-
if query_parameters:
526-
query_params["query_parameters"] = json.dumps(query_parameters)
527-
528506
Cursor._log_query(concatenated_query)
507+
timeout_controller = TimeoutController(timeout_seconds)
508+
timeout_controller.raise_if_timeout()
509+
510+
query_params = self._build_fb_numeric_query_params(
511+
[flattened_params],
512+
streaming=False,
513+
async_execution=False,
514+
extra_params={"merge_prepared_statement_batches": "true"},
515+
)
516+
529517
resp = await self._api_request(
530-
concatenated_query, query_params, timeout=timeout_seconds
518+
concatenated_query,
519+
query_params,
520+
timeout=timeout_controller.remaining(),
531521
)
532522
await self._raise_if_error(resp)
533523
await self._parse_response_headers(resp.headers)

src/firebolt/common/cursor/base_cursor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def _build_fb_numeric_query_params(
243243
parameters: Sequence[Sequence[ParameterType]],
244244
streaming: bool,
245245
async_execution: bool,
246+
extra_params: Optional[Dict[str, Any]] = None,
246247
) -> Dict[str, Any]:
247248
"""
248249
Build query parameters dictionary for fb_numeric paramstyle.
@@ -252,6 +253,7 @@ def _build_fb_numeric_query_params(
252253
only the first parameter sequence is used.
253254
streaming: Whether streaming is enabled
254255
async_execution: Whether async execution is enabled
256+
extra_params: Optional additional query parameters to include
255257
256258
Returns:
257259
Dictionary of query parameters to send with the request
@@ -272,6 +274,8 @@ def _build_fb_numeric_query_params(
272274
query_params["query_parameters"] = json.dumps(query_parameters)
273275
if async_execution:
274276
query_params["async"] = True
277+
if extra_params:
278+
query_params.update(extra_params)
275279
return query_params
276280

277281
@property

src/firebolt/db/cursor.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
import logging
54
import time
65
from abc import ABCMeta, abstractmethod
@@ -464,20 +463,16 @@ def executemany(
464463

465464
def _validate_bulk_insert_query(self, query: str) -> None:
466465
"""Validate that query is an INSERT statement for bulk_insert."""
467-
statements = parse_sql(query)
468-
if not statements:
469-
raise ProgrammingError("bulk_insert requires a valid INSERT statement")
466+
query_normalized = query.lstrip().lower()
470467

471-
if len(statements) > 1:
468+
if not query_normalized.startswith("insert"):
472469
raise ProgrammingError(
473-
"bulk_insert does not support multi-statement queries"
470+
"bulk_insert is only supported for INSERT statements"
474471
)
475472

476-
statement_type = statements[0].get_type()
477-
if statement_type != "INSERT":
473+
if ";" in query.strip().rstrip(";"):
478474
raise ProgrammingError(
479-
f"bulk_insert is only supported for INSERT statements. "
480-
f"Got {statement_type} statement"
475+
"bulk_insert does not support multi-statement queries"
481476
)
482477

483478
def _executemany_bulk_insert(
@@ -499,37 +494,32 @@ def _executemany_bulk_insert(
499494
except ValueError:
500495
raise ProgrammingError(f"Unsupported paramstyle: {paramstyle}")
501496

497+
concatenated_query = "; ".join([query] * len(parameters_seq))
498+
502499
self._close_rowset_and_reset()
503500
self._row_set = InMemoryRowSet()
504501

505502
try:
506503
if parameter_style == ParameterStyle.FB_NUMERIC:
507-
concatenated_query = "; ".join([query] * len(parameters_seq))
508-
509504
flattened_params: List[ParameterType] = []
510505
for param_set in parameters_seq:
511506
flattened_params.extend(param_set)
512507

513-
query_parameters = [
514-
{
515-
"name": f"${i+1}",
516-
"value": self._formatter.convert_parameter_for_serialization(
517-
value
518-
),
519-
}
520-
for i, value in enumerate(flattened_params)
521-
]
522-
523-
query_params: Dict[str, Any] = {
524-
"output_format": self._get_output_format(False),
525-
"merge_prepared_statement_batches": "true",
526-
}
527-
if query_parameters:
528-
query_params["query_parameters"] = json.dumps(query_parameters)
529-
530508
Cursor._log_query(concatenated_query)
509+
timeout_controller = TimeoutController(timeout_seconds)
510+
timeout_controller.raise_if_timeout()
511+
512+
query_params = self._build_fb_numeric_query_params(
513+
[flattened_params],
514+
streaming=False,
515+
async_execution=False,
516+
extra_params={"merge_prepared_statement_batches": "true"},
517+
)
518+
531519
resp = self._api_request(
532-
concatenated_query, query_params, timeout=timeout_seconds
520+
concatenated_query,
521+
query_params,
522+
timeout=timeout_controller.remaining(),
533523
)
534524
self._raise_if_error(resp)
535525
self._parse_response_headers(resp.headers)

tests/integration/dbapi/async/V2/test_queries_async.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -282,56 +282,45 @@ async def test_parameterized_query_with_special_chars(connection: Connection) ->
282282
], "Invalid data in table after parameterized insert"
283283

284284

285-
async def test_executemany_bulk_insert(connection: Connection) -> None:
285+
@mark.parametrize("paramstyle", ["qmark", "fb_numeric"])
286+
async def test_executemany_bulk_insert(connection: Connection, paramstyle: str) -> None:
286287
"""executemany with bulk_insert=True inserts data correctly."""
287-
async with connection.cursor() as c:
288-
await c.execute('DROP TABLE IF EXISTS "test_bulk_insert_async"')
289-
await c.execute(
290-
'CREATE FACT TABLE "test_bulk_insert_async"(id int, name string) primary index id'
291-
)
288+
import firebolt.async_db as db_module
292289

293-
import firebolt.async_db as db_module
290+
original_paramstyle = db_module.paramstyle
294291

295-
original_paramstyle = db_module.paramstyle
296-
db_module.paramstyle = "qmark"
292+
try:
293+
db_module.paramstyle = paramstyle
297294

298-
try:
299-
await c.executemany(
300-
'INSERT INTO "test_bulk_insert_async" VALUES (?, ?)',
301-
[(1, "alice"), (2, "bob"), (3, "charlie")],
302-
bulk_insert=True,
295+
async with connection.cursor() as c:
296+
await c.execute('DROP TABLE IF EXISTS "test_bulk_insert_async"')
297+
await c.execute(
298+
'CREATE FACT TABLE "test_bulk_insert_async"(id int, name string) primary index id'
303299
)
304300

301+
if paramstyle == "qmark":
302+
await c.executemany(
303+
'INSERT INTO "test_bulk_insert_async" VALUES (?, ?)',
304+
[(1, "alice"), (2, "bob"), (3, "charlie")],
305+
bulk_insert=True,
306+
)
307+
else:
308+
await c.executemany(
309+
'INSERT INTO "test_bulk_insert_async" VALUES ($1, $2)',
310+
[(1, "alice"), (2, "bob"), (3, "charlie")],
311+
bulk_insert=True,
312+
)
313+
305314
await c.execute('SELECT * FROM "test_bulk_insert_async" ORDER BY id')
306315
data = await c.fetchall()
307316
assert len(data) == 3
308317
assert data[0] == [1, "alice"]
309318
assert data[1] == [2, "bob"]
310319
assert data[2] == [3, "charlie"]
311-
finally:
312-
db_module.paramstyle = original_paramstyle
313-
314-
await c.execute('DELETE FROM "test_bulk_insert_async"')
315-
316-
db_module.paramstyle = "fb_numeric"
317-
318-
try:
319-
await c.executemany(
320-
'INSERT INTO "test_bulk_insert_async" VALUES ($1, $2)',
321-
[(4, "david"), (5, "eve"), (6, "frank")],
322-
bulk_insert=True,
323-
)
324-
325-
await c.execute('SELECT * FROM "test_bulk_insert_async" ORDER BY id')
326-
data = await c.fetchall()
327-
assert len(data) == 3
328-
assert data[0] == [4, "david"]
329-
assert data[1] == [5, "eve"]
330-
assert data[2] == [6, "frank"]
331-
finally:
332-
db_module.paramstyle = original_paramstyle
333320

334-
await c.execute('DROP TABLE "test_bulk_insert_async"')
321+
await c.execute('DROP TABLE "test_bulk_insert_async"')
322+
finally:
323+
db_module.paramstyle = original_paramstyle
335324

336325

337326
async def test_multi_statement_query(connection: Connection) -> None:

tests/integration/dbapi/sync/V2/test_queries.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -283,56 +283,45 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
283283
)
284284

285285

286-
def test_executemany_bulk_insert(connection: Connection) -> None:
286+
@mark.parametrize("paramstyle", ["qmark", "fb_numeric"])
287+
def test_executemany_bulk_insert(connection: Connection, paramstyle: str) -> None:
287288
"""executemany with bulk_insert=True inserts data correctly."""
288-
with connection.cursor() as c:
289-
c.execute('DROP TABLE IF EXISTS "test_bulk_insert"')
290-
c.execute(
291-
'CREATE FACT TABLE "test_bulk_insert"(id int, name string) primary index id'
292-
)
289+
import firebolt.db as db_module
293290

294-
import firebolt.db as db_module
291+
original_paramstyle = db_module.paramstyle
295292

296-
original_paramstyle = db_module.paramstyle
297-
db_module.paramstyle = "qmark"
293+
try:
294+
db_module.paramstyle = paramstyle
298295

299-
try:
300-
c.executemany(
301-
'INSERT INTO "test_bulk_insert" VALUES (?, ?)',
302-
[(1, "alice"), (2, "bob"), (3, "charlie")],
303-
bulk_insert=True,
296+
with connection.cursor() as c:
297+
c.execute('DROP TABLE IF EXISTS "test_bulk_insert"')
298+
c.execute(
299+
'CREATE FACT TABLE "test_bulk_insert"(id int, name string) primary index id'
304300
)
305301

302+
if paramstyle == "qmark":
303+
c.executemany(
304+
'INSERT INTO "test_bulk_insert" VALUES (?, ?)',
305+
[(1, "alice"), (2, "bob"), (3, "charlie")],
306+
bulk_insert=True,
307+
)
308+
else:
309+
c.executemany(
310+
'INSERT INTO "test_bulk_insert" VALUES ($1, $2)',
311+
[(1, "alice"), (2, "bob"), (3, "charlie")],
312+
bulk_insert=True,
313+
)
314+
306315
c.execute('SELECT * FROM "test_bulk_insert" ORDER BY id')
307316
data = c.fetchall()
308317
assert len(data) == 3
309318
assert data[0] == [1, "alice"]
310319
assert data[1] == [2, "bob"]
311320
assert data[2] == [3, "charlie"]
312-
finally:
313-
db_module.paramstyle = original_paramstyle
314-
315-
c.execute('DELETE FROM "test_bulk_insert"')
316-
317-
db_module.paramstyle = "fb_numeric"
318-
319-
try:
320-
c.executemany(
321-
'INSERT INTO "test_bulk_insert" VALUES ($1, $2)',
322-
[(4, "david"), (5, "eve"), (6, "frank")],
323-
bulk_insert=True,
324-
)
325-
326-
c.execute('SELECT * FROM "test_bulk_insert" ORDER BY id')
327-
data = c.fetchall()
328-
assert len(data) == 3
329-
assert data[0] == [4, "david"]
330-
assert data[1] == [5, "eve"]
331-
assert data[2] == [6, "frank"]
332-
finally:
333-
db_module.paramstyle = original_paramstyle
334321

335-
c.execute('DROP TABLE "test_bulk_insert"')
322+
c.execute('DROP TABLE "test_bulk_insert"')
323+
finally:
324+
db_module.paramstyle = original_paramstyle
336325

337326

338327
def test_multi_statement_query(connection: Connection) -> None:

0 commit comments

Comments
 (0)