Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions datafaker/interactive/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class DbCmd(ABC, cmd.Cmd):
"Error: '{0}' is not the name of a table"
" in this database or a column in this table"
)
ERROR_FAILED_SQL = 'SQL query "{query}" caused exception {exc}'
ERROR_FAILED_DISPLAY = "Error: Failed to display: {}"
ROW_COUNT_MSG = "Total row count: {}"

@abstractmethod
Expand Down Expand Up @@ -325,7 +327,7 @@ def do_counts(self, _arg: str) -> None:
return
table_name = self.table_name()
nonnull_columns = self.get_nonnull_columns(table_name)
colcounts = [f", COUNT({nnc}) AS {nnc}" for nnc in nonnull_columns]
colcounts = [f', COUNT("{nnc}") AS "{nnc}"' for nnc in nonnull_columns]
with self.sync_engine.connect() as connection:
result = (
connection.execute(
Expand Down Expand Up @@ -353,19 +355,24 @@ def do_counts(self, _arg: str) -> None:
def do_select(self, arg: str) -> None:
"""Run a select query over the database and show the first 50 results."""
max_select_rows = 50
query = "SELECT " + arg
with self.sync_engine.connect() as connection:
try:
result = connection.execute(sqlalchemy.text("SELECT " + arg))
result = connection.execute(sqlalchemy.text(query))
except sqlalchemy.exc.DatabaseError as exc:
self.print("Failed to execute: {}", exc)
self.print(self.ERROR_FAILED_SQL, exc, query)
return
row_count = result.rowcount
self.print(self.ROW_COUNT_MSG, row_count)
if 50 < row_count:
self.print("Showing the first {} rows", max_select_rows)
fields = list(result.keys())
rows = result.fetchmany(max_select_rows)
self.print_table(fields, rows)
try:
self.print_table(fields, rows)
except ValueError as exc:
self.print(self.ERROR_FAILED_DISPLAY, exc)
return

def do_peek(self, arg: str) -> None:
"""
Expand All @@ -383,9 +390,9 @@ def do_peek(self, arg: str) -> None:
col_names = arg.split()
if not col_names:
col_names = self._get_column_names()
nonnulls = [cn + " IS NOT NULL" for cn in col_names]
nonnulls = [f'"{cn}" IS NOT NULL' for cn in col_names]
with self.sync_engine.connect() as connection:
cols = ",".join(col_names)
cols = ", ".join(f'"{cn}"' for cn in col_names)
where = "WHERE" if nonnulls else ""
nonnull = " OR ".join(nonnulls)
query = sqlalchemy.text(
Expand All @@ -395,7 +402,7 @@ def do_peek(self, arg: str) -> None:
try:
result = connection.execute(query)
except sqlalchemy.exc.SQLAlchemyError as exc:
self.print(f'SQL query "{query}" caused exception {exc}')
self.print(self.ERROR_FAILED_SQL, exc, query)
return
self.print_table(list(result.keys()), result.fetchmany(max_peek_rows))

Expand Down
4 changes: 3 additions & 1 deletion datafaker/interactive/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def __init__(
src_schema: str | None,
metadata: MetaData,
config: MutableMapping[str, Any],
*args: Any,
**kwargs: Any,
) -> None:
"""Initialise a TableCmd."""
super().__init__(src_dsn, src_schema, metadata, config)
super().__init__(src_dsn, src_schema, metadata, config, *args, **kwargs)
self.set_prompt()

@property
Expand Down
19 changes: 19 additions & 0 deletions tests/examples/tricky.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- DROP DATABASE IF EXISTS tricky WITH (FORCE);
CREATE DATABASE tricky WITH TEMPLATE template0 ENCODING = 'UTF8' LOCALE = 'en_US.utf8';
ALTER DATABASE tricky OWNER TO postgres;

\connect tricky

CREATE TABLE public.names (
id INTEGER NOT NULL,
"offset" INTEGER,
"count" INTEGER NOT NULL,
sensible TEXT
);

ALTER TABLE ONLY public.names ADD CONSTRAINT names_pkey PRIMARY KEY (id);

ALTER TABLE public.names OWNER TO postgres;

INSERT INTO public.names VALUES (1, 10, 5, 'reasonable');
INSERT INTO public.names VALUES (2, NULL, 6, 'clear-headed');
96 changes: 96 additions & 0 deletions tests/test_interactive_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy import select

from datafaker.interactive import TableCmd
from datafaker.serialize_metadata import dict_to_metadata
from tests.utils import RequiresDBTestCase, TestDbCmdMixin


Expand Down Expand Up @@ -396,3 +397,98 @@ def test_sanity_checks_errors_only(self) -> None:
{},
),
)


class TrickyTests(ConfigureTablesTests):
"""Testing configure-tables with the instrument.sql database."""

dump_file_path = "tricky.sql"
database_name = "tricky"
schema_name = "public"

def do_and_test_peek_tricky(self, tc: TestTableCmd) -> None:
"""Peek the "names" table and check the output."""
tc.reset()
tc.do_peek("")
self.assertSetEqual(set(tc.headings), {"id", "offset", "count", "sensible"})
self.assertSetEqual(
set(tc.rows), {(1, 10, 5, "reasonable"), (2, None, 6, "clear-headed")}
)

def test_peek_with_tricky_names(self) -> None:
"""
Peek with column names that are function names (#66).
"""
with self._get_cmd({}) as tc:
tc.do_next("names")
self.do_and_test_peek_tricky(tc)

def test_count_with_tricky_names(self) -> None:
"""
Count with column names that are function names (#66).
"""
with self._get_cmd({}) as tc:
tc.do_next("names")
self.do_and_test_peek_tricky(tc)
tc.do_counts("")
self.assertSequenceEqual(tc.rows, [["offset", 1], ["sensible", 0]])

def test_incorrect_orm_yaml_columns(self) -> None:
"""
Peek with incorrect columns in orm.yaml (#70).
"""
self.metadata = dict_to_metadata(
{
"tables": {
"names": {
"columns": {
"id": {
"primary": True,
"nullable": False,
"type": "INTEGER",
},
"sensible": {
"primary": False,
"nullable": False,
"type": "TEXT",
},
"nonexistent": {
"primary": False,
"nullable": False,
"type": "TEXT",
},
}
}
}
}
)
with self._get_cmd({}) as tc:
tc.reset()
tc.do_peek("")
self.assertIn("SQL query", "/".join(m for (m, _a, _kw) in tc.messages))

def test_repeated_field_does_not_throw_exception(self) -> None:
"""
Select with repeated fields (#70).
"""
with TestTableCmd(
src_dsn=self.dsn,
src_schema=self.schema_name,
metadata=self.metadata,
config={},
print_tables=True,
) as tc:
tc.reset()
tc.do_select('sensible AS same, "offset" AS same FROM names')
self.assertIn(
"Failed to display", "/".join(m for (m, _a, _kw) in tc.messages)
)

def test_sql_error_does_not_throw_exception(self) -> None:
"""
Select with a SQL error.
"""
with self._get_cmd({}) as tc:
tc.reset()
tc.do_select("+++")
self.assertIn("SQL query", "/".join(m for (m, a, kw) in tc.messages))
5 changes: 4 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,10 @@ def generate_data(
class TestDbCmdMixin(DbCmd):
"""A mixin for capturing output from interactive commands."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args: Any, print_tables: bool = False, **kwargs: Any) -> None:
"""Initialize a TestDbCmdMixin"""
super().__init__(*args, **kwargs)
self._print_tables = print_tables
self.reset()

def reset(self) -> None:
Expand All @@ -316,6 +317,8 @@ def print_table(
"""Capture the printed table."""
self.headings = headings
self.rows = rows
if self._print_tables:
super().print_table(headings, rows)

def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None:
"""Capture the printed table."""
Expand Down