diff --git a/internal/database/postgres/client.go b/internal/database/postgres/client.go new file mode 100644 index 0000000..1e15080 --- /dev/null +++ b/internal/database/postgres/client.go @@ -0,0 +1,360 @@ +package database + +/** + * This file was auto-generated by Mango SQL : https://github.com/kefniark/mangosql + * Do not make direct changes to the file. + */ + +import ( + "context" + "errors" + "fmt" + + "log" + "time" + + squirrel "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/lib/pq" +) + +type SelectBuilder = squirrel.SelectBuilder +type WhereCondition = func(query SelectBuilder) SelectBuilder + +var placeholder = squirrel.Dollar + +type DBClient struct { + ctx *DBContext + // User Custom SQL Queries + // + // Usage: + // entities, err := db.Queries.MySqlRequests() + Queries *CustomQueries +} + +// Create a new instance of MangoSql +func New(db DBPgx) *DBClient { + return newClient(&DBContext{ + db: db, + tx: nil, + }) +} + +func newClient(ctx *DBContext) *DBClient { + return &DBClient{ + ctx: ctx, + // Models + // Custom Queries + Queries: &CustomQueries{ctx: ctx}} +} + +type FilterGenericField[T any] struct { + table string + field string +} + +// Only include Records with a field contains in a set of values +func (f FilterGenericField[T]) In(args ...T) WhereCondition { + sql := fmt.Sprintf("%s.%s = ANY(?)", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.Where(sql, pq.Array(args)) + } +} + +// Exclude Records with a field not contains in a set of values +func (f FilterGenericField[T]) NotIn(args ...T) WhereCondition { + sql := fmt.Sprintf("%s.%s != ANY(?)", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.Where(sql, pq.Array(args)) + } +} + +// Only include Records with a field specific value +func (f FilterGenericField[T]) Equal(arg T) WhereCondition { + sql := fmt.Sprintf("%s.%s = ?", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.Where(sql, arg) + } +} + +// Exclude Records with a field specific value +func (f FilterGenericField[T]) NotEqual(arg T) WhereCondition { + sql := fmt.Sprintf("%s.%s != ?", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.Where(sql, arg) + } +} + +// Only include Records with a field has undefined value +func (f FilterGenericField[T]) IsNull() WhereCondition { + sql := fmt.Sprintf("%s.%s IS NULL", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.Where(sql) + } +} + +// Only include Records with a field has defined values +func (f FilterGenericField[T]) IsNotNull() WhereCondition { + sql := fmt.Sprintf("%s.%s IS NOT NULL", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.Where(sql) + } +} + +// Sort Records in ASC order +func (f FilterGenericField[T]) OrderAsc() WhereCondition { + sql := fmt.Sprintf("%s.%s ASC", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.OrderBy(sql) + } +} + +// Sort Records in DESC order +func (f FilterGenericField[T]) OrderDesc() WhereCondition { + sql := fmt.Sprintf("%s.%s DESC", f.table, f.field) + return func(cond SelectBuilder) SelectBuilder { + return cond.OrderBy(sql) + } +} + +type DBPgx interface { + Begin(ctx context.Context) (pgx.Tx, error) + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row +} + +// Create a new Sql transaction. +// If any error or panic occurs inside, the transaction is automatically rollback +// +// Usage: +// +// err := db.Transaction(func(tx *DBClient) error { +// // ... can use tx. like db. +// }) +func (db DBClient) Transaction(transaction func(dbClient *DBClient) error) (e error) { + if db.ctx.tx != nil { + return errors.New("nested transaction is not supported") + } + + tx, err := db.ctx.db.Begin(context.Background()) + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + e = errors.Join(errors.New(p.(string)), tx.Rollback(context.Background())) + } + }() + + client := newClient(&DBContext{ + db: db.ctx.db, + tx: tx, + }) + if err = transaction(client); err != nil { + return errors.Join(err, tx.Rollback(context.Background())) + } + + return tx.Commit(context.Background()) +} + +// Split slice into chunks of the given size +func chunkBy[T any](items []T, chunkSize int) (chunks [][]T) { + for chunkSize < len(items) { + items, chunks = items[chunkSize:], append(chunks, items[0:chunkSize:chunkSize]) + } + return append(chunks, items) +} + +// Execute a Custom SQL query and get one row result. +// +// Usage: +// +// data MyResult +// res, err := db.QueryRow(ctx, &data, sql, args...) +func QueryRow(ctx *DBContext, data interface{}, sql string, args ...interface{}) error { + db := ctx.db + if ctx.tx != nil { + db = ctx.tx + } + + return db.QueryRow(context.Background(), sql, args...).Scan(data) +} + +// Execute a Custom SQL query and get one row result. +// +// Usage: +// +// res, err := db.QueryOne[MyResult](ctx, sql, args...) +func QueryOne[T any](ctx *DBContext, sql string, args ...interface{}) (*T, error) { + db := ctx.db + if ctx.tx != nil { + db = ctx.tx + } + + rows, err := db.Query(context.Background(), sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + return pgx.CollectOneRow(rows, pgx.RowToAddrOfStructByName[T]) +} + +// Execute a Custom SQL query and get many rows result. +// +// Usage: +// +// res, err := db.QueryMany[MyResult](ctx, sql, args...) +func QueryMany[T any](ctx *DBContext, sql string, args ...interface{}) ([]T, error) { + db := ctx.db + if ctx.tx != nil { + db = ctx.tx + } + + rows, err := db.Query(context.Background(), sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + return pgx.CollectRows(rows, pgx.RowToStructByName[T]) +} + +// Execute a Custom SQL query without result. +// +// Usage: +// +// err := db.Exec(ctx, sql, args...) +func Exec(ctx *DBContext, sql string, args ...interface{}) error { + db := ctx.db + if ctx.tx != nil { + db = ctx.tx + } + + _, err := db.Exec(context.Background(), sql, args...) + return err +} + +func first[T any](items []T, err error) (*T, error) { + if err != nil { + return nil, err + } + + if len(items) == 0 { + return nil, errors.New("first element not found") + } + + return &items[0], nil +} + +func limitFirst(cond SelectBuilder) SelectBuilder { + return cond.Offset(0).Limit(1) +} + +type DBContext struct { + db DBPgx + tx pgx.Tx +} + +type LogLevel int + +const ( + Debug LogLevel = iota + 1 + Info + Warn + Error +) + +func (ctx DBContext) logQuery(level LogLevel, msg string, err error, duration time.Duration, sql string, args ...interface{}) { + if err != nil { + log.Println("[\033[31mERROR\033[0m]\033[31m", msg, "\033[0m", "\n | \033[32mError:\033[0m", err, "\n | \033[32mArgs:\033[0m", args, "\n | \033[32mSQL:\033[0m", sql) + return + } + + if level < Warn && duration > time.Millisecond*500 { + level = Warn + msg = "[SLOW QUERY] " + msg + } + + switch level { + case Debug: + log.Println("[\033[35mDEBUG\033[0m]\033[33m", msg, "\033[0m", duration, "\n | \033[32mArgs:\033[0m", args, "\n | \033[32mSQL:\033[0m", sql) + case Info: + log.Println("[\033[34mINFO\033[0m]\033[33m", msg, "\033[0m", duration, "\n | \033[32mArgs:\033[0m", args, "\n | \033[32mSQL:\033[0m", sql) + case Warn: + log.Println("[\033[33mWARN\033[0m]\033[33m", msg, "\033[0m", duration, "\n | \033[32mArgs:\033[0m", args, "\n | \033[32mSQL:\033[0m", sql) + case Error: + log.Println("[\033[31mERROR\033[0m]\033[33m", msg, "\033[0m", duration, "\n | \033[32mArgs:\033[0m", args, "\n | \033[32mSQL:\033[0m", sql) + } +} + +type CustomQueries struct { + ctx *DBContext +} + +// Find TableNames records based on the provided conditions +// +// Usage: +// +// entities, err := db.Queries.TableNames( +// // ... can use filters here (cf db.TableNames.Query.*) +// ) +func (q *CustomQueries) TableNames(filters ...WhereCondition) (requestData []TableNamesModel, requestErr error) { + query := squirrel.Select("table_name") + query = query.From("information_schema.columns").PlaceholderFormat(placeholder) + query = query.Where("table_schema = 'public'") + for _, filter := range filters { + query = filter(query) + } + + sql, args, err := query.ToSql() + if err != nil { + return nil, err + } + start := time.Now() + defer func() { + q.ctx.logQuery(Debug, "DB.Queries.TableNames", requestErr, time.Since(start), sql, args) + }() + + return QueryMany[TableNamesModel](q.ctx, sql, args...) +} + +type TableNamesModel struct { + TableName *string `json:"table_name" db:"table_name"` +} + +// Find TableColumns records based on the provided conditions +// +// Usage: +// +// entities, err := db.Queries.TableColumns( +// // ... can use filters here (cf db.TableColumns.Query.*) +// ) +func (q *CustomQueries) TableColumns(filters ...WhereCondition) (requestData []TableColumnsModel, requestErr error) { + query := squirrel.Select("column_name, column_default, is_nullable, data_type, udt_name") + query = query.From("information_schema.columns").PlaceholderFormat(placeholder) + query = query.Where("table_schema = 'public'") + for _, filter := range filters { + query = filter(query) + } + + sql, args, err := query.ToSql() + if err != nil { + return nil, err + } + start := time.Now() + defer func() { + q.ctx.logQuery(Debug, "DB.Queries.TableColumns", requestErr, time.Since(start), sql, args) + }() + + return QueryMany[TableColumnsModel](q.ctx, sql, args...) +} + +type TableColumnsModel struct { + ColumnName *string `json:"column_name" db:"column_name"` + ColumnDefault *string `json:"column_default" db:"column_default"` + IsNullable *string `json:"is_nullable" db:"is_nullable"` + DataType *string `json:"data_type" db:"data_type"` + UdtName *string `json:"udt_name" db:"udt_name"` +} diff --git a/internal/database/postgres/introspect.go b/internal/database/postgres/introspect.go new file mode 100644 index 0000000..d31df19 --- /dev/null +++ b/internal/database/postgres/introspect.go @@ -0,0 +1,57 @@ +package database + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/kefniark/mango-sql/internal/core" +) + +func Parse(url string) (*core.SQLSchema, error) { + database, err := pgx.Connect(context.Background(), url) + if err != nil { + return nil, err + } + + db := New(database) + + tableNames, err := db.Queries.TableNames() + if err != nil { + return nil, err + } + + schema := &core.SQLSchema{ + Tables: make(map[string]*core.SQLTable), + } + + for i, tableName := range tableNames { + table := &core.SQLTable{ + Name: *tableName.TableName, + Columns: map[string]*core.SQLColumn{}, + Order: i, + } + + cols, err := db.Queries.TableColumns( + func(query SelectBuilder) SelectBuilder { + return query.Where("table_name = ?", tableName.TableName) + }, + ) + if err != nil { + return nil, err + } + + for _, col := range cols { + table.Columns[*col.ColumnName] = &core.SQLColumn{ + Name: *col.ColumnName, + HasDefault: col.ColumnDefault != nil, + Type: "string", + TypeSql: *col.UdtName, + Nullable: *col.IsNullable == "true", + } + } + + schema.Tables[*tableName.TableName] = table + } + + return schema, nil +} diff --git a/internal/database/postgres/queries.sql b/internal/database/postgres/queries.sql new file mode 100644 index 0000000..5ec222d --- /dev/null +++ b/internal/database/postgres/queries.sql @@ -0,0 +1,9 @@ +-- queryMany: TableNames +SELECT DISTINCT table_name +FROM information_schema.columns +WHERE table_schema = 'public'; + +-- queryMany: TableColumns +SELECT column_name, column_default, is_nullable, data_type, udt_name +FROM information_schema.columns +WHERE table_schema = 'public'; diff --git a/internal/database/templates/logger_none.tmpl b/internal/database/postgres/schema.sql similarity index 100% rename from internal/database/templates/logger_none.tmpl rename to internal/database/postgres/schema.sql diff --git a/internal/generator.go b/internal/generator.go index 10e6489..b634c99 100644 --- a/internal/generator.go +++ b/internal/generator.go @@ -5,13 +5,13 @@ import ( "io" "github.com/kefniark/mango-sql/internal/core" - "github.com/kefniark/mango-sql/internal/database" + "github.com/kefniark/mango-sql/internal/generator" ) func Generate(schema *core.SQLSchema, contents io.Writer, pkg string, driver string, logger string) error { switch driver { case "sqlite", "pq", "pgx": - return database.Generate(schema, contents, pkg, driver, logger) + return generator.Generate(schema, contents, pkg, driver, logger) } return fmt.Errorf("driver %s not supported", driver) diff --git a/internal/database/database.go b/internal/generator/database.go similarity index 99% rename from internal/database/database.go rename to internal/generator/database.go index f87313e..06fb634 100644 --- a/internal/database/database.go +++ b/internal/generator/database.go @@ -1,4 +1,4 @@ -package database +package generator import ( "bufio" @@ -32,6 +32,8 @@ var customTypes = []FieldInitializer{ } func Generate(schema *core.SQLSchema, contents io.Writer, pkg string, driver string, logger string) error { + fmt.Println(schema) + templateType := "" switch driver { case "pgx": @@ -273,6 +275,7 @@ func toPostgresTable(table *core.SQLTable, driver string) *PostgresTable { func toPostgresColumn(column *core.SQLColumn) *PostgresColumn { val := getColumnType(column) + fmt.Println(column, val) json := column.As if json == "" { @@ -398,7 +401,12 @@ func getPrimaryFields(table *core.SQLTable) []string { return table.Indexes[0].Columns } - return []string{} + cols := []string{} + for _, col := range table.Columns { + cols = append(cols, col.Name) + } + + return []string{cols[0]} } func getAutogeneratedFields() []string { diff --git a/internal/database/filters.go b/internal/generator/filters.go similarity index 99% rename from internal/database/filters.go rename to internal/generator/filters.go index 5f01f6b..062cc83 100644 --- a/internal/database/filters.go +++ b/internal/generator/filters.go @@ -1,4 +1,4 @@ -package database +package generator import ( "slices" diff --git a/internal/database/schema.go b/internal/generator/schema.go similarity index 98% rename from internal/database/schema.go rename to internal/generator/schema.go index c3bb0cc..8ee41d0 100644 --- a/internal/database/schema.go +++ b/internal/generator/schema.go @@ -1,4 +1,4 @@ -package database +package generator import "github.com/kefniark/mango-sql/internal/core" diff --git a/internal/database/templates/custom.tmpl b/internal/generator/templates/custom.tmpl similarity index 100% rename from internal/database/templates/custom.tmpl rename to internal/generator/templates/custom.tmpl diff --git a/internal/database/templates/factory_pgx.tmpl b/internal/generator/templates/factory_pgx.tmpl similarity index 100% rename from internal/database/templates/factory_pgx.tmpl rename to internal/generator/templates/factory_pgx.tmpl diff --git a/internal/database/templates/factory_pq.tmpl b/internal/generator/templates/factory_pq.tmpl similarity index 100% rename from internal/database/templates/factory_pq.tmpl rename to internal/generator/templates/factory_pq.tmpl diff --git a/internal/database/templates/header_pgx.tmpl b/internal/generator/templates/header_pgx.tmpl similarity index 100% rename from internal/database/templates/header_pgx.tmpl rename to internal/generator/templates/header_pgx.tmpl diff --git a/internal/database/templates/header_pq.tmpl b/internal/generator/templates/header_pq.tmpl similarity index 100% rename from internal/database/templates/header_pq.tmpl rename to internal/generator/templates/header_pq.tmpl diff --git a/internal/database/templates/logger_console.tmpl b/internal/generator/templates/logger_console.tmpl similarity index 100% rename from internal/database/templates/logger_console.tmpl rename to internal/generator/templates/logger_console.tmpl diff --git a/internal/database/templates/logger_logrus.tmpl b/internal/generator/templates/logger_logrus.tmpl similarity index 100% rename from internal/database/templates/logger_logrus.tmpl rename to internal/generator/templates/logger_logrus.tmpl diff --git a/internal/generator/templates/logger_none.tmpl b/internal/generator/templates/logger_none.tmpl new file mode 100644 index 0000000..e69de29 diff --git a/internal/database/templates/logger_zap.tmpl b/internal/generator/templates/logger_zap.tmpl similarity index 100% rename from internal/database/templates/logger_zap.tmpl rename to internal/generator/templates/logger_zap.tmpl diff --git a/internal/database/templates/logger_zerolog.tmpl b/internal/generator/templates/logger_zerolog.tmpl similarity index 100% rename from internal/database/templates/logger_zerolog.tmpl rename to internal/generator/templates/logger_zerolog.tmpl diff --git a/internal/database/templates/model.tmpl b/internal/generator/templates/model.tmpl similarity index 100% rename from internal/database/templates/model.tmpl rename to internal/generator/templates/model.tmpl diff --git a/internal/database/templates/queries.tmpl b/internal/generator/templates/queries.tmpl similarity index 100% rename from internal/database/templates/queries.tmpl rename to internal/generator/templates/queries.tmpl diff --git a/internal/database/types.go b/internal/generator/types.go similarity index 98% rename from internal/database/types.go rename to internal/generator/types.go index e78d0de..529473d 100644 --- a/internal/database/types.go +++ b/internal/generator/types.go @@ -1,4 +1,4 @@ -package database +package generator import ( "fmt" diff --git a/justfile b/justfile index 0da3e62..34663f6 100644 --- a/justfile +++ b/justfile @@ -10,7 +10,10 @@ lint: docs: npm run docs:dev -generate: +generate-internal: + go run ./cmd/mangosql/ --output ./internal/database/postgres/client.go --logger console ./internal/database/postgres/ + +generate: # tests go run ./cmd/mangosql/ --output ./tests/postgres/auto-increment/client.go --package autoincrement --logger console ./tests/postgres/auto-increment/schema.sql go run ./cmd/mangosql/ --output ./tests/postgres/composite/client.go --package composite --logger console ./tests/postgres/composite/schema.sql diff --git a/tests/postgres/introspection/client_test.go b/tests/postgres/introspection/client_test.go new file mode 100644 index 0000000..e8b55a5 --- /dev/null +++ b/tests/postgres/introspection/client_test.go @@ -0,0 +1,53 @@ +package introspection + +import ( + "bufio" + "bytes" + "embed" + "fmt" + "os" + "testing" + + introspect "github.com/kefniark/mango-sql/internal/database/postgres" + "github.com/kefniark/mango-sql/internal/generator" + "github.com/kefniark/mango-sql/tests/helpers" + "github.com/stretchr/testify/assert" +) + +//go:embed *.sql +var sqlFS embed.FS + +func newTestDB(t *testing.T) string { + data, err := sqlFS.ReadFile("seed.sql") + if err != nil { + panic(err) + } + + config := helpers.NewDBConfigWith(t, data, "postgres.introspection") + + return config.URL() +} + +func TestIntrospection(t *testing.T) { + url := newTestDB(t) + + schema, err := introspect.Parse(url) + assert.NoError(t, err) + + fmt.Println("Got", schema) + + var b bytes.Buffer + contents := bufio.NewWriter(&b) + + err = generator.Generate(schema, contents, "introspection", "pgx", "none") + assert.NoError(t, err) + assert.NoError(t, contents.Flush()) + + f, err := os.Create("client.go") + assert.NoError(t, err) + + defer f.Close() + + _, err = f.Write(b.Bytes()) + assert.NoError(t, err) +} diff --git a/tests/postgres/introspection/seed.sql b/tests/postgres/introspection/seed.sql new file mode 100644 index 0000000..f6919c4 --- /dev/null +++ b/tests/postgres/introspection/seed.sql @@ -0,0 +1,53 @@ +/* + * Create Tables + */ +create table if not exists rbac_permissions ( + id serial primary key, + lft integer not null, + rght integer not null, + title text not null, + description text not null +); +create index on rbac_permissions (lft); +create index on rbac_permissions (rght); +create index on rbac_permissions (title); + +create table if not exists rbac_rolepermissions ( + role_id integer not null, + permission_id integer not null, + assignment_date timestamptz not null, + primary key (role_id, permission_id) +); + +create table if not exists rbac_roles ( + id serial primary key, + lft integer not null, + rght integer not null, + title varchar not null, + description text not null +); +create index on rbac_roles (lft); +create index on rbac_roles (rght); +create index on rbac_roles (title); + +create table if not exists rbac_userroles ( + user_id integer not null, + role_id integer not null, + assignment_date timestamptz not null, + primary key (user_id, role_id) +); + +/* + * Insert Initial Table Data + */ +insert into rbac_permissions (id, lft, rght, title, description) +values (1, 0, 1, 'root', 'root'); + +insert into rbac_rolepermissions (role_id, permission_id, assignment_date) +values (1, 1, current_timestamp); + +insert into rbac_roles (id, lft, rght, title, description) +values (1, 0, 1, 'root', 'root'); + +insert into rbac_userroles (user_id, Role_id, assignment_date) +values (1, 1, current_timestamp); \ No newline at end of file