Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions beanquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,21 +764,24 @@ def _insert(self, node: ast.Insert):
impl = getattr(table, 'insert', None)
if impl is None:
raise CompilationError(f'table "{node.table.name}" does not support insertion', node.table)
if len(node.values) != len(node.columns):
raise CompilationError(
f'column names and values mismatch: '
f'expected {len(node.columns)} but {len(node.values)} values were supplied', node)
values = [EvalConstant(None)] * len(table.columns)
columns = {name: i for i, name in enumerate(table.columns.keys())}
for column, value in zip(node.columns, node.values):
index = columns.get(column.name)
if index is None:
raise CompilationError(f'column "{column.name}" not found in table "{node.table.name}"', column)
expr = self._compile(value)
if not expr.dtype == table.columns.get(column.name).dtype:
raise CompilationError(f'expression has wrong type for column "{column.name}"', value)
values[index] = expr
return EvalInsert(table, values)
rows = []
for row in node.values:
if len(row) != len(node.columns):
raise CompilationError(
f'column names and values mismatch: '
f'expected {len(node.columns)} but {len(row)} values were supplied', node)
values = [EvalConstant(None)] * len(table.columns)
for column, value in zip(node.columns, row):
index = columns.get(column.name)
if index is None:
raise CompilationError(f'column "{column.name}" not found in table "{node.table.name}"', column)
expr = self._compile(value)
if not expr.dtype == table.columns.get(column.name).dtype:
raise CompilationError(f'expression has wrong type for column "{column.name}"', value)
values[index] = expr
rows.append(values)
return EvalInsert(table, rows)


def transform_journal(journal):
Expand Down
2 changes: 1 addition & 1 deletion beanquery/parser/bql.ebnf
Original file line number Diff line number Diff line change
Expand Up @@ -397,5 +397,5 @@ create_table::CreateTable
insert::Insert
= 'INSERT' 'INTO' ~ table:table
['(' columns:','.{column} ')']
'VALUES' '(' values:','.{expression} ')'
'VALUES' ','.{ '(' values+:','.{expression}+ ')' }
;
23 changes: 18 additions & 5 deletions beanquery/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,17 +1256,30 @@ def block1():
self._token(')')
self._define(['columns'], [])
self._token('VALUES')
self._token('(')

def sep2():
self._token(',')

def block3():
self._expression_()
self._token('(')

def sep4():
self._token(',')

def block5():
self._expression_()
self._positive_gather(block5, sep4)
self.add_last_node_to_name('values')
self._token(')')
self._define(
[],
['values'],
)
self._gather(block3, sep2)
self.name_last_node('values')
self._token(')')
self._define(['columns', 'table', 'values'], [])
self._define(
['columns', 'table'],
['values'],
)


def main(filename, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions beanquery/query_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import operator

from decimal import Decimal
from typing import List
from typing import List, Sequence

from dateutil.relativedelta import relativedelta

Expand Down Expand Up @@ -697,9 +697,9 @@ def __call__(self):
@dataclasses.dataclass
class EvalInsert:
table: tables.Table
values: list[EvalNode]
rows: Sequence[Sequence[EvalNode]]

def __call__(self):
values = tuple(value(None) for value in self.values)
self.table.insert(values)
for row in self.rows:
self.table.insert(tuple(value(None) for value in row))
return (), []
6 changes: 6 additions & 0 deletions beanquery/query_execute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,12 @@ def test_insert_placeholders(self):
self.assertEqual(self.conn.tables['abcd'].data[0], values)
self.assertEqual(curs.fetchall(), [])

def test_insert_many(self):
curs = self.conn.execute('''INSERT INTO abcd (a) VALUES (1), (2), (3), (4)''')
values = [row[0] for row in self.conn.tables['abcd'].data]
self.assertEqual(values, [1, 2, 3, 4])
self.assertEqual(curs.fetchall(), [])


class TestCSVTable(unittest.TestCase):

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ ignore = [
'PLW2901',
'RUF012',
'RUF023', # unsorted-dunder-slots
'RUF059', # unused-unpacked-variable
'UP007',
'UP032',
]
Expand Down