diff --git a/lib/parallel.go b/lib/parallel.go new file mode 100644 index 0000000..87ea4bd --- /dev/null +++ b/lib/parallel.go @@ -0,0 +1,787 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +package lib + +import ( + "fmt" + "runtime" + "slices" + "sync" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" + "github.com/google/cel-go/parser" +) + +// Parallel returns a cel.EnvOption that configures the parallel macro for +// concurrent application of an expression to every element of a list or map. +// +// # Call forms +// +// parallel mirrors the call forms of the built-in map macro: +// +// // (1) One-variable — apply expr to every element. +// .parallel(, ) -> list +// .parallel(, ) -> list +// +// // (2a) Predicate filter — apply expr only where pred is true. +// .parallel(, , ) -> list +// .parallel(, , ) -> list +// +// // (2b) Two-variable — bind index/key and value simultaneously. +// .parallel(, , ) -> list +// .parallel(, , ) -> list +// +// Forms (2a) and (2b) are distinguished by their second argument: if it is a +// plain identifier it is treated as iterVar2 (two-variable form); otherwise it +// is treated as a boolean predicate (filter form). This is the same rule used +// by map. +// +// For list ranges in the two-variable form, iterVar is the zero-based element +// index and iterVar2 is the element value. For map ranges, iterVar is the key +// and iterVar2 is the corresponding value. +// +// # Concurrency +// +// perCall controls the maximum number of goroutines used simultaneously by a +// single .parallel() call. If perCall ≤ 0 it defaults to runtime.GOMAXPROCS(0). +// When perCall == 1 the macro degenerates to the sequential map macro: no +// goroutines are spawned and there is no overhead. +// +// globalCap bounds the total number of concurrent leaf body evaluations across +// all .parallel() calls in the program. If globalCap ≤ 0, no global limit is +// applied. globalCap < perCall is valid: outer goroutines driving nested +// .parallel() calls do useful dispatch work, not just leaf evaluations. +// +// # Concurrency safety +// +// Body and predicate expressions must be safe for concurrent evaluation. +// Custom functions called from inside a parallel body must not rely on shared +// mutable state unless they provide their own synchronisation. +// +// # Error handling +// +// If any element's expression returns an error, the first error in iteration +// order is returned and no result list is produced, consistent with how CEL's +// built-in map macro propagates errors. When one element errors, remaining +// evaluations still run to completion — there is no early termination. +// +// # Notes +// +// A single Parallel lib instance must not be shared across concurrent +// cel.Env constructions. One instance per pipeline configuration is the +// expected pattern, consistent with all other mito libs. +// +// # Examples +// +// // Fetch every URL concurrently. +// urls.parallel(u, get(u)) +// +// // Fetch only HTTPS URLs concurrently. +// urls.parallel(u, u.startsWith("https://"), get(u)) +// +// // Two-variable list form: enrich each item with its position. +// items.parallel(i, item, item.with({"seq": i})) +// +// // Two-variable map form: key and value both directly in scope. +// endpoints.parallel(svc, url, get(url).Body.decode_json().with({"svc":svc})) +func Parallel(perCall, globalCap int) cel.EnvOption { + var globalSem chan struct{} + if globalCap > 0 { + globalSem = make(chan struct{}, globalCap) + } + return cel.Lib(¶llelLib{perCall: perCall, globalSem: globalSem}) +} + +// parallelLib implements cel.Library. +// +// # Design: three sentinel functions +// +// The cel-go interpreter plans a comprehension as an unexported *evalFold +// whose sub-expressions (iterRange, step, etc.) are not accessible via any +// public interface. Rather than using an oracle strategy (which requires +// synthetic collections and range-variable shadowing with a known limitation +// for complex range expressions), we emit additional sentinel function calls +// inside the AST so that the planned sub-Interpretables are recoverable via +// InterpretableCall.Args(): +// +// @parallel(rangeExpr, fold) +// fold step contains: @parallel_body(transformExpr) +// fold step also contains (predicate form only): @parallel_pred(predExpr) +// +// Three decorators intercept these sentinel calls by function name: +// - @parallel_body: stores Args()[0] (transform) in bodyImpls +// - @parallel_pred: stores Args()[0] (predicate) in predImpls +// - @parallel: retrieves iterRange from Args()[0], body from bodyImpls, +// pred from predImpls, and builds a parallelFold +// +// parallelFold.Eval calls body.Eval(elementActivation) directly — no oracle, +// no synthetic collection, no range-variable shadowing. +// +// # Why we never call parser.MakeMap inside emitSentinel +// +// parser.MakeMap uses parser.AccumulatorName ("@result") as the accumulator +// variable in the comprehension it emits. When that comprehension is embedded +// as an argument to the @parallel sentinel call, the type-checker sees "@result" +// as a free variable reference and rejects it. We therefore always emit our +// own comprehensions using parallelAccuVar ("@parallel_result"), which we +// declare explicitly as a cel.Variable in CompileOptions. +type parallelLib struct { + perCall int + globalSem chan struct{} // nil when no global limit + + mu sync.Mutex + + // bodies maps each @parallel sentinel call's AST node ID to the per-call + // metadata recorded at macro expansion time. + bodies map[int64]parallelBody + + // bodyImpls maps each @parallel sentinel call's AST node ID to the planned + // transform Interpretable, stored by the @parallel_body decorator. + bodyImpls map[int64]interpreter.Interpretable + + // predImpls maps each @parallel sentinel call's AST node ID to the planned + // predicate Interpretable, stored by the @parallel_pred decorator. + // Only populated for the predicate filter form. + predImpls map[int64]interpreter.Interpretable +} + +// parallelBody carries compile-time metadata for one parallel call site. +type parallelBody struct { + iterVar string // first iteration variable + iterVar2 string // second iteration variable; non-empty for two-variable form + hasPred bool // true for the (iterVar, pred, expr) form + nestedParallel bool // true when the body expression contains a nested @parallel + bodyCallID int64 // node ID of the @parallel_body sentinel call + predCallID int64 // node ID of the @parallel_pred sentinel call; zero if !hasPred +} + +// CompileOptions registers the three sentinel function declarations, the +// parallelAccuVar variable declaration, and both macro overloads. +// +// The parallelAccuVar declaration is required because the inner comprehensions +// reference it as an accumulator, and the type-checker treats accumulator +// references as variable lookups when the comprehension is nested inside a +// function call argument. +func (l *parallelLib) CompileOptions() []cel.EnvOption { + return []cel.EnvOption{ + // Declare parallelAccuVar so the type-checker accepts its reference + // inside the comprehension step and result expressions. + cel.Variable(parallelAccuVar, cel.ListType(cel.DynType)), + // @parallel(dyn, dyn) -> list(dyn) + cel.Function(parallelSentinel, + cel.Overload(parallelSentinel+"_impl", + []*cel.Type{cel.DynType, cel.DynType}, + cel.ListType(cel.DynType), + ), + ), + // @parallel_body(dyn) -> dyn + cel.Function(parallelBodySentinel, + cel.Overload(parallelBodySentinel+"_impl", + []*cel.Type{cel.DynType}, + cel.DynType, + ), + ), + // @parallel_pred(dyn) -> dyn + cel.Function(parallelPredSentinel, + cel.Overload(parallelPredSentinel+"_impl", + []*cel.Type{cel.DynType}, + cel.DynType, + ), + ), + cel.Macros( + // (1) range.parallel(v, expr) + cel.ReceiverMacro("parallel", 2, l.makeParallel2), + // (2a/b) range.parallel(v, pred_or_v2, expr) + cel.ReceiverMacro("parallel", 3, l.makeParallel3), + ), + } +} + +// ProgramOptions installs the three decorators. The body and predicate +// decorators must run before the parallel decorator so that by the time +// parallelDecorator fires, all Interpretables are already stored. +func (l *parallelLib) ProgramOptions() []cel.ProgramOption { + perCall := l.perCall + if perCall <= 0 { + perCall = runtime.GOMAXPROCS(0) + } + return []cel.ProgramOption{ + cel.CustomDecorator(l.parallelBodyDecorator()), + cel.CustomDecorator(l.parallelPredDecorator()), + cel.CustomDecorator(l.parallelDecorator(perCall)), + } +} + +// makeParallel2 handles: range.parallel(iterVar, expr) +// +// When perCall == 1 it degenerates to map(iterVar, expr) by emitting a plain +// comprehension that the standard interpreter executes sequentially. +func (l *parallelLib) makeParallel2(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + iterVar, err := requireIdent(mef, args[0]) + if err != nil { + return nil, err + } + if l.perCall < 2 { + return l.makeSeqComprehension(mef, target, iterVar, nil, args[1]), nil + } + bodyCall := mef.NewCall(parallelBodySentinel, args[1]) + pb := parallelBody{ + iterVar: iterVar, + nestedParallel: containsParallel(args[1]), + bodyCallID: bodyCall.ID(), + } + return l.emitSentinel(mef, target, pb, func(rangeExpr ast.Expr) ast.Expr { + accu := mef.NewIdent(parallelAccuVar) + return mef.NewComprehension( + rangeExpr, iterVar, parallelAccuVar, + mef.NewList(), + mef.NewLiteral(types.True), + mef.NewCall(operators.Add, accu, mef.NewList(bodyCall)), + mef.NewIdent(parallelAccuVar), + ) + }) +} + +// makeParallel3 handles: +// +// (2a) range.parallel(iterVar, pred, expr) - predicate filter form +// (2b) range.parallel(iterVar, iterVar2, expr) - two-variable form +// +// If args[1] is a plain identifier it is iterVar2 (2b); otherwise it is a +// boolean predicate (2a). This matches map's own disambiguation rule. +func (l *parallelLib) makeParallel3(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + iterVar, err := requireIdent(mef, args[0]) + if err != nil { + return nil, err + } + + if args[1].Kind() == ast.IdentKind { + // Two-variable form (2b). + iterVar2, err := requireIdent(mef, args[1]) + if err != nil { + return nil, err + } + if iterVar == iterVar2 { + return nil, mef.NewError(args[1].ID(), "parallel: iteration variables must be distinct") + } + if l.perCall < 2 { + return l.makeSeqComprehensionTwoVar(mef, target, iterVar, iterVar2, args[2]), nil + } + bodyCall := mef.NewCall(parallelBodySentinel, args[2]) + pb := parallelBody{ + iterVar: iterVar, + iterVar2: iterVar2, + nestedParallel: containsParallel(args[2]), + bodyCallID: bodyCall.ID(), + } + return l.emitSentinel(mef, target, pb, func(rangeExpr ast.Expr) ast.Expr { + accu := mef.NewIdent(parallelAccuVar) + return mef.NewComprehensionTwoVar( + rangeExpr, iterVar, iterVar2, parallelAccuVar, + mef.NewList(), + mef.NewLiteral(types.True), + mef.NewCall(operators.Add, accu, mef.NewList(bodyCall)), + mef.NewIdent(parallelAccuVar), + ) + }) + } + + // Predicate filter form (2a). + if l.perCall < 2 { + return l.makeSeqComprehension(mef, target, iterVar, args[1], args[2]), nil + } + predCall := mef.NewCall(parallelPredSentinel, args[1]) + bodyCall := mef.NewCall(parallelBodySentinel, args[2]) + pb := parallelBody{ + iterVar: iterVar, + hasPred: true, + nestedParallel: containsParallel(args[2]), + predCallID: predCall.ID(), + bodyCallID: bodyCall.ID(), + } + return l.emitSentinel(mef, target, pb, func(rangeExpr ast.Expr) ast.Expr { + accu := mef.NewIdent(parallelAccuVar) + // step: accu + (@parallel_pred(pred) ? [@parallel_body(expr)] : []) + return mef.NewComprehension( + rangeExpr, iterVar, parallelAccuVar, + mef.NewList(), + mef.NewLiteral(types.True), + mef.NewCall(operators.Add, + accu, + mef.NewCall(operators.Conditional, + predCall, + mef.NewList(bodyCall), + mef.NewList(), + ), + ), + mef.NewIdent(parallelAccuVar), + ) + }) +} + +// makeSeqComprehension emits a sequential one-variable comprehension using +// parallelAccuVar as the accumulator, used for the perCall==1 path. +// If pred is non-nil it is incorporated as a conditional step (filter form). +func (l *parallelLib) makeSeqComprehension(mef cel.MacroExprFactory, target ast.Expr, iterVar string, pred, body ast.Expr) ast.Expr { + accu := mef.NewIdent(parallelAccuVar) + bodyList := mef.NewList(mef.Copy(body)) + var step ast.Expr + if pred != nil { + step = mef.NewCall(operators.Add, + accu, + mef.NewCall(operators.Conditional, mef.Copy(pred), bodyList, mef.NewList()), + ) + } else { + step = mef.NewCall(operators.Add, accu, bodyList) + } + return mef.NewComprehension( + target, iterVar, parallelAccuVar, + mef.NewList(), + mef.NewLiteral(types.True), + step, + mef.NewIdent(parallelAccuVar), + ) +} + +// makeSeqComprehensionTwoVar emits a sequential two-variable comprehension +// using parallelAccuVar as the accumulator, used for the perCall==1 path. +func (l *parallelLib) makeSeqComprehensionTwoVar(mef cel.MacroExprFactory, target ast.Expr, iterVar, iterVar2 string, body ast.Expr) ast.Expr { + accu := mef.NewIdent(parallelAccuVar) + return mef.NewComprehensionTwoVar( + target, iterVar, iterVar2, parallelAccuVar, + mef.NewList(), + mef.NewLiteral(types.True), + mef.NewCall(operators.Add, accu, mef.NewList(mef.Copy(body))), + mef.NewIdent(parallelAccuVar), + ) +} + +// emitSentinel builds and records the @parallel(rangeExpr, fold) AST node. +// foldBuilder receives a copy of target as its range expression and must +// embed the bodyCall (and predCall for the filter form) in the fold step. +func (l *parallelLib) emitSentinel(mef cel.MacroExprFactory, target ast.Expr, pb parallelBody, foldBuilder func(rangeExpr ast.Expr) ast.Expr) (ast.Expr, *common.Error) { + fold := foldBuilder(mef.Copy(target)) + sentinel := mef.NewCall(parallelSentinel, mef.Copy(target), fold) + l.recordBody(sentinel.ID(), pb) + return sentinel, nil +} + +// parallelSentinel is the name of the no-op sentinel function the parallel +// macros emit as the outermost wrapper. Its two planned arguments are +// accessible to the decorator via the public InterpretableCall.Args() API: +// +// args[0] → planned range Interpretable (evaluates to the list or map) +// args[1] → planned fold Interpretable (discarded at runtime; present +// only so the type-checker sees a well-formed comprehension) +// +// The sentinel is never called at runtime: the decorator replaces every +// InterpretableCall whose Function() is parallelSentinel with a parallelFold. +const parallelSentinel = "@parallel" + +// requireIdent asserts that e is a plain identifier and returns its name. +func requireIdent(mef cel.MacroExprFactory, e ast.Expr) (string, *common.Error) { + if e.Kind() != ast.IdentKind { + return "", mef.NewError(e.ID(), "parallel: iteration variable must be a simple identifier") + } + name := e.AsIdent() + if name == parallelAccuVar || name == parser.AccumulatorName || name == parser.HiddenAccumulatorName { + return "", mef.NewError(e.ID(), "parallel: iteration variable conflicts with internal accumulator") + } + return name, nil +} + +// containsParallel reports whether expr contains a nested @parallel sentinel +// call anywhere in its subtree. Macros expand bottom-up, so when the outer +// parallel macro fires the inner is already rewritten to @parallel(...). +func containsParallel(expr ast.Expr) bool { + switch expr.Kind() { + case ast.CallKind: + c := expr.AsCall() + if c.FunctionName() == parallelSentinel { + return true + } + if c.IsMemberFunction() && containsParallel(c.Target()) { + return true + } + return slices.ContainsFunc(c.Args(), containsParallel) + case ast.ComprehensionKind: + c := expr.AsComprehension() + return containsParallel(c.IterRange()) || + containsParallel(c.AccuInit()) || + containsParallel(c.LoopCondition()) || + containsParallel(c.LoopStep()) || + containsParallel(c.Result()) + case ast.ListKind: + return slices.ContainsFunc(expr.AsList().Elements(), containsParallel) + case ast.MapKind: + m := expr.AsMap() + for _, entry := range m.Entries() { + e := entry.AsMapEntry() + if containsParallel(e.Key()) || containsParallel(e.Value()) { + return true + } + } + case ast.SelectKind: + return containsParallel(expr.AsSelect().Operand()) + case ast.StructKind: + for _, f := range expr.AsStruct().Fields() { + if containsParallel(f.AsStructField().Value()) { + return true + } + } + } + return false +} + +func (l *parallelLib) recordBody(id int64, pb parallelBody) { + l.mu.Lock() + defer l.mu.Unlock() + if l.bodies == nil { + l.bodies = make(map[int64]parallelBody) + } + l.bodies[id] = pb +} + +// parallelAccuVar is the accumulator variable name used in the inner folds +// emitted by the parallel macros. The leading '@' places it outside the space +// of valid user-written identifiers, preventing collisions. +// +// It must also be declared as a cel.Variable in CompileOptions so that the +// type-checker accepts the accumulator reference inside the comprehension body. +const parallelAccuVar = "@parallel_result" + +// parallelBodyDecorator intercepts planned @parallel_body calls, stores +// Args()[0] as the transform Interpretable keyed on the parent @parallel +// call's ID, and returns Args()[0] to make the sentinel transparent. +func (l *parallelLib) parallelBodyDecorator() interpreter.InterpretableDecorator { + return func(i interpreter.Interpretable) (interpreter.Interpretable, error) { + call, ok := i.(interpreter.InterpretableCall) + if !ok || call.Function() != parallelBodySentinel { + return i, nil + } + args := call.Args() + if len(args) != 1 { + return nil, fmt.Errorf("parallel: @parallel_body has %d args, expected 1", len(args)) + } + l.mu.Lock() + for parallelID, pb := range l.bodies { + if pb.bodyCallID == call.ID() { + if l.bodyImpls == nil { + l.bodyImpls = make(map[int64]interpreter.Interpretable) + } + l.bodyImpls[parallelID] = args[0] + break + } + } + l.mu.Unlock() + return args[0], nil + } +} + +// parallelBodySentinel is the name of a second no-op sentinel that wraps the +// per-element transform expression inside the fold step emitted by the macro. +// Because it is planned as an InterpretableCall, its sole argument — the +// planned transform Interpretable — is recoverable via Args()[0] using the +// fully public API. A dedicated decorator intercepts these calls and stores +// the transform Interpretable in parallelLib.bodyImpls. +// +// This eliminates the oracle strategy entirely: parallelFold.Eval calls +// body.Eval(elementActivation) directly, with no synthetic collections and +// no range-variable shadowing. Complex range expressions such as +// state.items.parallel(...) work correctly. +// +// TODO: if cel-go ever exports an InterpretableFold interface that exposes +// the step Interpretable of a planned comprehension, both sentinel functions +// become unnecessary and the whole implementation simplifies to a type +// assertion in the decorator. Worth filing upstream. +const parallelBodySentinel = "@parallel_body" + +// parallelPredDecorator intercepts planned @parallel_pred calls, stores +// Args()[0] as the predicate Interpretable keyed on the parent @parallel +// call's ID, and returns Args()[0] to make the sentinel transparent. +func (l *parallelLib) parallelPredDecorator() interpreter.InterpretableDecorator { + return func(i interpreter.Interpretable) (interpreter.Interpretable, error) { + call, ok := i.(interpreter.InterpretableCall) + if !ok || call.Function() != parallelPredSentinel { + return i, nil + } + args := call.Args() + if len(args) != 1 { + return nil, fmt.Errorf("parallel: @parallel_pred has %d args, expected 1", len(args)) + } + l.mu.Lock() + for parallelID, pb := range l.bodies { + if pb.predCallID == call.ID() { + if l.predImpls == nil { + l.predImpls = make(map[int64]interpreter.Interpretable) + } + l.predImpls[parallelID] = args[0] + break + } + } + l.mu.Unlock() + return args[0], nil + } +} + +// parallelPredSentinel is the name of a third no-op sentinel that wraps the +// predicate expression in the filter form range.parallel(v, pred, expr). +// It is treated identically to parallelBodySentinel: a dedicated decorator +// intercepts it and stores the planned predicate Interpretable in +// parallelLib.predImpls so that parallelFold.Eval can gate each body +// evaluation on the predicate without needing the fold's conditional step. +const parallelPredSentinel = "@parallel_pred" + +// parallelDecorator intercepts planned @parallel calls and replaces them with +// parallelFold Interpretables. Detection is by function name, which is robust +// against node ID renumbering by constant-folding optimisations. +func (l *parallelLib) parallelDecorator(perCall int) interpreter.InterpretableDecorator { + return func(i interpreter.Interpretable) (interpreter.Interpretable, error) { + call, ok := i.(interpreter.InterpretableCall) + if !ok || call.Function() != parallelSentinel { + return i, nil + } + args := call.Args() + if len(args) != 2 { + return nil, fmt.Errorf("parallel: sentinel call has %d args, expected 2", len(args)) + } + iterRange := args[0] // args[1] is the fold; discarded here + + l.mu.Lock() + pb, found := l.bodies[call.ID()] + body := l.bodyImpls[call.ID()] + pred := l.predImpls[call.ID()] // nil if !hasPred + l.mu.Unlock() + + // A @parallel call with no registered metadata means this decorator is + // running against an AST compiled by a different parallelLib instance. + // This is always a programming error: return a clear message rather than + // silently leaving an unevaluable sentinel call in the program. + if !found { + return nil, fmt.Errorf( + "parallel: @parallel call %d has no registered metadata; "+ + "ensure the same Parallel() lib instance is passed to both "+ + "cel.NewEnv and env.Program", call.ID()) + } + if body == nil { + return nil, fmt.Errorf( + "parallel: no body Interpretable for call %d; "+ + "ensure the same Parallel() lib instance is passed to both "+ + "cel.NewEnv and env.Program", call.ID()) + } + if pb.hasPred && pred == nil { + return nil, fmt.Errorf( + "parallel: no predicate Interpretable for call %d; "+ + "ensure the same Parallel() lib instance is passed to both "+ + "cel.NewEnv and env.Program", call.ID()) + } + + return ¶llelFold{ + id: call.ID(), + iterVar: pb.iterVar, + iterVar2: pb.iterVar2, + hasPred: pb.hasPred, + leafBody: !pb.nestedParallel, + iterRange: iterRange, + body: body, + pred: pred, + perCall: perCall, + globalSem: l.globalSem, + }, nil + } +} + +// parallelFold is the concurrent Interpretable that replaces each @parallel +// sentinel call at runtime. +// +// For each element the planned body Interpretable is evaluated directly: +// +// body.Eval(interpreter.NewHierarchicalActivation(outer, elementActivation)) +// +// where elementActivation binds iterVar (and iterVar2 for the two-variable +// form) to the current element's value(s). The outer activation provides all +// other names, including the range expression itself, so complex receivers +// such as state.items.parallel(...) work correctly. +// +// For the predicate form, pred.Eval(elementActivation) is called first; the +// body is skipped if the predicate is false, and the element is omitted from +// the result list. +// +// Elements are dispatched through a semaphore of size perCall. Results are +// written into a pre-allocated slice at the element's original index so that +// output order matches input order regardless of goroutine completion order. +type parallelFold struct { + id int64 + iterVar string + iterVar2 string + hasPred bool + leafBody bool // true when body is not itself a nested parallelFold + iterRange interpreter.Interpretable + body interpreter.Interpretable + pred interpreter.Interpretable // nil when !hasPred + perCall int + globalSem chan struct{} // nil when no global limit +} + +// ID implements interpreter.Interpretable. +func (p *parallelFold) ID() int64 { return p.id } + +// Eval implements interpreter.Interpretable. +func (p *parallelFold) Eval(activation interpreter.Activation) ref.Val { + // Evaluate the range expression once. + rangeVal := p.iterRange.Eval(activation) + if types.IsError(rangeVal) { + return rangeVal + } + + // Collect per-goroutine element descriptors. + type elem struct { + primary ref.Val // iterVar binding + secondary ref.Val // iterVar2 binding; zero value for one-variable forms + } + + var elems []elem + twoVar := p.iterVar2 != "" + + switch rv := rangeVal.(type) { + case traits.Lister: + it := rv.Iterator() + if twoVar { + for idx := types.Int(0); it.HasNext() == types.True; idx++ { + elems = append(elems, elem{primary: idx, secondary: it.Next()}) + } + } else { + for it.HasNext() == types.True { + elems = append(elems, elem{primary: it.Next()}) + } + } + case traits.Mapper: + it := rv.Iterator() + if twoVar { + for it.HasNext() == types.True { + k := it.Next() + elems = append(elems, elem{primary: k, secondary: rv.Get(k)}) + } + } else { + for it.HasNext() == types.True { + elems = append(elems, elem{primary: it.Next()}) + } + } + default: + return types.NewErr("parallel: receiver must be a list or map, got %T", rangeVal) + } + + n := len(elems) + if n == 0 { + return types.DefaultTypeAdapter.NativeToValue([]ref.Val{}) + } + + var filtered []bool + if p.hasPred { + filtered = make([]bool, n) + } + results := make([]ref.Val, n) + + sem := make(chan struct{}, p.perCall) + var wg sync.WaitGroup + var errMu sync.Mutex + var seenErr ref.Val + + // Concurrent body.Eval and pred.Eval calls below assume that + // Interpretable.Eval is safe for concurrent use when each call receives + // an independent Activation. This holds for cel-go v0.28 — the interpreter + // structures (evalConst, evalAttr, evalOr, evalFold, etc.) carry no mutable + // per-evaluation state. This is not a documented guarantee. + // TestParallel_ConcurrentEvalRace is a race-detector trip wire that will + // catch regressions if a future cel-go release introduces shared mutable + // state into any Interpretable implementation. + for i, e := range elems { + wg.Add(1) + sem <- struct{}{} + go func(idx int, e elem) { + defer wg.Done() + defer func() { <-sem }() + + if p.leafBody && p.globalSem != nil { + p.globalSem <- struct{}{} + defer func() { <-p.globalSem }() + } + + bindings := map[string]any{p.iterVar: e.primary} + if twoVar { + bindings[p.iterVar2] = e.secondary + } + childActivation, err := interpreter.NewActivation(bindings) + if err != nil { + v := types.WrapErr(err) + results[idx] = v + errMu.Lock() + if seenErr == nil { + seenErr = v + } + errMu.Unlock() + return + } + hier := interpreter.NewHierarchicalActivation(activation, childActivation) + + // Predicate check. + if p.hasPred { + predVal := p.pred.Eval(hier) + if types.IsError(predVal) { + results[idx] = predVal + errMu.Lock() + if seenErr == nil { + seenErr = predVal + } + errMu.Unlock() + return + } + if predVal != types.True { + filtered[idx] = true + return + } + } + + result := p.body.Eval(hier) + results[idx] = result + if types.IsError(result) { + errMu.Lock() + if seenErr == nil { + seenErr = result + } + errMu.Unlock() + } + }(i, e) + } + + wg.Wait() + + // Surface the first error in iteration order (deterministic across runs). + if seenErr != nil { + for _, r := range results { + if types.IsError(r) { + return r + } + } + return seenErr // unreachable in practice + } + + // Assemble the result list, omitting filtered elements. + out := make([]ref.Val, 0, n) + for i, r := range results { + if filtered != nil && filtered[i] { + continue + } + out = append(out, r) + } + return types.DefaultTypeAdapter.NativeToValue(out) +} diff --git a/lib/parallel_test.go b/lib/parallel_test.go new file mode 100644 index 0000000..5b74594 --- /dev/null +++ b/lib/parallel_test.go @@ -0,0 +1,730 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +package lib + +import ( + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" +) + +// form (1): range.parallel(v, expr) + +func TestParallel_List_Basic(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, x * 2)`) + out := parallelEval(t, prg, map[string]any{"items": []any{1, 2, 3, 4, 5}}) + assertOrderedList(t, out, []ref.Val{ + types.Int(2), types.Int(4), types.Int(6), types.Int(8), types.Int(10), + }) +} + +func TestParallel_List_Empty(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, x)`) + out := parallelEval(t, prg, map[string]any{"items": []any{}}) + assertOrderedList(t, out, nil) +} + +func TestParallel_List_OuterScopeVisible(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, x + suffix)`) + out := parallelEval(t, prg, map[string]any{ + "items": []any{"hello", "world"}, + "suffix": "!", + }) + assertOrderedList(t, out, []ref.Val{types.String("hello!"), types.String("world!")}) +} + +func TestParallel_List_PreservesOrder(t *testing.T) { + env := parallelTestEnv(t, 8) + prg := parallelProgram(t, env, `items.parallel(x, x)`) + input := make([]any, 200) + for i := range input { + input[i] = i + } + for run := 0; run < 20; run++ { + out := parallelEval(t, prg, map[string]any{"items": input}) + list := out.(traits.Lister) + sz := int(list.Size().(types.Int)) + for i := 0; i < sz; i++ { + got := list.Get(types.Int(i)) + if got.(types.Int) != types.Int(i) { + t.Fatalf("run %d index %d: got %v want %d", run, i, got, i) + } + } + } +} + +func TestParallel_List_ActivationIsolation(t *testing.T) { + // Each goroutine must see its own iterVar, not a shared value. + env := parallelTestEnv(t, 8) + prg := parallelProgram(t, env, `items.parallel(x, x * x)`) + input := make([]any, 50) + for i := range input { + input[i] = i + } + out := parallelEval(t, prg, map[string]any{"items": input}) + want := make([]ref.Val, 50) + for i := range want { + want[i] = types.Int(i * i) + } + assertOrderedList(t, out, want) +} + +func TestParallel_Map_OneVar(t *testing.T) { + // Over a map range, iterVar is bound to each key. + // The outer activation keeps the map itself in scope. + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `m.parallel(k, m[k] * 10)`) + out := parallelEval(t, prg, map[string]any{ + "m": map[string]any{"a": 1, "b": 2, "c": 3}, + }) + assertMultiset(t, out, []ref.Val{types.Int(10), types.Int(20), types.Int(30)}) +} + +// form (2a): range.parallel(v, pred, expr) + +func TestParallel_Pred_Basic(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, x % 2 == 0, x * 2)`) + out := parallelEval(t, prg, map[string]any{"items": []any{1, 2, 3, 4}}) + assertOrderedList(t, out, []ref.Val{types.Int(4), types.Int(8)}) +} + +func TestParallel_Pred_NonePass(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, x > 100, x)`) + out := parallelEval(t, prg, map[string]any{"items": []any{1, 2, 3}}) + assertOrderedList(t, out, nil) +} + +func TestParallel_Pred_AllPass(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, x > 0, x * 3)`) + out := parallelEval(t, prg, map[string]any{"items": []any{1, 2, 3}}) + assertOrderedList(t, out, []ref.Val{types.Int(3), types.Int(6), types.Int(9)}) +} + +func TestParallel_Pred_PreservesOrder(t *testing.T) { + // Output must be in input order among the elements that pass the predicate. + env := parallelTestEnv(t, 8) + prg := parallelProgram(t, env, `items.parallel(x, x % 2 == 0, x)`) + input := make([]any, 100) + for i := range input { + input[i] = i + } + for run := 0; run < 10; run++ { + out := parallelEval(t, prg, map[string]any{"items": input}) + list := out.(traits.Lister) + sz := int(list.Size().(types.Int)) + if sz != 50 { + t.Fatalf("run %d: expected 50 results, got %d", run, sz) + } + for i := 0; i < sz; i++ { + got := list.Get(types.Int(i)) + if got.(types.Int) != types.Int(i*2) { + t.Fatalf("run %d index %d: got %v want %d", run, i, got, i*2) + } + } + } +} + +func TestParallel_Pred_Map(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `m.parallel(k, k.startsWith("h"), m[k] * 10)`) + out := parallelEval(t, prg, map[string]any{ + "m": map[string]any{"hello": 1, "world": 2, "hi": 3}, + }) + assertMultiset(t, out, []ref.Val{types.Int(10), types.Int(30)}) +} + +func TestParallel_Pred_OuterScopeVisible(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, x > threshold, x)`) + out := parallelEval(t, prg, map[string]any{ + "items": []any{1, 2, 3, 4, 5}, + "threshold": 3, + }) + assertOrderedList(t, out, []ref.Val{types.Int(4), types.Int(5)}) +} + +// form (2b): range.parallel(v, v2, expr) + +func TestParallel_TwoVar_List(t *testing.T) { + // iterVar = index, iterVar2 = element value. + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(i, v, v + "_" + string(i))`) + out := parallelEval(t, prg, map[string]any{"items": []any{"a", "b", "c"}}) + assertOrderedList(t, out, []ref.Val{ + types.String("a_0"), types.String("b_1"), types.String("c_2"), + }) +} + +func TestParallel_TwoVar_List_IndexCorrect(t *testing.T) { + // The index variable must carry the real position, not always 0. + env := parallelTestEnv(t, 8) + prg := parallelProgram(t, env, `items.parallel(i, v, i)`) + input := make([]any, 50) + for i := range input { + input[i] = i * 100 + } + for run := 0; run < 10; run++ { + out := parallelEval(t, prg, map[string]any{"items": input}) + list := out.(traits.Lister) + sz := int(list.Size().(types.Int)) + for j := 0; j < sz; j++ { + got := list.Get(types.Int(j)) + if got.(types.Int) != types.Int(j) { + t.Fatalf("run %d index %d: got %v want %d", run, j, got, j) + } + } + } +} + +func TestParallel_TwoVar_Map(t *testing.T) { + // iterVar = key, iterVar2 = value — value directly available without map lookup. + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `m.parallel(k, v, v * 100)`) + out := parallelEval(t, prg, map[string]any{ + "m": map[string]any{"x": 1, "y": 2}, + }) + assertMultiset(t, out, []ref.Val{types.Int(100), types.Int(200)}) +} + +// assertMultiset checks that val is a CEL list whose elements, treated as a +// multiset, equal want. Used when output order is non-deterministic (map ranges). +func assertMultiset(t *testing.T, val ref.Val, want []ref.Val) { + t.Helper() + list, ok := val.(traits.Lister) + if !ok { + t.Fatalf("expected list, got %T: %v", val, val) + } + sz := int(list.Size().(types.Int)) + if sz != len(want) { + t.Fatalf("list length: got %d want %d", sz, len(want)) + } + freq := make(map[string]int, len(want)) + for _, w := range want { + freq[fmt.Sprint(w)]++ + } + for i := 0; i < sz; i++ { + s := fmt.Sprint(list.Get(types.Int(i))) + if freq[s] == 0 { + t.Errorf("unexpected element %q in result", s) + } + freq[s]-- + } +} + +func TestParallel_TwoVar_DistinctVarsRequired(t *testing.T) { + env := parallelTestEnv(t, 0) + _, issues := env.Compile(`items.parallel(x, x, x)`) + if issues == nil || issues.Err() == nil { + t.Fatal("expected compile error for duplicate iteration variables, got none") + } +} + +// sequential degeneration (perCall == 1) + +func TestParallel_Sequential_OneVar(t *testing.T) { + env := parallelTestEnv(t, 1) + prg := parallelProgram(t, env, `items.parallel(x, x + 1)`) + out := parallelEval(t, prg, map[string]any{"items": []any{0, 1, 2}}) + assertOrderedList(t, out, []ref.Val{types.Int(1), types.Int(2), types.Int(3)}) +} + +func TestParallel_Sequential_Pred(t *testing.T) { + env := parallelTestEnv(t, 1) + prg := parallelProgram(t, env, `items.parallel(x, x > 1, x * 10)`) + out := parallelEval(t, prg, map[string]any{"items": []any{1, 2, 3}}) + assertOrderedList(t, out, []ref.Val{types.Int(20), types.Int(30)}) +} + +func TestParallel_Sequential_TwoVar(t *testing.T) { + env := parallelTestEnv(t, 1) + prg := parallelProgram(t, env, `items.parallel(i, v, string(i) + ":" + v)`) + out := parallelEval(t, prg, map[string]any{"items": []any{"a", "b"}}) + assertOrderedList(t, out, []ref.Val{types.String("0:a"), types.String("1:b")}) +} + +// error handling + +func TestParallel_Error_Surfaced(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, 1 / x)`) + parallelEvalExpectErr(t, prg, map[string]any{"items": []any{1, 0, 2}}) +} + +func TestParallel_Error_FirstInIterationOrder(t *testing.T) { + // items[0] and items[2] both produce errors. The call must not panic, + // deadlock, or return nil, and must return an error deterministically. + env := parallelTestEnv(t, 4) + prg := parallelProgram(t, env, `items.parallel(x, 1 / x)`) + parallelEvalExpectErr(t, prg, map[string]any{"items": []any{0, 1, 0}}) +} + +func TestParallel_Pred_Error_InPredicate(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `items.parallel(x, 1/x > 0, x)`) + parallelEvalExpectErr(t, prg, map[string]any{"items": []any{1, 0, 2}}) +} + +func parallelEvalExpectErr(t *testing.T, prg cel.Program, vars map[string]any) { + t.Helper() + _, _, err := prg.Eval(vars) + if err == nil { + t.Fatal("expected eval error, got nil") + } +} + +// concurrency properties + +func TestParallel_ActuallyConcurrent(t *testing.T) { + // Verify goroutines run concurrently by tracking peak in-flight count. + var inFlight, peak atomic.Int64 + + trackFn := cel.Function("track", + cel.Overload("track_int", + []*cel.Type{cel.IntType}, + cel.IntType, + cel.UnaryBinding(func(v ref.Val) ref.Val { + cur := inFlight.Add(1) + for { + p := peak.Load() + if cur <= p || peak.CompareAndSwap(p, cur) { + break + } + } + // Hold in-flight long enough for goroutines to overlap. + time.Sleep(time.Millisecond) + inFlight.Add(-1) + return v + }), + ), + ) + + env := parallelTestEnv(t, 8, trackFn) + prg := parallelProgram(t, env, `items.parallel(x, track(x))`) + input := make([]any, 32) + for i := range input { + input[i] = i + } + out := parallelEval(t, prg, map[string]any{"items": input}) + want := make([]ref.Val, 32) + for i := range want { + want[i] = types.Int(i) + } + assertOrderedList(t, out, want) + + if peak.Load() <= 1 { + t.Errorf("expected concurrent execution (peak in-flight > 1), got %d", peak.Load()) + } +} + +func TestParallel_BoundedConcurrency(t *testing.T) { + // With perCall=2, no more than 2 goroutines should be in-flight at once. + var inFlight atomic.Int64 + var overLimit atomic.Bool + + trackFn := cel.Function("track2", + cel.Overload("track2_int", + []*cel.Type{cel.IntType}, + cel.IntType, + cel.UnaryBinding(func(v ref.Val) ref.Val { + cur := inFlight.Add(1) + if cur > 2 { + overLimit.Store(true) + } + ch := make(chan struct{}) + go func() { close(ch) }() + <-ch + inFlight.Add(-1) + return v + }), + ), + ) + + env := parallelTestEnv(t, 2, trackFn) + prg := parallelProgram(t, env, `items.parallel(x, track2(x))`) + input := make([]any, 20) + for i := range input { + input[i] = i + } + parallelEval(t, prg, map[string]any{"items": input}) + + if overLimit.Load() { + t.Error("concurrency limit violated: more than 2 goroutines in-flight simultaneously") + } +} + +// nesting + +func TestParallel_NestedInMap(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `outer.map(row, row.parallel(x, x * 2))`) + out := parallelEval(t, prg, map[string]any{ + "outer": []any{[]any{1, 2}, []any{3, 4}}, + }) + outerList := out.(traits.Lister) + if int(outerList.Size().(types.Int)) != 2 { + t.Fatalf("outer size: got %d want 2", outerList.Size()) + } + assertOrderedList(t, outerList.Get(types.Int(0)), []ref.Val{types.Int(2), types.Int(4)}) + assertOrderedList(t, outerList.Get(types.Int(1)), []ref.Val{types.Int(6), types.Int(8)}) +} + +func TestParallel_MapNestedInParallel(t *testing.T) { + env := parallelTestEnv(t, 0) + prg := parallelProgram(t, env, `outer.parallel(row, row.map(x, x + 1))`) + out := parallelEval(t, prg, map[string]any{ + "outer": []any{[]any{0, 1}, []any{2, 3}}, + }) + outerList := out.(traits.Lister) + if int(outerList.Size().(types.Int)) != 2 { + t.Fatalf("outer size: got %d want 2", outerList.Size()) + } + assertOrderedList(t, outerList.Get(types.Int(0)), []ref.Val{types.Int(1), types.Int(2)}) + assertOrderedList(t, outerList.Get(types.Int(1)), []ref.Val{types.Int(3), types.Int(4)}) +} + +func TestParallel_NestedInParallel(t *testing.T) { + env := parallelTestEnv(t, 4) + prg := parallelProgram(t, env, `outer.parallel(row, row.parallel(x, x * 2))`) + out := parallelEval(t, prg, map[string]any{ + "outer": []any{[]any{1, 2, 3}, []any{4, 5, 6}}, + }) + outerList := out.(traits.Lister) + if int(outerList.Size().(types.Int)) != 2 { + t.Fatalf("outer size: got %d want 2", outerList.Size()) + } + assertOrderedList(t, outerList.Get(types.Int(0)), []ref.Val{types.Int(2), types.Int(4), types.Int(6)}) + assertOrderedList(t, outerList.Get(types.Int(1)), []ref.Val{types.Int(8), types.Int(10), types.Int(12)}) +} + +// compile-time validation + +func TestParallel_CompileError_NonIdentIterVar(t *testing.T) { + env := parallelTestEnv(t, 0) + _, issues := env.Compile(`items.parallel(1 + 1, x)`) + if issues == nil || issues.Err() == nil { + t.Fatal("expected compile error for non-identifier iterVar, got none") + } +} + +func TestParallel_CompileError_AccuVarConflict(t *testing.T) { + env := parallelTestEnv(t, 0) + _, issues := env.Compile(`items.parallel(__result__, x)`) + if issues == nil || issues.Err() == nil { + t.Fatal("expected compile error for reserved accumulator name, got none") + } +} + +// lib instance isolation + +func TestParallel_WrongLibInstance_ReturnsError(t *testing.T) { + // Compile with lib1, then construct a fresh env with lib2 (different + // instance) and attempt to program the AST using lib2's env. lib2 has no + // record of the @parallel sentinel call emitted during lib1's compilation, + // so its parallelDecorator must return an error. + lib1 := ¶llelLib{perCall: 2} + lib2 := ¶llelLib{perCall: 2} + + env1, err := cel.NewEnv( + cel.Lib(lib1), + cel.Variable("items", cel.ListType(cel.DynType)), + ) + if err != nil { + t.Fatalf("cel.NewEnv lib1: %v", err) + } + ast, issues := env1.Compile(`items.parallel(x, x)`) + if issues != nil && issues.Err() != nil { + t.Fatalf("compile: %v", issues.Err()) + } + + // Build a second env with lib2 only. lib2 has no bodies map entry for the + // @parallel node that lib1's macro emitted, so decoration must fail. + env2, err := cel.NewEnv( + cel.Lib(lib2), + cel.Variable("items", cel.ListType(cel.DynType)), + ) + if err != nil { + t.Fatalf("cel.NewEnv lib2: %v", err) + } + _, progErr := env2.Program(ast) + if progErr == nil { + t.Error("expected error when programming AST compiled by different lib instance, got nil") + } +} + +// activation contract + +func TestParallel_HierarchicalActivation(t *testing.T) { + // The body must see both the iter variable and outer state simultaneously. + env := parallelTestEnv(t, 4) + prg := parallelProgram(t, env, `items.parallel(x, x + base)`) + out := parallelEval(t, prg, map[string]any{ + "items": []any{1, 2, 3}, + "base": 100, + }) + assertOrderedList(t, out, []ref.Val{types.Int(101), types.Int(102), types.Int(103)}) +} + +func TestInterpreterActivationContract(t *testing.T) { + // Confirm our understanding of NewHierarchicalActivation: child takes + // priority, parent provides fallback, unknown names return not-found. + child, err := interpreter.NewActivation(map[string]any{"x": 42}) + if err != nil { + t.Fatalf("NewActivation child: %v", err) + } + parent, err := interpreter.NewActivation(map[string]any{"y": 99}) + if err != nil { + t.Fatalf("NewActivation parent: %v", err) + } + hier := interpreter.NewHierarchicalActivation(parent, child) + + if v, ok := hier.ResolveName("x"); !ok || v != 42 { + t.Errorf("child var x: got (%v, %v) want (42, true)", v, ok) + } + if v, ok := hier.ResolveName("y"); !ok || v != 99 { + t.Errorf("parent var y: got (%v, %v) want (99, true)", v, ok) + } + if v, ok := hier.ResolveName("z"); ok { + t.Errorf("unknown var z: got (%v, %v) want (_, false)", v, ok) + } +} + +// parallelTestEnv builds a cel.Env with Parallel(perCall, 0) and a standard +// set of variable declarations covering all names used in the test suite. +// Additional env options (e.g. custom functions) may be appended via extra. +func parallelTestEnv(t *testing.T, perCall int, extra ...cel.EnvOption) *cel.Env { + t.Helper() + base := []cel.EnvOption{ + Parallel(perCall, 0), + // Scalar inputs used across tests. + cel.Variable("items", cel.ListType(cel.DynType)), + cel.Variable("outer", cel.ListType(cel.DynType)), + cel.Variable("m", cel.MapType(cel.StringType, cel.DynType)), + cel.Variable("suffix", cel.StringType), + cel.Variable("base", cel.IntType), + cel.Variable("threshold", cel.IntType), + } + env, err := cel.NewEnv(append(base, extra...)...) + if err != nil { + t.Fatalf("cel.NewEnv: %v", err) + } + return env +} + +// global cap + +func TestParallel_GlobalCap_BoundsTotal(t *testing.T) { + // With perCall=8 and globalCap=2, no more than 2 leaf bodies run at once. + var inFlight atomic.Int64 + var overLimit atomic.Bool + + trackFn := cel.Function("track_global", + cel.Overload("track_global_int", + []*cel.Type{cel.IntType}, + cel.IntType, + cel.UnaryBinding(func(v ref.Val) ref.Val { + cur := inFlight.Add(1) + if cur > 2 { + overLimit.Store(true) + } + ch := make(chan struct{}) + go func() { close(ch) }() + <-ch + inFlight.Add(-1) + return v + }), + ), + ) + + env, err := cel.NewEnv( + Parallel(8, 2), + cel.Variable("items", cel.ListType(cel.DynType)), + trackFn, + ) + if err != nil { + t.Fatalf("cel.NewEnv: %v", err) + } + prg := parallelProgram(t, env, `items.parallel(x, track_global(x))`) + input := make([]any, 20) + for i := range input { + input[i] = i + } + for run := 0; run < 5; run++ { + overLimit.Store(false) + parallelEval(t, prg, map[string]any{"items": input}) + if overLimit.Load() { + t.Fatalf("run %d: global cap violated: more than 2 leaf bodies concurrent", run) + } + } +} + +func TestParallel_GlobalCap_NestedNoDeadlock(t *testing.T) { + // Nested parallel with a global cap must not deadlock: the outer parallel's + // goroutines don't consume global slots (they are not leaves). + env, err := cel.NewEnv( + Parallel(4, 3), + cel.Variable("outer", cel.ListType(cel.DynType)), + ) + if err != nil { + t.Fatalf("cel.NewEnv: %v", err) + } + prg := parallelProgram(t, env, `outer.parallel(row, row.parallel(x, x * 2))`) + + // Build 8 rows of 4 elements each to create scheduling pressure. + rows := make([]any, 8) + for r := range rows { + row := make([]any, 4) + for c := range row { + row[c] = r*4 + c + 1 + } + rows[r] = row + } + for run := 0; run < 10; run++ { + out := parallelEval(t, prg, map[string]any{"outer": rows}) + outerList := out.(traits.Lister) + if int(outerList.Size().(types.Int)) != 8 { + t.Fatalf("run %d: outer size: got %d want 8", run, outerList.Size()) + } + for r := range rows { + inner := outerList.Get(types.Int(r)) + want := make([]ref.Val, 4) + for c := range want { + want[c] = types.Int((r*4 + c + 1) * 2) + } + assertOrderedList(t, inner, want) + } + } +} + +// assertOrderedList checks that val is a CEL list equal to want in order. +func assertOrderedList(t *testing.T, val ref.Val, want []ref.Val) { + t.Helper() + list, ok := val.(traits.Lister) + if !ok { + t.Fatalf("expected list, got %T: %v", val, val) + } + sz := int(list.Size().(types.Int)) + if sz != len(want) { + t.Fatalf("list length: got %d want %d (value=%v)", sz, len(want), val) + } + for i, w := range want { + got := list.Get(types.Int(i)) + if got.Equal(w) != types.True { + t.Errorf("[%d]: got %v want %v", i, got, w) + } + } +} + +func TestParallel_GlobalCap_ZeroMeansUnlimited(t *testing.T) { + // globalCap=0 means no global limit; verify actual concurrency. + var peak atomic.Int64 + var inFlight atomic.Int64 + + trackFn := cel.Function("track_unlim", + cel.Overload("track_unlim_int", + []*cel.Type{cel.IntType}, + cel.IntType, + cel.UnaryBinding(func(v ref.Val) ref.Val { + cur := inFlight.Add(1) + for { + p := peak.Load() + if cur <= p || peak.CompareAndSwap(p, cur) { + break + } + } + ch := make(chan struct{}) + go func() { close(ch) }() + <-ch + inFlight.Add(-1) + return v + }), + ), + ) + + env, err := cel.NewEnv( + Parallel(8, 0), + cel.Variable("items", cel.ListType(cel.DynType)), + trackFn, + ) + if err != nil { + t.Fatalf("cel.NewEnv: %v", err) + } + prg := parallelProgram(t, env, `items.parallel(x, track_unlim(x))`) + input := make([]any, 32) + for i := range input { + input[i] = i + } + parallelEval(t, prg, map[string]any{"items": input}) + if peak.Load() <= 1 { + t.Errorf("expected concurrent execution (peak > 1), got %d", peak.Load()) + } +} + +// race-detector regression + +func TestParallel_ConcurrentEvalRace(t *testing.T) { + // Trip wire for Interpretable.Eval thread-safety. Exercises a variety of + // Interpretable types under high concurrency so the race detector has the + // best chance of spotting shared mutable state. This is not a correctness + // test — existing tests cover that. + env, err := cel.NewEnv( + Parallel(8, 0), + cel.Variable("items", cel.ListType(cel.DynType)), + cel.Variable("base", cel.IntType), + cel.Variable("m", cel.MapType(cel.StringType, cel.DynType)), + ) + if err != nil { + t.Fatalf("cel.NewEnv: %v", err) + } + // Binary ops, function call, list construction, conditional, comprehension, map access. + src := `items.parallel(x, x * 2 + base + int([x].map(y, y)[0]) + (x > 0 ? x : 0) + m["k"])` + prg := parallelProgram(t, env, src) + input := make([]any, 128) + for i := range input { + input[i] = i + } + vars := map[string]any{ + "items": input, + "base": 0, + "m": map[string]any{"k": 0}, + } + for iter := 0; iter < 10; iter++ { + parallelEval(t, prg, vars) + } +} + +func parallelProgram(t *testing.T, env *cel.Env, src string) cel.Program { + t.Helper() + ast, issues := env.Compile(src) + if issues != nil && issues.Err() != nil { + t.Fatalf("compile %q: %v", src, issues.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program %q: %v", src, err) + } + return prg +} + +func parallelEval(t *testing.T, prg cel.Program, vars map[string]any) ref.Val { + t.Helper() + out, _, err := prg.Eval(vars) + if err != nil { + t.Fatalf("eval: %v", err) + } + return out +} diff --git a/mito.go b/mito.go index e7eb8b9..599c587 100644 --- a/mito.go +++ b/mito.go @@ -77,6 +77,8 @@ func Main() int { insecure := flag.Bool("insecure", false, "disable TLS verification in the HTTP client") logTrace := flag.Bool("log_requests", false, "log request traces to stderr (go1.21+)") maxTraceBody := flag.Int("max_log_body", 1000, "maximum length of body logged in request traces (go1.21+)") + maxConcurrency := flag.Int("max_concurrency", 3, "maximum concurrency in parallel macro") + globalConcurrency := flag.Int("global_concurrency", 0, "global concurrency limit across all parallel calls (0 = no limit)") fold := flag.Bool("fold", false, "apply constant folding optimisation") dumpState := flag.String("dump", "", "dump eval state ('always' or 'error')") coverage := flag.String("coverage", "", "file to write an execution coverage report to (prefix if multiple executions are run)") @@ -187,6 +189,7 @@ func Main() int { return 2 } } + libMap["parallel"] = lib.Parallel(*maxConcurrency, *globalConcurrency) if *use == "all" { for _, l := range libMap { libs = append(libs, l) @@ -479,6 +482,7 @@ var ( "strings": lib.Strings(), "printf": lib.Printf(), "xml": nil, // This will be populated by Main. + "parallel": nil, // This will be populated by Main. } mimetypes = map[string]interface{}{ diff --git a/mito_test.go b/mito_test.go index e8b8cdb..c573513 100644 --- a/mito_test.go +++ b/mito_test.go @@ -20,12 +20,15 @@ package mito import ( "encoding/base64" "flag" + "fmt" "net/http" "net/http/httptest" "os" "path/filepath" "regexp" + "strconv" "sync" + "sync/atomic" "testing" "time" @@ -102,8 +105,37 @@ func server(ts *testscript.TestScript, neg bool, name string, newServer func(han if neg { ts.Fatalf("unsupported: ! %s", name) } + + // Parse leading -flag value pairs. + var barrierN int + var timeout time.Duration + i := 0 + for i < len(args) && len(args[i]) > 0 && args[i][0] == '-' { + if i+1 >= len(args) { + ts.Fatalf("%s: flag %s requires a value", name, args[i]) + } + switch args[i] { + case "-barrier": + n, err := strconv.Atoi(args[i+1]) + if err != nil { + ts.Fatalf("%s: -barrier: %v", name, err) + } + barrierN = n + case "-timeout": + d, err := time.ParseDuration(args[i+1]) + if err != nil { + ts.Fatalf("%s: -timeout: %v", name, err) + } + timeout = d + default: + ts.Fatalf("%s: unknown flag %s", name, args[i]) + } + i += 2 + } + args = args[i:] + if len(args) != 1 && len(args) != 3 { - ts.Fatalf("usage: %s body [user password]", name) + ts.Fatalf("usage: %s [-barrier N] [-timeout duration] body [user password]", name) } var user, pass string body, err := os.ReadFile(ts.MkAbs(args[0])) @@ -112,21 +144,59 @@ func server(ts *testscript.TestScript, neg bool, name string, newServer func(han user = args[1] pass = args[2] } - srv := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + + checkAuth := func(w http.ResponseWriter, req *http.Request) bool { u, p, _ := req.BasicAuth() - // Obvious security anti-patterns are obvious; for testing. if user != "" && user != u { w.WriteHeader(http.StatusForbidden) w.Write([]byte("user mismatch")) - return + return false } if pass != "" && pass != p { w.WriteHeader(http.StatusForbidden) w.Write([]byte("password mismatch")) - return + return false } - w.Write(body) - })) + return true + } + + var handler http.Handler + if barrierN > 0 { + var arrived atomic.Int64 + gate := make(chan struct{}) + var once sync.Once + handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if !checkAuth(w, req) { + return + } + if arrived.Add(1) >= int64(barrierN) { + once.Do(func() { close(gate) }) + } + if timeout > 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case <-gate: + case <-timer.C: + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(w, "barrier timeout: only %d of %d requests arrived", arrived.Load(), barrierN) + return + } + } else { + <-gate + } + w.Write(body) + }) + } else { + handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if !checkAuth(w, req) { + return + } + w.Write(body) + }) + } + + srv := newServer(handler) ts.Setenv("URL", srv.URL) ts.Defer(func() { srv.Close() }) } diff --git a/testdata/map_barrier.txt b/testdata/map_barrier.txt new file mode 100644 index 0000000..e5fcec0 --- /dev/null +++ b/testdata/map_barrier.txt @@ -0,0 +1,17 @@ +serve -barrier 3 -timeout 5s body.json +expand src_var.cel src.cel + +mito -use http src.cel +! stderr . +cmp stdout want.txt + +-- body.json -- +{"ok":true} +-- src_var.cel -- +[1, 2, 3].map(x, string(get("${URL}").Body)) +-- want.txt -- +[ + "barrier timeout: only 1 of 3 requests arrived", + "barrier timeout: only 2 of 3 requests arrived", + "{\"ok\":true}\n" +] diff --git a/testdata/parallel.txt b/testdata/parallel.txt new file mode 100644 index 0000000..acec6ed --- /dev/null +++ b/testdata/parallel.txt @@ -0,0 +1,31 @@ +mito src.cel +! stderr . +cmp stdout want.txt + +-- src.cel -- +[ + [1,2,3].parallel(v, v), + + { + "one":1, + "two":2, + "three":3, + }.parallel(k, v, + {"key": k, "val": v} + ).transformMapEntry(_, v, + {v.key: v.val} + ), +] +-- want.txt -- +[ + [ + 1, + 2, + 3 + ], + { + "one": 1, + "three": 3, + "two": 2 + } +] diff --git a/testdata/parallel_barrier.txt b/testdata/parallel_barrier.txt new file mode 100644 index 0000000..db70a40 --- /dev/null +++ b/testdata/parallel_barrier.txt @@ -0,0 +1,26 @@ +serve -barrier 3 -timeout 5s body.json +expand src_var.cel src.cel + +mito -use http,json,parallel,collections src.cel +! stderr . +cmp stdout want.txt + +-- body.json -- +{"ok":true} +-- src_var.cel -- +[1, 2, 3].parallel(x, get("${URL}").Body.decode_json().with({"num": x})) +-- want.txt -- +[ + { + "num": 1, + "ok": true + }, + { + "num": 2, + "ok": true + }, + { + "num": 3, + "ok": true + } +]