Skip to content

Commit 96be689

Browse files
authored
ast: Add support for CREATE TYPE as ENUM (#388)
MySQL and SQLite do not support CREATE TYPE ... ENUM, so this change mainly involves porting the existing enum support from the dinosql / catalog package to the sql/catalog package
1 parent 59bc936 commit 96be689

File tree

7 files changed

+220
-7
lines changed

7 files changed

+220
-7
lines changed

internal/compiler/compile.go

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"os"
9+
"regexp"
910
"sort"
1011
"strings"
1112

@@ -23,6 +24,35 @@ type Parser interface {
2324
Parse(io.Reader) ([]ast.Statement, error)
2425
}
2526

27+
// copied over from gen.go
28+
func structName(name string) string {
29+
out := ""
30+
for _, p := range strings.Split(name, "_") {
31+
if p == "id" {
32+
out += "ID"
33+
} else {
34+
out += strings.Title(p)
35+
}
36+
}
37+
return out
38+
}
39+
40+
var identPattern = regexp.MustCompile("[^a-zA-Z0-9_]+")
41+
42+
func enumValueName(value string) string {
43+
name := ""
44+
id := strings.Replace(value, "-", "_", -1)
45+
id = strings.Replace(id, ":", "_", -1)
46+
id = strings.Replace(id, "/", "_", -1)
47+
id = identPattern.ReplaceAllString(id, "")
48+
for _, part := range strings.Split(id, "_") {
49+
name += strings.Title(part)
50+
}
51+
return name
52+
}
53+
54+
// end copypasta
55+
2656
func Run(conf config.SQL, combo config.CombinedSettings) (*Result, error) {
2757
var p Parser
2858

@@ -53,25 +83,52 @@ func Run(conf config.SQL, combo config.CombinedSettings) (*Result, error) {
5383
}
5484

5585
var structs []dinosql.GoStruct
86+
var enums []dinosql.GoEnum
5687
for _, schema := range c.Schemas {
5788
for _, table := range schema.Tables {
5889
s := dinosql.GoStruct{
59-
Table: pg.FQN{Schema: table.Rel.Schema, Rel: table.Rel.Name},
90+
Table: pg.FQN{Schema: schema.Name, Rel: table.Rel.Name},
6091
Name: strings.Title(table.Rel.Name),
6192
}
6293
for _, col := range table.Columns {
6394
s.Fields = append(s.Fields, dinosql.GoField{
64-
Name: strings.Title(col.Name),
95+
Name: structName(col.Name),
6596
Type: "string",
6697
Tags: map[string]string{"json:": col.Name},
6798
})
6899
}
69100
structs = append(structs, s)
70101
}
102+
for _, typ := range schema.Types {
103+
switch t := typ.(type) {
104+
case catalog.Enum:
105+
var name string
106+
// TODO: This name should be public, not main
107+
if schema.Name == "main" {
108+
name = t.Name
109+
} else {
110+
name = schema.Name + "_" + t.Name
111+
}
112+
e := dinosql.GoEnum{
113+
Name: structName(name),
114+
}
115+
for _, v := range t.Vals {
116+
e.Constants = append(e.Constants, dinosql.GoConstant{
117+
Name: e.Name + enumValueName(v),
118+
Value: v,
119+
Type: e.Name,
120+
})
121+
}
122+
enums = append(enums, e)
123+
}
124+
}
71125
}
126+
72127
if len(structs) > 0 {
73128
sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name })
74129
}
75-
76-
return &Result{structs: structs}, nil
130+
if len(enums) > 0 {
131+
sort.Slice(enums, func(i, j int) bool { return enums[i].Name < enums[j].Name })
132+
}
133+
return &Result{structs: structs, enums: enums}, nil
77134
}

internal/compiler/result.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
)
77

88
type Result struct {
9+
enums []dinosql.GoEnum
910
structs []dinosql.GoStruct
1011
queries []dinosql.GoQuery
1112
}
@@ -19,5 +20,5 @@ func (r *Result) GoQueries(settings config.CombinedSettings) []dinosql.GoQuery {
1920
}
2021

2122
func (r *Result) Enums(settings config.CombinedSettings) []dinosql.GoEnum {
22-
return nil
23+
return r.enums
2324
}

internal/endtoend/testdata/experimental_elephant/go/models.go

Lines changed: 23 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/experimental_elephant/query.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ CREATE TABLE bar (
66
baz text NOT NULL
77
);
88

9+
CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');
10+
911
SELECT bar FROM foo;
1012

1113
DROP TABLE bar;

internal/postgresql/parse.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@ func stringSlice(list nodes.List) []string {
2222
return items
2323
}
2424

25+
func parseTypeName(node nodes.Node) (*ast.TypeName, error) {
26+
switch n := node.(type) {
27+
28+
case nodes.List:
29+
parts := stringSlice(n)
30+
switch len(parts) {
31+
case 1:
32+
return &ast.TypeName{
33+
Name: parts[0],
34+
}, nil
35+
case 2:
36+
return &ast.TypeName{
37+
Schema: parts[0],
38+
Name: parts[1],
39+
}, nil
40+
default:
41+
return nil, fmt.Errorf("invalid type name: %s", join(n, "."))
42+
}
43+
44+
default:
45+
return nil, fmt.Errorf("unexpected node type: %T", n)
46+
}
47+
}
48+
2549
func parseTableName(node nodes.Node) (*ast.TableName, error) {
2650
switch n := node.(type) {
2751

@@ -180,6 +204,25 @@ func translate(node nodes.Node) (ast.Node, error) {
180204
}
181205
return create, nil
182206

207+
case nodes.CreateEnumStmt:
208+
name, err := parseTypeName(n.TypeName)
209+
if err != nil {
210+
return nil, err
211+
}
212+
stmt := &ast.CreateEnumStmt{
213+
TypeName: name,
214+
Vals: &ast.List{},
215+
}
216+
for _, val := range n.Vals.Items {
217+
switch v := val.(type) {
218+
case nodes.String:
219+
stmt.Vals.Items = append(stmt.Vals.Items, &ast.String{
220+
Str: v.Str,
221+
})
222+
}
223+
}
224+
return stmt, nil
225+
183226
case nodes.DropStmt:
184227
drop := &ast.DropTableStmt{
185228
IfExists: n.MissingOk,

internal/sql/ast/ast.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ func (n *AlterTableCmd) Pos() int {
5757
return 0
5858
}
5959

60+
type CreateEnumStmt struct {
61+
TypeName *TypeName
62+
Vals *List
63+
}
64+
65+
func (n *CreateEnumStmt) Pos() int {
66+
return 0
67+
}
68+
6069
type CreateTableStmt struct {
6170
IfNotExists bool
6271
Name *TableName
@@ -88,7 +97,8 @@ func (n *ColumnDef) Pos() int {
8897
}
8998

9099
type TypeName struct {
91-
Name string
100+
Schema string
101+
Name string
92102
}
93103

94104
func (n *TypeName) Pos() int {
@@ -127,3 +137,11 @@ type ColumnRef struct {
127137
func (n *ColumnRef) Pos() int {
128138
return 0
129139
}
140+
141+
type String struct {
142+
Str string
143+
}
144+
145+
func (n *String) Pos() int {
146+
return 0
147+
}

internal/sql/catalog/catalog.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ func Build(stmts []ast.Statement) (*Catalog, error) {
2121
switch n := stmts[i].Raw.Stmt.(type) {
2222
case *ast.AlterTableStmt:
2323
err = c.alterTable(n)
24+
case *ast.CreateEnumStmt:
25+
err = c.createEnum(n)
2426
case *ast.CreateTableStmt:
2527
err = c.createTable(n)
2628
case *ast.DropTableStmt:
@@ -33,8 +35,19 @@ func Build(stmts []ast.Statement) (*Catalog, error) {
3335
return c, nil
3436
}
3537

38+
func stringSlice(list *ast.List) []string {
39+
items := []string{}
40+
for _, item := range list.Items {
41+
if n, ok := item.(*ast.String); ok {
42+
items = append(items, n.Str)
43+
}
44+
}
45+
return items
46+
}
47+
3648
// TODO: This need to be rich error types
3749
var ErrRelationNotFound = errors.New("relation not found")
50+
var ErrRelationAlreadyExists = errors.New("relation already exists")
3851
var ErrSchemaNotFound = errors.New("schema not found")
3952
var ErrColumnNotFound = errors.New("column not found")
4053
var ErrColumnExists = errors.New("column already exists")
@@ -159,6 +172,37 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error {
159172
return nil
160173
}
161174

175+
func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error {
176+
ns := stmt.TypeName.Schema
177+
if ns == "" {
178+
ns = c.DefaultSchema
179+
}
180+
schema, err := c.getSchema(ns)
181+
if err != nil {
182+
return err
183+
}
184+
// Because tables have associated data types, the type name must also
185+
// be distinct from the name of any existing table in the same
186+
// schema.
187+
// https://www.postgresql.org/docs/current/sql-createtype.html
188+
tbl := &ast.TableName{
189+
Name: stmt.TypeName.Name,
190+
}
191+
if _, _, err := schema.getTable(tbl); err == nil {
192+
// return wrap(pg.ErrorRelationAlreadyExists(fqn.Rel), raw.StmtLocation)
193+
return ErrRelationAlreadyExists
194+
}
195+
if _, err := schema.getType(stmt.TypeName); err == nil {
196+
// return wrap(pg.ErrorTypeAlreadyExists(fqn.Rel), raw.StmtLocation)
197+
return ErrRelationAlreadyExists
198+
}
199+
schema.Types = append(schema.Types, Enum{
200+
Name: stmt.TypeName.Name,
201+
Vals: stringSlice(stmt.Vals),
202+
})
203+
return nil
204+
}
205+
162206
func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error {
163207
ns := stmt.Name.Schema
164208
if ns == "" {
@@ -223,9 +267,22 @@ type Catalog struct {
223267
type Schema struct {
224268
Name string
225269
Tables []*Table
270+
Types []Type
226271
Comment string
227272
}
228273

274+
func (s *Schema) getType(rel *ast.TypeName) (Type, error) {
275+
for i := range s.Types {
276+
switch typ := s.Types[i].(type) {
277+
case Enum:
278+
if typ.Name == rel.Name {
279+
return s.Types[i], nil
280+
}
281+
}
282+
}
283+
return nil, ErrRelationNotFound
284+
}
285+
229286
func (s *Schema) getTable(rel *ast.TableName) (*Table, int, error) {
230287
for i := range s.Tables {
231288
if s.Tables[i].Rel.Name == rel.Name {
@@ -247,3 +304,16 @@ type Column struct {
247304
Type ast.TypeName
248305
IsNotNull bool
249306
}
307+
308+
type Type interface {
309+
isType()
310+
}
311+
312+
type Enum struct {
313+
Name string
314+
Vals []string
315+
Comment string
316+
}
317+
318+
func (e Enum) isType() {
319+
}

0 commit comments

Comments
 (0)