From 75cd90db2ddb2cfa962f72a76f5f905156678517 Mon Sep 17 00:00:00 2001 From: h3n4l Date: Thu, 22 Jan 2026 18:24:16 +0800 Subject: [PATCH] feat: add WithMaxRows option to cap find() and countDocuments() results Add functional options pattern to Execute() with WithMaxRows(n) that caps find() and countDocuments() results at the driver level. - Add ExecuteOption type and WithMaxRows() function - Apply min(queryLimit, maxRows) in executeFind and executeCountDocuments - Aggregate operations intentionally not affected - Add comprehensive tests for all limit scenarios Co-Authored-By: Claude Opus 4.5 --- README.md | 19 +++++ client.go | 25 ++++++- collection_test.go | 129 ++++++++++++++++++++++++++++++++ executor.go | 4 +- internal/executor/collection.go | 35 +++++++-- internal/executor/executor.go | 6 +- 6 files changed, 205 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 482a85c..8db623f 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,25 @@ func main() { } ``` +## Execute Options + +The `Execute` method accepts optional configuration: + +### WithMaxRows + +Limit the maximum number of rows returned by `find()` and `countDocuments()` operations. This is useful to prevent excessive memory usage or network traffic from unbounded queries. + +```go +// Cap results at 1000 rows +result, err := gc.Execute(ctx, "mydb", `db.users.find()`, gomongo.WithMaxRows(1000)) +``` + +**Behavior:** +- If the query includes `.limit(N)`, the effective limit is `min(N, maxRows)` +- Query limit 50 + MaxRows 1000 → returns up to 50 rows +- Query limit 5000 + MaxRows 1000 → returns up to 1000 rows +- `aggregate()` operations are not affected (use `$limit` stage instead) + ## Output Format Results are returned in Extended JSON (Relaxed) format: diff --git a/client.go b/client.go index ef97172..bda4599 100644 --- a/client.go +++ b/client.go @@ -23,8 +23,29 @@ type Result struct { Statement string } +// executeConfig holds configuration for Execute. +type executeConfig struct { + maxRows *int64 +} + +// ExecuteOption configures Execute behavior. +type ExecuteOption func(*executeConfig) + +// WithMaxRows limits the maximum number of rows returned by find() and +// countDocuments() operations. If the query includes .limit(N), the effective +// limit is min(N, maxRows). Aggregate operations are not affected. +func WithMaxRows(n int64) ExecuteOption { + return func(c *executeConfig) { + c.maxRows = &n + } +} + // Execute parses and executes a MongoDB shell statement. // Returns results as Extended JSON (Relaxed) strings. -func (c *Client) Execute(ctx context.Context, database, statement string) (*Result, error) { - return execute(ctx, c.client, database, statement) +func (c *Client) Execute(ctx context.Context, database, statement string, opts ...ExecuteOption) (*Result, error) { + cfg := &executeConfig{} + for _, opt := range opts { + opt(cfg) + } + return execute(ctx, c.client, database, statement, cfg.maxRows) } diff --git a/collection_test.go b/collection_test.go index 0cb13ac..37ef0f2 100644 --- a/collection_test.go +++ b/collection_test.go @@ -2153,3 +2153,132 @@ func TestCursorMinMaxCombined(t *testing.T) { require.NoError(t, err) require.Equal(t, 2, result.RowCount) } + +func TestWithMaxRowsCapsResults(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_maxrows_cap" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Insert 20 documents + collection := client.Database(dbName).Collection("items") + docs := make([]any, 20) + for i := 0; i < 20; i++ { + docs[i] = bson.M{"index": i} + } + _, err := collection.InsertMany(ctx, docs) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Without MaxRows - returns all 20 + result, err := gc.Execute(ctx, dbName, "db.items.find()") + require.NoError(t, err) + require.Equal(t, 20, result.RowCount) + + // With MaxRows(10) - caps at 10 + result, err = gc.Execute(ctx, dbName, "db.items.find()", gomongo.WithMaxRows(10)) + require.NoError(t, err) + require.Equal(t, 10, result.RowCount) +} + +func TestWithMaxRowsQueryLimitTakesPrecedence(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_maxrows_query_limit" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Insert 20 documents + collection := client.Database(dbName).Collection("items") + docs := make([]any, 20) + for i := 0; i < 20; i++ { + docs[i] = bson.M{"index": i} + } + _, err := collection.InsertMany(ctx, docs) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Query limit(5) is smaller than MaxRows(100) - should return 5 + result, err := gc.Execute(ctx, dbName, "db.items.find().limit(5)", gomongo.WithMaxRows(100)) + require.NoError(t, err) + require.Equal(t, 5, result.RowCount) +} + +func TestWithMaxRowsTakesPrecedenceOverLargerLimit(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_maxrows_precedence" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Insert 20 documents + collection := client.Database(dbName).Collection("items") + docs := make([]any, 20) + for i := 0; i < 20; i++ { + docs[i] = bson.M{"index": i} + } + _, err := collection.InsertMany(ctx, docs) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Query limit(100) is larger than MaxRows(5) - should return 5 + result, err := gc.Execute(ctx, dbName, "db.items.find().limit(100)", gomongo.WithMaxRows(5)) + require.NoError(t, err) + require.Equal(t, 5, result.RowCount) +} + +func TestExecuteBackwardCompatibility(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_backward_compat" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + collection := client.Database(dbName).Collection("items") + _, err := collection.InsertMany(ctx, []any{ + bson.M{"name": "a"}, + bson.M{"name": "b"}, + bson.M{"name": "c"}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Execute without options should work (backward compatible) + result, err := gc.Execute(ctx, dbName, "db.items.find()") + require.NoError(t, err) + require.Equal(t, 3, result.RowCount) +} + +func TestCountDocumentsWithMaxRows(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_count_maxrows" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Insert 100 documents + collection := client.Database(dbName).Collection("items") + docs := make([]any, 100) + for i := 0; i < 100; i++ { + docs[i] = bson.M{"index": i} + } + _, err := collection.InsertMany(ctx, docs) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Without MaxRows - counts all 100 + result, err := gc.Execute(ctx, dbName, "db.items.countDocuments()") + require.NoError(t, err) + require.Equal(t, "100", result.Rows[0]) + + // With MaxRows(50) - counts up to 50 + result, err = gc.Execute(ctx, dbName, "db.items.countDocuments()", gomongo.WithMaxRows(50)) + require.NoError(t, err) + require.Equal(t, "50", result.Rows[0]) +} diff --git a/executor.go b/executor.go index 4751014..a78b50c 100644 --- a/executor.go +++ b/executor.go @@ -9,7 +9,7 @@ import ( ) // execute parses and executes a MongoDB shell statement. -func execute(ctx context.Context, client *mongo.Client, database, statement string) (*Result, error) { +func execute(ctx context.Context, client *mongo.Client, database, statement string, maxRows *int64) (*Result, error) { op, err := translator.Parse(statement) if err != nil { // Convert internal errors to public errors @@ -33,7 +33,7 @@ func execute(ctx context.Context, client *mongo.Client, database, statement stri } } - result, err := executor.Execute(ctx, client, database, op, statement) + result, err := executor.Execute(ctx, client, database, op, statement, maxRows) if err != nil { return nil, err } diff --git a/internal/executor/collection.go b/internal/executor/collection.go index 0ac143c..0f6bb12 100644 --- a/internal/executor/collection.go +++ b/internal/executor/collection.go @@ -12,8 +12,27 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/options" ) +// computeEffectiveLimit returns the minimum of opLimit and maxRows. +// Returns nil if both are nil. +func computeEffectiveLimit(opLimit, maxRows *int64) *int64 { + if opLimit == nil && maxRows == nil { + return nil + } + if opLimit == nil { + return maxRows + } + if maxRows == nil { + return opLimit + } + // Both are non-nil, return the minimum + if *opLimit < *maxRows { + return opLimit + } + return maxRows +} + // executeFind executes a find operation. -func executeFind(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { +func executeFind(ctx context.Context, client *mongo.Client, database string, op *translator.Operation, maxRows *int64) (*Result, error) { collection := client.Database(database).Collection(op.Collection) filter := op.Filter @@ -25,8 +44,10 @@ func executeFind(ctx context.Context, client *mongo.Client, database string, op if op.Sort != nil { opts.SetSort(op.Sort) } - if op.Limit != nil { - opts.SetLimit(*op.Limit) + // Compute effective limit: min(op.Limit, maxRows) + effectiveLimit := computeEffectiveLimit(op.Limit, maxRows) + if effectiveLimit != nil { + opts.SetLimit(*effectiveLimit) } if op.Skip != nil { opts.SetSkip(*op.Skip) @@ -230,7 +251,7 @@ func executeGetIndexes(ctx context.Context, client *mongo.Client, database strin } // executeCountDocuments executes a db.collection.countDocuments() command. -func executeCountDocuments(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { +func executeCountDocuments(ctx context.Context, client *mongo.Client, database string, op *translator.Operation, maxRows *int64) (*Result, error) { collection := client.Database(database).Collection(op.Collection) filter := op.Filter @@ -242,8 +263,10 @@ func executeCountDocuments(ctx context.Context, client *mongo.Client, database s if op.Hint != nil { opts.SetHint(op.Hint) } - if op.Limit != nil { - opts.SetLimit(*op.Limit) + // Compute effective limit: min(op.Limit, maxRows) + effectiveLimit := computeEffectiveLimit(op.Limit, maxRows) + if effectiveLimit != nil { + opts.SetLimit(*effectiveLimit) } if op.Skip != nil { opts.SetSkip(*op.Skip) diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 6a1d139..3cc4a35 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -16,10 +16,10 @@ type Result struct { } // Execute executes a parsed operation against MongoDB. -func Execute(ctx context.Context, client *mongo.Client, database string, op *translator.Operation, statement string) (*Result, error) { +func Execute(ctx context.Context, client *mongo.Client, database string, op *translator.Operation, statement string, maxRows *int64) (*Result, error) { switch op.OpType { case translator.OpFind: - return executeFind(ctx, client, database, op) + return executeFind(ctx, client, database, op, maxRows) case translator.OpFindOne: return executeFindOne(ctx, client, database, op) case translator.OpAggregate: @@ -35,7 +35,7 @@ func Execute(ctx context.Context, client *mongo.Client, database string, op *tra case translator.OpGetIndexes: return executeGetIndexes(ctx, client, database, op) case translator.OpCountDocuments: - return executeCountDocuments(ctx, client, database, op) + return executeCountDocuments(ctx, client, database, op, maxRows) case translator.OpEstimatedDocumentCount: return executeEstimatedDocumentCount(ctx, client, database, op) case translator.OpDistinct: