Skip to content

Commit 741ad14

Browse files
l46kokcopybara-github
authored andcommitted
Support container resolution for calls and struct creation in planner
PiperOrigin-RevId: 839411411
1 parent 5a75f0f commit 741ad14

File tree

3 files changed

+135
-29
lines changed

3 files changed

+135
-29
lines changed

runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import dev.cel.common.ast.CelExpr.CelCall;
2929
import dev.cel.common.ast.CelExpr.CelList;
3030
import dev.cel.common.ast.CelExpr.CelMap;
31+
import dev.cel.common.ast.CelExpr.CelSelect;
3132
import dev.cel.common.ast.CelExpr.CelStruct;
3233
import dev.cel.common.ast.CelExpr.CelStruct.Entry;
3334
import dev.cel.common.ast.CelReference;
@@ -58,6 +59,7 @@ public final class ProgramPlanner {
5859
private final CelValueProvider valueProvider;
5960
private final DefaultDispatcher dispatcher;
6061
private final AttributeFactory attributeFactory;
62+
private final CelContainer container;
6163

6264
/**
6365
* Plans a {@link Program} from the provided parsed-only or type-checked {@link
@@ -168,7 +170,6 @@ private Interpretable planCall(CelExpr expr, PlannerContext ctx) {
168170
evaluatedArgs[argIndex + offset] = plan(args.get(argIndex), ctx);
169171
}
170172

171-
// TODO: Handle all specialized calls (logical operators, conditionals, equals etc)
172173
String functionName = resolvedFunction.functionName();
173174
Operator operator = Operator.findReverse(functionName).orElse(null);
174175
if (operator != null) {
@@ -209,16 +210,7 @@ private Interpretable planCall(CelExpr expr, PlannerContext ctx) {
209210

210211
private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) {
211212
CelStruct struct = celExpr.struct();
212-
CelType structType =
213-
typeProvider
214-
.findType(struct.messageName())
215-
.orElseThrow(
216-
() -> new IllegalArgumentException("Undefined type name: " + struct.messageName()));
217-
if (!structType.kind().equals(CelKind.STRUCT)) {
218-
throw new IllegalArgumentException(
219-
String.format(
220-
"Expected struct type for %s, got %s", structType.name(), structType.kind()));
221-
}
213+
StructType structType = resolveStructType(struct);
222214

223215
ImmutableList<Entry> entries = struct.entries();
224216
String[] keys = new String[entries.size()];
@@ -230,7 +222,7 @@ private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) {
230222
values[i] = plan(entry.value(), ctx);
231223
}
232224

233-
return EvalCreateStruct.create(valueProvider, (StructType) structType, keys, values);
225+
return EvalCreateStruct.create(valueProvider, structType, keys, values);
234226
}
235227

236228
private Interpretable planCreateList(CelExpr celExpr, PlannerContext ctx) {
@@ -269,7 +261,7 @@ private Interpretable planCreateMap(CelExpr celExpr, PlannerContext ctx) {
269261
private ResolvedFunction resolveFunction(
270262
CelExpr expr, ImmutableMap<Long, CelReference> referenceMap) {
271263
CelCall call = expr.call();
272-
Optional<CelExpr> target = call.target();
264+
Optional<CelExpr> maybeTarget = call.target();
273265
String functionName = call.function();
274266

275267
CelReference reference = referenceMap.get(expr.id());
@@ -281,22 +273,89 @@ private ResolvedFunction resolveFunction(
281273
.setFunctionName(functionName)
282274
.setOverloadId(reference.overloadIds().get(0));
283275

284-
target.ifPresent(builder::setTarget);
276+
maybeTarget.ifPresent(builder::setTarget);
285277

286278
return builder.build();
287279
}
288280
}
289281

290-
// Parsed-only.
291-
// TODO: Handle containers.
292-
if (!target.isPresent()) {
282+
// Parsed-only function resolution.
283+
//
284+
// There are two distinct cases we must handle:
285+
//
286+
// 1. Non-qualified function calls. This will resolve into either:
287+
// - A simple global call foo()
288+
// - A fully qualified global call through normal container resolution foo.bar.qux()
289+
// 2. Qualified function calls:
290+
// - A member call on an identifier foo.bar()
291+
// - A fully qualified global call, through normal container resolution or abbreviations
292+
// foo.bar.qux()
293+
if (!maybeTarget.isPresent()) {
294+
for (String cand : container.resolveCandidateNames(functionName)) {
295+
CelResolvedOverload overload = dispatcher.findOverload(cand).orElse(null);
296+
if (overload != null) {
297+
return ResolvedFunction.newBuilder().setFunctionName(cand).build();
298+
}
299+
}
300+
301+
// Normal global call
293302
return ResolvedFunction.newBuilder().setFunctionName(functionName).build();
294-
} else {
295-
return ResolvedFunction.newBuilder()
296-
.setFunctionName(functionName)
297-
.setTarget(target.get())
298-
.build();
299303
}
304+
305+
CelExpr target = maybeTarget.get();
306+
String qualifiedPrefix = toQualifiedName(target).orElse(null);
307+
if (qualifiedPrefix != null) {
308+
String qualifiedName = qualifiedPrefix + "." + functionName;
309+
for (String cand : container.resolveCandidateNames(qualifiedName)) {
310+
CelResolvedOverload overload = dispatcher.findOverload(cand).orElse(null);
311+
if (overload != null) {
312+
return ResolvedFunction.newBuilder().setFunctionName(cand).build();
313+
}
314+
}
315+
}
316+
317+
// Normal member call
318+
return ResolvedFunction.newBuilder().setFunctionName(functionName).setTarget(target).build();
319+
}
320+
321+
private StructType resolveStructType(CelStruct struct) {
322+
String messageName = struct.messageName();
323+
for (String typeName : container.resolveCandidateNames(messageName)) {
324+
CelType structType = typeProvider.findType(typeName).orElse(null);
325+
if (structType == null) {
326+
continue;
327+
}
328+
329+
if (!structType.kind().equals(CelKind.STRUCT)) {
330+
throw new IllegalArgumentException(
331+
String.format(
332+
"Expected struct type for %s, got %s", structType.name(), structType.kind()));
333+
}
334+
335+
return (StructType) structType;
336+
}
337+
338+
throw new IllegalArgumentException("Undefined type name: " + messageName);
339+
}
340+
341+
/** Converts a given expression into a qualified name, if possible. */
342+
private Optional<String> toQualifiedName(CelExpr operand) {
343+
switch (operand.getKind()) {
344+
case IDENT:
345+
return Optional.of(operand.ident().name());
346+
case SELECT:
347+
CelSelect select = operand.select();
348+
String maybeQualified = toQualifiedName(select.operand()).orElse(null);
349+
if (maybeQualified != null) {
350+
return Optional.of(maybeQualified + "." + select.field());
351+
}
352+
353+
break;
354+
default:
355+
// fall-through
356+
}
357+
358+
return Optional.empty();
300359
}
301360

302361
@AutoValue
@@ -338,17 +397,22 @@ private static PlannerContext create(CelAbstractSyntaxTree ast) {
338397
}
339398

340399
public static ProgramPlanner newPlanner(
341-
CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) {
342-
return new ProgramPlanner(typeProvider, valueProvider, dispatcher);
400+
CelTypeProvider typeProvider,
401+
CelValueProvider valueProvider,
402+
DefaultDispatcher dispatcher,
403+
CelContainer container) {
404+
return new ProgramPlanner(typeProvider, valueProvider, dispatcher, container);
343405
}
344406

345407
private ProgramPlanner(
346-
CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) {
408+
CelTypeProvider typeProvider,
409+
CelValueProvider valueProvider,
410+
DefaultDispatcher dispatcher,
411+
CelContainer container) {
347412
this.typeProvider = typeProvider;
348413
this.valueProvider = valueProvider;
349-
// TODO: Container support
350414
this.dispatcher = dispatcher;
351-
this.attributeFactory =
352-
AttributeFactory.newAttributeFactory(CelContainer.newBuilder().build(), typeProvider);
415+
this.container = container;
416+
this.attributeFactory = AttributeFactory.newAttributeFactory(container, typeProvider);
353417
}
354418
}

runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ java_library(
1818
"//common:cel_descriptor_util",
1919
"//common:cel_source",
2020
"//common:compiler_common",
21+
"//common:container",
2122
"//common:error_codes",
2223
"//common:operator",
2324
"//common:options",

runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
3232
import com.google.testing.junit.testparameterinjector.TestParameters;
3333
import dev.cel.common.CelAbstractSyntaxTree;
34+
import dev.cel.common.CelContainer;
3435
import dev.cel.common.CelDescriptorUtil;
3536
import dev.cel.common.CelErrorCode;
3637
import dev.cel.common.CelOptions;
@@ -99,9 +100,11 @@ public final class ProgramPlannerTest {
99100
DynamicProto.create(DefaultMessageFactory.create(DESCRIPTOR_POOL));
100101
private static final CelValueProvider VALUE_PROVIDER =
101102
ProtoMessageValueProvider.newInstance(CelOptions.DEFAULT, DYNAMIC_PROTO);
103+
private static final CelContainer CEL_CONTAINER =
104+
CelContainer.newBuilder().setName("cel.expr.conformance.proto3").build();
102105

103106
private static final ProgramPlanner PLANNER =
104-
ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER, newDispatcher());
107+
ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER, newDispatcher(), CEL_CONTAINER);
105108
private static final CelCompiler CEL_COMPILER =
106109
CelCompilerFactory.standardCelCompilerBuilder()
107110
.addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.DYN))
@@ -114,6 +117,10 @@ public final class ProgramPlannerTest {
114117
"neg",
115118
newGlobalOverload("neg_int", SimpleType.INT, SimpleType.INT),
116119
newGlobalOverload("neg_double", SimpleType.DOUBLE, SimpleType.DOUBLE)),
120+
newFunctionDeclaration(
121+
"cel.expr.conformance.proto3.power",
122+
newGlobalOverload(
123+
"power_int_int", SimpleType.INT, SimpleType.INT, SimpleType.INT)),
117124
newFunctionDeclaration(
118125
"concat",
119126
newGlobalOverload(
@@ -122,6 +129,7 @@ public final class ProgramPlannerTest {
122129
"bytes_concat_bytes", SimpleType.BYTES, SimpleType.BYTES, SimpleType.BYTES)))
123130
.addMessageTypes(TestAllTypes.getDescriptor())
124131
.addLibraries(CelExtensions.optional())
132+
.setContainer(CEL_CONTAINER)
125133
.build();
126134

