Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Version history
===============

**4.0.1**

- Fix enum column definitions to explicitly include schema and name if reflected
via SQLAlchemy's Metadata (pr by @sheinbergon)

**4.0.0**

- **BACKWARD INCOMPATIBLE** API changes (for those who customize code generation by
Expand Down
18 changes: 16 additions & 2 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,14 @@ def render_column_type(self, column: Column[Any]) -> str:
):
# Import SQLAlchemy Enum (will be handled in collect_imports)
self.add_import(Enum)
return f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls])"
extra_kwargs = ""
if column_type.name is not None:
extra_kwargs += f", name={column_type.name!r}"

if column_type.schema is not None:
extra_kwargs += f", schema={column_type.schema!r}"

return f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
Comment thread
sheinbergon marked this conversation as resolved.

args = []
kwargs: dict[str, Any] = {}
Expand All @@ -562,7 +569,14 @@ def render_column_type(self, column: Column[Any]) -> str:
):
self.add_import(ARRAY)
self.add_import(Enum)
rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls])"
extra_kwargs = ""
if column_type.item_type.name is not None:
extra_kwargs += f", name={column_type.item_type.name!r}"

if column_type.item_type.schema is not None:
extra_kwargs += f", schema={column_type.item_type.schema!r}"
Comment thread
sheinbergon marked this conversation as resolved.

rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
if column_type.dimensions is not None:
kwargs["dimensions"] = repr(column_type.dimensions)

Expand Down
94 changes: 87 additions & 7 deletions tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,14 +2508,14 @@ class Accounts(Base):
__tablename__ = 'accounts'

id: Mapped[int] = mapped_column(Integer, primary_key=True)
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'), nullable=False)


class Users(Base):
__tablename__ = 'users'

id: Mapped[int] = mapped_column(Integer, primary_key=True)
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'), nullable=False)
""",
)

Expand Down Expand Up @@ -2851,7 +2851,7 @@ class Users(Base):
__tablename__ = 'users'

id: Mapped[int] = mapped_column(Integer, primary_key=True)
roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')), nullable=False)
""",
)

Expand Down Expand Up @@ -2927,7 +2927,7 @@ class Users(Base):
__tablename__ = 'users'

