Skip to content

Commit 67aaf00

Browse files
committed
Add limit iterator and context plumbing for --limit flag
1 parent 45afed3 commit 67aaf00

File tree

4 files changed

+229
-0
lines changed

4 files changed

+229
-0
lines changed

libs/cmdio/limit.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package cmdio
2+
3+
import (
4+
"context"
5+
6+
"github.com/databricks/databricks-sdk-go/listing"
7+
)
8+
9+
type limitKey struct{}
10+
11+
// WithLimit stores the limit in the context.
12+
func WithLimit(ctx context.Context, n int) context.Context {
13+
return context.WithValue(ctx, limitKey{}, n)
14+
}
15+
16+
// GetLimit retrieves the limit from context. Returns 0 if not set.
17+
func GetLimit(ctx context.Context) int {
18+
v, ok := ctx.Value(limitKey{}).(int)
19+
if !ok {
20+
return 0
21+
}
22+
return v
23+
}
24+
25+
type limitIterator[T any] struct {
26+
inner listing.Iterator[T]
27+
remaining int
28+
}
29+
30+
func (l *limitIterator[T]) HasNext(ctx context.Context) bool {
31+
return l.remaining > 0 && l.inner.HasNext(ctx)
32+
}
33+
34+
func (l *limitIterator[T]) Next(ctx context.Context) (T, error) {
35+
v, err := l.inner.Next(ctx)
36+
if err != nil {
37+
return v, err
38+
}
39+
l.remaining--
40+
return v, nil
41+
}
42+
43+
// ApplyLimit wraps a listing.Iterator to yield at most the limit from context.
44+
// It returns the iterator unchanged if the limit is not positive.
45+
func ApplyLimit[T any](ctx context.Context, i listing.Iterator[T]) listing.Iterator[T] {
46+
if limit := GetLimit(ctx); limit > 0 {
47+
return &limitIterator[T]{inner: i, remaining: limit}
48+
}
49+
return i
50+
}

libs/cmdio/limit_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package cmdio_test
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
"github.com/databricks/cli/libs/cmdio"
9+
"github.com/databricks/databricks-sdk-go/listing"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
type sliceIterator[T any] struct {
15+
items []T
16+
}
17+
18+
func (s *sliceIterator[T]) HasNext(_ context.Context) bool {
19+
return len(s.items) > 0
20+
}
21+
22+
func (s *sliceIterator[T]) Next(_ context.Context) (T, error) {
23+
if len(s.items) == 0 {
24+
var zero T
25+
return zero, errors.New("no more items")
26+
}
27+
item := s.items[0]
28+
s.items = s.items[1:]
29+
return item, nil
30+
}
31+
32+
func drain[T any](ctx context.Context, iter listing.Iterator[T]) ([]T, error) {
33+
var result []T
34+
for iter.HasNext(ctx) {
35+
v, err := iter.Next(ctx)
36+
if err != nil {
37+
return result, err
38+
}
39+
result = append(result, v)
40+
}
41+
return result, nil
42+
}
43+
44+
type errorIterator[T any] struct {
45+
items []T
46+
failAt int
47+
callCount int
48+
}
49+
50+
func (e *errorIterator[T]) HasNext(_ context.Context) bool {
51+
return e.callCount <= e.failAt && e.callCount < len(e.items)
52+
}
53+
54+
func (e *errorIterator[T]) Next(_ context.Context) (T, error) {
55+
idx := e.callCount
56+
e.callCount++
57+
if idx == e.failAt {
58+
var zero T
59+
return zero, errors.New("fetch error")
60+
}
61+
return e.items[idx], nil
62+
}
63+
64+
func TestWithLimitRoundTrip(t *testing.T) {
65+
ctx := cmdio.WithLimit(t.Context(), 42)
66+
assert.Equal(t, 42, cmdio.GetLimit(ctx))
67+
}
68+
69+
func TestGetLimitReturnsZeroWhenNotSet(t *testing.T) {
70+
assert.Equal(t, 0, cmdio.GetLimit(t.Context()))
71+
}
72+
73+
func TestApplyLimit(t *testing.T) {
74+
tests := []struct {
75+
name string
76+
limit int
77+
setLimit bool
78+
items []int
79+
want []int
80+
}{
81+
{
82+
name: "caps results",
83+
limit: 5,
84+
setLimit: true,
85+
items: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
86+
want: []int{1, 2, 3, 4, 5},
87+
},
88+
{
89+
name: "no-op when unset",
90+
items: []int{1, 2, 3},
91+
want: []int{1, 2, 3},
92+
},
93+
{
94+
name: "greater than total",
95+
limit: 10,
96+
setLimit: true,
97+
items: []int{1, 2, 3},
98+
want: []int{1, 2, 3},
99+
},
100+
{
101+
name: "one",
102+
limit: 1,
103+
setLimit: true,
104+
items: []int{1, 2, 3},
105+
want: []int{1},
106+
},
107+
{
108+
name: "zero",
109+
limit: 0,
110+
setLimit: true,
111+
items: []int{1, 2, 3},
112+
want: []int{1, 2, 3},
113+
},
114+
{
115+
name: "negative",
116+
limit: -1,
117+
setLimit: true,
118+
items: []int{1, 2, 3},
119+
want: []int{1, 2, 3},
120+
},
121+
}
122+
123+
for _, tt := range tests {
124+
t.Run(tt.name, func(t *testing.T) {
125+
ctx := t.Context()
126+
if tt.setLimit {
127+
ctx = cmdio.WithLimit(ctx, tt.limit)
128+
}
129+
130+
iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: tt.items})
131+
132+
result, err := drain(t.Context(), iter)
133+
require.NoError(t, err)
134+
assert.Equal(t, tt.want, result)
135+
})
136+
}
137+
}
138+
139+
func TestApplyLimitPreservesErrors(t *testing.T) {
140+
ctx := cmdio.WithLimit(t.Context(), 5)
141+
iter := cmdio.ApplyLimit(ctx, &errorIterator[int]{items: []int{1, 2, 3}, failAt: 2})
142+
143+
result, err := drain(t.Context(), iter)
144+
assert.ErrorContains(t, err, "fetch error")
145+
assert.Equal(t, []int{1, 2}, result)
146+
}