127135
/**
@@ -174,6 +182,14 @@ private static DefaultDispatcher newDispatcher() {
174182
"neg",
175183
CelFunctionBinding.from("neg_int", Long.class, arg -> -arg),
176184
CelFunctionBinding.from("neg_double", Double.class, arg -> -arg));
185+
addBindings(
186+
builder,
187+
"cel.expr.conformance.proto3.power",
188+
CelFunctionBinding.from(
189+
"power_int_int",
190+
Long.class,
191+
Long.class,
192+
(value, power) -> (long) Math.pow(value, power)));
177193
addBindings(
178194
builder,
179195
"concat",
@@ -379,6 +395,16 @@ public void planCreateStruct_withFields() throws Exception {
379395
.isEqualTo(TestAllTypes.newBuilder().setSingleString("foo").setSingleBool(true).build());
380396
}
381397

398+
@Test
399+
public void plan_createStruct_withContainer() throws Exception {
400+
CelAbstractSyntaxTree ast = compile("TestAllTypes{}");
401+
Program program = PLANNER.plan(ast);
402+
403+
TestAllTypes result = (TestAllTypes) program.eval();
404+
405+
assertThat(result).isEqualTo(TestAllTypes.getDefaultInstance());
406+
}
407+
382408
@Test
383409
public void plan_call_zeroArgs() throws Exception {
384410
CelAbstractSyntaxTree ast = compile("zero()");
@@ -555,6 +581,21 @@ public void plan_call_conditional_throws(String expression) throws Exception {
555581
assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.DIVIDE_BY_ZERO);
556582
}
557583

584+
@Test
585+
@TestParameters("{expression: 'power(2,3)'}")
586+
@TestParameters("{expression: 'proto3.power(2,3)'}")
587+
@TestParameters("{expression: 'conformance.proto3.power(2,3)'}")
588+
@TestParameters("{expression: 'expr.conformance.proto3.power(2,3)'}")
589+
@TestParameters("{expression: 'cel.expr.conformance.proto3.power(2,3)'}")
590+
public void plan_call_withContainer(String expression) throws Exception {
591+
CelAbstractSyntaxTree ast = compile(expression); // invokes cel.expr.conformance.proto3.power
592+
Program program = PLANNER.plan(ast);
593+
594+
Long result = (Long) program.eval();
595+
596+
assertThat(result).isEqualTo(8);
597+
}
598+
558599
private CelAbstractSyntaxTree compile(String expression) throws Exception {
559600
CelAbstractSyntaxTree ast = CEL_COMPILER.parse(expression).getAst();
560601
if (isParseOnly) {

0 commit comments

Comments
 (0)