diff --git a/sqlalchemy_firebird/base.py b/sqlalchemy_firebird/base.py index c2467b3..2cbab1a 100644 --- a/sqlalchemy_firebird/base.py +++ b/sqlalchemy_firebird/base.py @@ -189,6 +189,18 @@ def get_column_specification(self, column, **kwargs): has_identity = column.identity is not None + type_compiler_instance = ( + self.dialect.type_compiler_instance + if self.dialect.using_sqlalchemy2 + else self.dialect.type_compiler + ) + + compiled_type = type_compiler_instance.process( + column.type, + type_expression=column, + identifier_preparer=self.preparer, + ) + if ( column.primary_key and column is column.table._autoincrement_column @@ -202,19 +214,9 @@ def get_column_specification(self, column, **kwargs): ) and self.dialect.supports_identity_columns ): - colspec += " INTEGER GENERATED BY DEFAULT AS IDENTITY" + colspec += " %s GENERATED BY DEFAULT AS IDENTITY" % compiled_type else: - type_compiler_instance = ( - self.dialect.type_compiler_instance - if self.dialect.using_sqlalchemy2 - else self.dialect.type_compiler - ) - - colspec += " " + type_compiler_instance.process( - column.type, - type_expression=column, - identifier_preparer=self.preparer, - ) + colspec += " " + compiled_type default_ = self.get_column_default_string(column) if default_ is not None: colspec += " DEFAULT " + default_ diff --git a/test/test_compiler.py b/test/test_compiler.py index 2d34b28..6c46f63 100644 --- a/test/test_compiler.py +++ b/test/test_compiler.py @@ -7,8 +7,10 @@ from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import Index +from sqlalchemy import BigInteger from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import SmallInteger from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import String @@ -442,6 +444,29 @@ def test_column_identity(self, pk): % (", PRIMARY KEY (y)" if pk else ""), ) + @testing.combinations( + (Integer, "INTEGER"), + (BigInteger, "BIGINT"), + (SmallInteger, "SMALLINT"), + argnames="type_,expected_sql_type", + ) + def test_autoincrement_primary_key_uses_column_type( + self, type_, expected_sql_type + ): + # Regression test for issue #88: the autoincrement primary key + # must use the column's declared type instead of hardcoded INTEGER. + m = MetaData() + t = Table( + "t", + m, + Column("id", type_, primary_key=True, autoincrement=True), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (id %s GENERATED BY DEFAULT AS IDENTITY, " + "PRIMARY KEY (id))" % expected_sql_type, + ) + def test_column_identity_null(self): # all other tests are in test_identity_column.py m = MetaData()