libs/cmdio/render.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ func Render(ctx context.Context, v any) error {
264264
}
265265

266266
func RenderIterator[T any](ctx context.Context, i listing.Iterator[T]) error {
267+
i = ApplyLimit(ctx, i)
267268
c := fromContext(ctx)
268269
return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template)
269270
}
@@ -277,11 +278,13 @@ func RenderWithTemplate(ctx context.Context, v any, headerTemplate, template str
277278
}
278279

279280
func RenderIteratorWithTemplate[T any](ctx context.Context, i listing.Iterator[T], headerTemplate, template string) error {
281+
i = ApplyLimit(ctx, i)
280282
c := fromContext(ctx)
281283
return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, headerTemplate, template)
282284
}
283285

284286
func RenderIteratorJson[T any](ctx context.Context, i listing.Iterator[T]) error {
287+
i = ApplyLimit(ctx, i)
285288
c := fromContext(ctx)
286289
return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template)
287290
}

libs/cmdio/render_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,36 @@ var testCases = []testCase{
167167
},
168168
}
169169

170+
func TestRenderIteratorWithLimit(t *testing.T) {
171+
output := &bytes.Buffer{}
172+
ctx := t.Context()
173+
cmdIO := NewIO(ctx, flags.OutputText, nil, output, output,
174+
"id\tname",
175+
"{{range .}}{{.WorkspaceId}}\t{{.WorkspaceName}}\n{{end}}")
176+
ctx = InContext(ctx, cmdIO)
177+
ctx = WithLimit(ctx, 3)
178+
179+
err := RenderIterator(ctx, makeIterator(10))
180+
assert.NoError(t, err)
181+
assert.Equal(t, "id name\n"+makeBigOutput(3), output.String())
182+
}
183+
184+
func TestRenderIteratorWithLimitJSON(t *testing.T) {
185+
output := &bytes.Buffer{}
186+
ctx := t.Context()
187+
cmdIO := NewIO(ctx, flags.OutputJSON, nil, output, output, "", "")
188+
ctx = InContext(ctx, cmdIO)
189+
ctx = WithLimit(ctx, 2)
190+
191+
err := RenderIterator(ctx, makeIterator(10))
192+
assert.NoError(t, err)
193+
194+
var items []provisioning.Workspace
195+
err = json.Unmarshal(output.Bytes(), &items)
196+
assert.NoError(t, err)
197+
assert.Len(t, items, 2)
198+
}
199+
170200
func TestRender(t *testing.T) {
171201
for _, c := range testCases {
172202
t.Run(c.name, func(t *testing.T) {

0 commit comments

Comments
 (0)