id: Mapped[int] = mapped_column(Integer, primary_key=True)
roles: Mapped[Optional[list[RoleEnum]]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
roles: Mapped[Optional[list[RoleEnum]]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
""",
)

Expand Down Expand Up @@ -2965,7 +2965,7 @@ class Items(Base):
__tablename__ = 'items'

id: Mapped[int] = mapped_column(Integer, primary_key=True)
tag_matrix: Mapped[list[list[TagEnum]]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls]), dimensions=2), nullable=False)
tag_matrix: Mapped[list[list[TagEnum]]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum'), dimensions=2), nullable=False)
""",
)

Expand Down Expand Up @@ -3043,7 +3043,87 @@ class Users(Base):
__tablename__ = 'users'

id: Mapped[int] = mapped_column(Integer, primary_key=True)
primary_role: Mapped[RoleEnum] = mapped_column(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
all_roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
primary_role: Mapped[RoleEnum] = mapped_column(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum'), nullable=False)
all_roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')), nullable=False)
""",
)


def test_enum_named_with_schema(generator: CodeGenerator) -> None:
Table(
"my_table",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"status",
SAEnum("active", "inactive", name="status_enum", schema="custom_schema"),
nullable=False,
),
schema="custom_schema",
)

validate_code(
generator.generate(),
"""\
import enum

from sqlalchemy import Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
pass


class StatusEnum(str, enum.Enum):
ACTIVE = 'active'
INACTIVE = 'inactive'


class MyTable(Base):
__tablename__ = 'my_table'
__table_args__ = {'schema': 'custom_schema'}

id: Mapped[int] = mapped_column(Integer, primary_key=True)
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum', schema='custom_schema'), nullable=False)
""",
)


def test_array_enum_named_with_schema(generator: CodeGenerator) -> None:
Table(
"my_table",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"tags",
ARRAY(SAEnum("a", "b", name="tag_enum", schema="custom_schema")),
nullable=False,
),
schema="custom_schema",
)

validate_code(
generator.generate(),
"""\
import enum

from sqlalchemy import ARRAY, Enum, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
pass


class TagEnum(str, enum.Enum):
A = 'a'
B = 'b'


class MyTable(Base):
__tablename__ = 'my_table'
__table_args__ = {'schema': 'custom_schema'}

id: Mapped[int] = mapped_column(Integer, primary_key=True)
tags: Mapped[list[TagEnum]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum', schema='custom_schema')), nullable=False)
""",
)
39 changes: 38 additions & 1 deletion tests/test_generator_sqlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import Enum as SAEnum
from sqlalchemy import Uuid
from sqlalchemy.engine import Engine
from sqlalchemy.schema import (
Expand All @@ -13,7 +14,7 @@
Table,
UniqueConstraint,
)
from sqlalchemy.types import INTEGER, VARCHAR
from sqlalchemy.types import ARRAY, INTEGER, VARCHAR

from sqlacodegen.generators import CodeGenerator, SQLModelGenerator

Expand Down Expand Up @@ -329,3 +330,39 @@ class Accounts(SQLModel, table=True):
status: AccountsStatus = Field(sa_column=Column('status', Enum(AccountsStatus, values_callable=lambda cls: [member.value for member in cls]), nullable=False))
""",
)


def test_array_enum_named_with_schema(generator: CodeGenerator) -> None:
Table(
"my_table",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"tags",
ARRAY(SAEnum("a", "b", name="tag_enum", schema="custom_schema")),
nullable=False,
),
schema="custom_schema",
)

validate_code(
generator.generate(),
"""\
import enum

from sqlalchemy import ARRAY, Column, Enum, Integer
from sqlmodel import Field, SQLModel

class TagEnum(str, enum.Enum):
A = 'a'
B = 'b'


class MyTable(SQLModel, table=True):
__tablename__ = 'my_table'
__table_args__ = {'schema': 'custom_schema'}

id: int = Field(sa_column=Column('id', Integer, primary_key=True))
tags: list[TagEnum] = Field(sa_column=Column('tags', ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum', schema='custom_schema')), nullable=False))
""",
)
86 changes: 80 additions & 6 deletions tests/test_generator_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Blah(str, enum.Enum):

t_simple_items = Table(
'simple_items', metadata,
Column('enum', Enum(Blah, values_callable=lambda cls: [member.value for member in cls])),
Column('enum', Enum(Blah, values_callable=lambda cls: [member.value for member in cls], name='blah', schema='someschema')),
Column('bool', Boolean),
Column('vector', VECTOR(3)),
Column('number', Numeric(10, asdecimal=False)),
Expand Down Expand Up @@ -309,13 +309,13 @@ class StatusEnum(str, enum.Enum):
t_accounts = Table(
'accounts', metadata,
Column('id', Integer, primary_key=True),
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]))
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'))
)

t_users = Table(
'users', metadata,
Column('id', Integer, primary_key=True),
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]))
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'))
)
""",
)
Expand Down Expand Up @@ -348,7 +348,7 @@ class RoleEnum(str, enum.Enum):
t_users = Table(
'users', metadata,
Column('id', Integer, primary_key=True),
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
)
""",
)
Expand Down Expand Up @@ -386,13 +386,87 @@ class RoleEnum(str, enum.Enum):
t_groups = Table(
'groups', metadata,
Column('id', Integer, primary_key=True),
Column('allowed_roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
Column('allowed_roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
)

t_users = Table(
'users', metadata,
Column('id', Integer, primary_key=True),
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
)
""",
)


def test_enum_named_with_schema(generator: CodeGenerator) -> None:
Table(
"my_table",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"status",
SAEnum("active", "inactive", name="status_enum", schema="custom_schema"),
),
schema="custom_schema",
)

validate_code(
generator.generate(),
"""\
import enum

from sqlalchemy import Column, Enum, Integer, MetaData, Table

metadata = MetaData()


class StatusEnum(str, enum.Enum):
ACTIVE = 'active'
INACTIVE = 'inactive'


t_my_table = Table(
'my_table', metadata,
Column('id', Integer, primary_key=True),
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum', schema='custom_schema')),
schema='custom_schema'
)
""",
)


def test_array_enum_named_with_schema(generator: CodeGenerator) -> None:
Table(
"my_table",
generator.metadata,
Column("id", INTEGER, primary_key=True),
Column(
"tags",
ARRAY(SAEnum("a", "b", name="tag_enum", schema="custom_schema")),
),
schema="custom_schema",
)

validate_code(
generator.generate(),
"""\
import enum

from sqlalchemy import ARRAY, Column, Enum, Integer, MetaData, Table

metadata = MetaData()


class TagEnum(str, enum.Enum):
A = 'a'
B = 'b'


t_my_table = Table(
'my_table', metadata,
Column('id', Integer, primary_key=True),
Column('tags', ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum', schema='custom_schema'))),
schema='custom_schema'
)
""",
)
Expand Down