Skip to content

Commit 7177fd7

Browse files
db.In, db.NotIn iterating slices (#21)
1 parent 7c55126 commit 7177fd7

3 files changed

Lines changed: 116 additions & 8 deletions

File tree

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,13 @@ test: test-clean
4242
test-all: test-clean
4343
GOGC=off go test $(TEST_FLAGS) $(MOD_VENDOR) -run=$(TEST) ./...
4444

45+
test-all-tparse: test-clean
46+
GOGC=off go test $(TEST_FLAGS) $(MOD_VENDOR) -run=$(TEST) ./... -json | tparse --follow
47+
4548
test-with-reset: db-reset test-all
4649

50+
test-with-reset-tparse: db-reset test-all-tparse
51+
4752
test-clean:
4853
GOGC=off go clean -testcache
4954

db/cond.go

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package db
22

33
import (
44
"fmt"
5+
"reflect"
56
"strings"
67

78
"github.com/Masterminds/squirrel"
@@ -80,7 +81,6 @@ func (n *binaryExprNode) ToSql() (string, []interface{}, error) {
8081
func compileNodes(nodes []squirrel.Sqlizer) (q string, args []interface{}, err error) {
8182
for i, node := range nodes {
8283
qn, argsn, err := node.ToSql()
83-
8484
if err != nil {
8585
return "", nil, fmt.Errorf("error compiling node %d: %w", i, err)
8686
}
@@ -203,12 +203,12 @@ func NotILike(v interface{}) squirrel.Sqlizer {
203203

204204
// In represents an IN operator. The value must be variadic.
205205
func In[T interface{}](v ...T) squirrel.Sqlizer {
206-
return Func[T]("IN", v...)
206+
return Func("IN", v...)
207207
}
208208

209209
// NotIn represents a NOT IN operator. The value must be variadic.
210210
func NotIn[T interface{}](v ...T) squirrel.Sqlizer {
211-
return Func[T]("NOT IN", v...)
211+
return Func("NOT IN", v...)
212212
}
213213

214214
// Raw represents a raw SQL expression.
@@ -226,15 +226,58 @@ func Func[T interface{}](name string, params ...T) squirrel.Sqlizer {
226226
}
227227

228228
places := make([]string, len(params))
229-
args := make([]interface{}, 0, len(params))
230229

230+
// iterating through slices
231+
if reflect.TypeOf(params[0]).Kind() == reflect.Slice {
232+
elements := 0
233+
for _, subSlice := range params {
234+
v := reflect.ValueOf(subSlice)
235+
elements += v.Len()
236+
}
237+
238+
args := make([]interface{}, 0, elements)
239+
240+
for i, subSlice := range params {
241+
subSliceVal := reflect.ValueOf(subSlice)
242+
subPlaces := make([]string, subSliceVal.Len())
243+
244+
for j := 0; j < subSliceVal.Len(); j++ {
245+
val := subSliceVal.Index(j).Interface()
246+
if sqlizer, ok := interface{}(val).(squirrel.Sqlizer); ok {
247+
paramSQL, paramArgs, err := sqlizer.ToSql()
248+
if err != nil {
249+
return "", nil, fmt.Errorf("%s: error compiling argument %d: %w", name, i, err)
250+
}
251+
252+
subPlaces[j] = paramSQL
253+
args = append(args, paramArgs...)
254+
} else if reflect.TypeOf(val).Kind() == reflect.Slice {
255+
v := reflect.ValueOf(val)
256+
for k := 0; k < v.Len(); k++ {
257+
subPlaces[j] = paramPlaceholder
258+
args = append(args, v.Index(k).Interface())
259+
}
260+
} else {
261+
subPlaces[j] = paramPlaceholder
262+
args = append(args, val)
263+
}
264+
}
265+
266+
places[i] = "(" + strings.Join(subPlaces, ",") + ")"
267+
}
268+
269+
return name + " (" + strings.Join(places, ",") + ")", args, nil
270+
}
271+
272+
args := make([]interface{}, 0, len(params))
231273
for i, param := range params {
232274
if sqlizer, ok := interface{}(param).(squirrel.Sqlizer); ok {
233-
paramSql, paramArgs, err := sqlizer.ToSql()
275+
paramSQL, paramArgs, err := sqlizer.ToSql()
234276
if err != nil {
235277
return "", nil, fmt.Errorf("%s: error compiling argument %d: %w", name, i, err)
236278
}
237-
places[i] = paramSql
279+
280+
places[i] = paramSQL
238281
args = append(args, paramArgs...)
239282
} else {
240283
places[i] = paramPlaceholder

db/cond_test.go

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ import (
44
"testing"
55

66
sq "github.com/Masterminds/squirrel"
7-
"github.com/goware/pgkit/v2/db"
87
"github.com/stretchr/testify/assert"
98
"github.com/stretchr/testify/require"
9+
10+
"github.com/goware/pgkit/v2/db"
1011
)
1112

1213
func TestCond(t *testing.T) {
13-
1414
t.Run("equal to", func(t *testing.T) {
1515
cond := db.Cond{"one": 1}
1616
s, args, err := cond.ToSql()
@@ -19,6 +19,14 @@ func TestCond(t *testing.T) {
1919
assert.Equal(t, "one = ?", s)
2020
})
2121

22+
t.Run("equal to with multiple parameters", func(t *testing.T) {
23+
cond := db.And{db.Cond{"one": 1}, db.Cond{"two": 2}}
24+
s, args, err := cond.ToSql()
25+
require.NoError(t, err)
26+
assert.Equal(t, []interface{}{1, 2}, args)
27+
assert.Equal(t, "(one = ? AND two = ?)", s)
28+
})
29+
2230
t.Run("equal to (inverted)", func(t *testing.T) {
2331
cond := db.Cond{1: "one"}
2432
s, args, err := cond.ToSql()
@@ -64,6 +72,16 @@ func TestCond(t *testing.T) {
6472
})
6573

6674
t.Run("IN with slice", func(t *testing.T) {
75+
sl1 := []int{1, 2, 3}
76+
cond := db.Cond{"list": db.In(sl1...)}
77+
s, args, err := cond.ToSql()
78+
require.NoError(t, err)
79+
80+
assert.Equal(t, []interface{}{1, 2, 3}, args)
81+
assert.Equal(t, "list IN (?, ?, ?)", s)
82+
})
83+
84+
t.Run("IN with slice variadic", func(t *testing.T) {
6785
cond := db.Cond{"list": db.In(1, 2, 3)}
6886
s, args, err := cond.ToSql()
6987
require.NoError(t, err)
@@ -72,6 +90,48 @@ func TestCond(t *testing.T) {
7290
assert.Equal(t, "list IN (?, ?, ?)", s)
7391
})
7492

93+
t.Run("multiple IN with slice", func(t *testing.T) {
94+
sl1 := []int{1, 2, 3}
95+
sl2 := []int{4, 5, 6}
96+
cond := db.Cond{"list": db.In([]interface{}{sl1, sl2}...)}
97+
s, args, err := cond.ToSql()
98+
require.NoError(t, err)
99+
100+
assert.Equal(t, []interface{}{1, 2, 3, 4, 5, 6}, args)
101+
assert.Equal(t, "list IN ((?,?,?),(?,?,?))", s)
102+
})
103+
104+
t.Run("multiple IN with slice AND where ID", func(t *testing.T) {
105+
cond := db.And{db.Cond{"list": db.In([][]string{{"1", "2", "3"}, {"3", "4", "5"}}...)}, db.Cond{"id": 1}}
106+
s, args, err := cond.ToSql()
107+
require.NoError(t, err)
108+
109+
assert.Equal(t, []interface{}{"1", "2", "3", "3", "4", "5", 1}, args)
110+
assert.Equal(t, "(list IN ((?,?,?),(?,?,?)) AND id = ?)", s)
111+
})
112+
113+
t.Run("multiple IN with struct", func(t *testing.T) {
114+
randomStruct := []struct {
115+
Id uint64
116+
Name string
117+
}{
118+
{Id: 1, Name: "Lukas"},
119+
{Id: 2, Name: "David"},
120+
}
121+
122+
data := [][]interface{}{}
123+
for _, s := range randomStruct {
124+
data = append(data, []interface{}{s.Id, s.Name})
125+
}
126+
127+
cond := db.Cond{"list": db.In(data...)}
128+
s, args, err := cond.ToSql()
129+
require.NoError(t, err)
130+
131+
assert.Equal(t, []interface{}{uint64(1), "Lukas", uint64(2), "David"}, args)
132+
assert.Equal(t, "list IN ((?,?),(?,?))", s)
133+
})
134+
75135
t.Run("NOT IN", func(t *testing.T) {
76136
cond := db.Cond{"list": db.NotIn("Czech Republic", "Slovakia")}
77137
s, args, err := cond.ToSql()

0 commit comments

Comments
 (0)