Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Version history
===============

**Unreleased**

- Fixed rendering of inherited keyword arguments for dialect-specific types that use
``**kwargs`` in their initializers (such as MySQL ``CHAR`` with ``collation``) while
preserving existing ``*args`` rendering behavior (PR by @hyoj0942)

**4.0.1**

- Fix enum column definitions to explicitly include schema and name if reflected
Expand Down
102 changes: 84 additions & 18 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,77 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
else:
return render_callable("mapped_column", *args, kwargs=kwargs)

def _render_column_type_value(self, value: Any) -> str:
if isinstance(value, (JSONB, JSON)):
# Remove astext_type if it's the default
if isinstance(value.astext_type, Text) and value.astext_type.length is None:
value.astext_type = None # type: ignore[assignment]
else:
self.add_import(Text)

if isinstance(value, TextClause):
self.add_literal_import("sqlalchemy", "text")
return render_callable("text", repr(value.text))

return repr(value)

def _collect_inherited_init_kwargs(
self,
column_type: Any,
init_sig: inspect.Signature,
seen_param_names: set[str],
missing: object,
) -> dict[str, str]:
has_var_keyword = any(
param.kind is Parameter.VAR_KEYWORD
for param in init_sig.parameters.values()
)
has_var_positional = any(
param.kind is Parameter.VAR_POSITIONAL
for param in init_sig.parameters.values()
)
if not has_var_keyword or has_var_positional:
return {}

inherited_kwargs: dict[str, str] = {}
for supercls in column_type.__class__.__mro__[1:]:
if supercls is object:
break

try:
super_sig = inspect.signature(supercls.__init__)
except (TypeError, ValueError):
continue

for super_param in list(super_sig.parameters.values())[1:]:
if super_param.name.startswith("_"):
continue

if super_param.kind in (
Parameter.POSITIONAL_ONLY,
Parameter.VAR_POSITIONAL,
Parameter.VAR_KEYWORD,
):
continue

if super_param.name in seen_param_names:
continue

seen_param_names.add(super_param.name)
value = getattr(column_type, super_param.name, missing)
if value is missing:
continue

default = super_param.default
if default is not Parameter.empty and value == default:
continue

inherited_kwargs[super_param.name] = self._render_column_type_value(
value
)

return inherited_kwargs

def render_column_type(self, column: Column[Any]) -> str:
column_type = column.type
# Check if this is an enum column with a Python enum class
Expand Down Expand Up @@ -586,6 +657,8 @@ def render_column_type(self, column: Column[Any]) -> str:
defaults = {param.name: param.default for param in sig.parameters.values()}
Comment thread
sheinbergon marked this conversation as resolved.
missing = object()
use_kwargs = False
seen_param_names: set[str] = set()

for param in list(sig.parameters.values())[1:]:
# Remove annoyances like _warn_on_bytestring
if param.name.startswith("_"):
Expand All @@ -594,32 +667,25 @@ def render_column_type(self, column: Column[Any]) -> str:
use_kwargs = True
continue

seen_param_names.add(param.name)
value = getattr(column_type, param.name, missing)

if isinstance(value, (JSONB, JSON)):
# Remove astext_type if it's the default
if (
isinstance(value.astext_type, Text)
and value.astext_type.length is None
):
value.astext_type = None # type: ignore[assignment]
else:
self.add_import(Text)

default = defaults.get(param.name, missing)
if isinstance(value, TextClause):
self.add_literal_import("sqlalchemy", "text")
rendered_value = render_callable("text", repr(value.text))
else:
rendered_value = repr(value)

if value is missing or value == default:
use_kwargs = True
elif use_kwargs:
continue

rendered_value = self._render_column_type_value(value)
if use_kwargs:
kwargs[param.name] = rendered_value
else:
args.append(rendered_value)

kwargs.update(
self._collect_inherited_init_kwargs(
column_type, sig, seen_param_names, missing
)
)

vararg = next(
(
param.name
Expand Down
34 changes: 34 additions & 0 deletions tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,40 @@ class Num2(Base):
)


@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
@pytest.mark.parametrize("generator", [["keep_dialect_types"]], indirect=True)
def test_keep_dialect_types_keeps_mysql_char_collation(
generator: CodeGenerator,
) -> None:
from sqlalchemy.dialects.mysql import CHAR as MYSQL_CHAR
from sqlalchemy.dialects.mysql import INTEGER as MYSQL_INTEGER

Table(
"result_logs",
generator.metadata,
Column("id", MYSQL_INTEGER, primary_key=True),
Column("result_code", MYSQL_CHAR(1, collation="utf8mb3_bin"), nullable=False),
)

validate_code(
generator.generate(),
"""\
from sqlalchemy.dialects.mysql import CHAR, INTEGER
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
pass


class ResultLogs(Base):
__tablename__ = 'result_logs'

id: Mapped[int] = mapped_column(INTEGER, primary_key=True)
result_code: Mapped[str] = mapped_column(CHAR(1, collation='utf8mb3_bin'), nullable=False)
""",
)


def test_onetomany(generator: CodeGenerator) -> None:
Table(
"simple_items",
Expand Down
28 changes: 28 additions & 0 deletions tests/test_generator_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,34 @@ def test_mysql_column_types(generator: CodeGenerator) -> None:
)


@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
@pytest.mark.parametrize("generator", [["keep_dialect_types"]], indirect=True)
def test_mysql_char_collation_keep_dialect_types(generator: CodeGenerator) -> None:
Table(
"simple_items",
generator.metadata,
Column("id", mysql.INTEGER, primary_key=True),
Column("result_code", mysql.CHAR(1, collation="utf8mb3_bin"), nullable=False),
)

validate_code(
generator.generate(),
"""\
from sqlalchemy import Column, MetaData, Table
from sqlalchemy.dialects.mysql import CHAR, INTEGER

metadata = MetaData()


t_simple_items = Table(
'simple_items', metadata,
Column('id', INTEGER, primary_key=True),
Column('result_code', CHAR(1, collation='utf8mb3_bin'), nullable=False)
)
""",
)


def test_constraints(generator: CodeGenerator) -> None:
Table(
"simple_items",
Expand Down