Skip to content

Commit 5c81c2b

Browse files
l46kokcopybara-github
authored andcommitted
Fix CelValue adaptation for lists/maps
PiperOrigin-RevId: 875476558
1 parent 0737691 commit 5c81c2b

14 files changed

Lines changed: 169 additions & 80 deletions

File tree

common/src/main/java/dev/cel/common/exceptions/CelInvalidArgumentException.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,8 @@ public final class CelInvalidArgumentException extends CelRuntimeException {
2424
public CelInvalidArgumentException(Throwable cause) {
2525
super(cause, CelErrorCode.INVALID_ARGUMENT);
2626
}
27+
28+
public CelInvalidArgumentException(String message) {
29+
super(message, CelErrorCode.INVALID_ARGUMENT);
30+
}
2731
}

common/src/main/java/dev/cel/common/values/CelValueConverter.java

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,39 @@ public static CelValueConverter getDefaultInstance() {
4141
return DEFAULT_INSTANCE;
4242
}
4343

44-
/** Adapts a {@link CelValue} to a plain old Java Object. */
45-
public Object unwrap(CelValue celValue) {
46-
Preconditions.checkNotNull(celValue);
44+
/**
45+
* Unwraps the {@code value} into its plain old Java Object representation.
46+
*
47+
* <p>The value may be a {@link CelValue}, a {@link Collection} or a {@link Map}.
48+
*/
49+
public Object maybeUnwrap(Object value) {
50+
if (value instanceof CelValue) {
51+
return unwrap((CelValue) value);
52+
}
4753

48-
if (celValue instanceof OptionalValue) {
49-
OptionalValue<Object, Object> optionalValue = (OptionalValue<Object, Object>) celValue;
50-
if (optionalValue.isZeroValue()) {
51-
return Optional.empty();
54+
if (value instanceof Collection) {
55+
Collection<Object> collection = (Collection<Object>) value;
56+
ImmutableList.Builder<Object> builder =
57+
ImmutableList.builderWithExpectedSize(collection.size());
58+
for (Object element : collection) {
59+
builder.add(maybeUnwrap(element));
5260
}
5361

54-
return Optional.of(optionalValue.value());
62+
return builder.build();
5563
}
5664

57-
if (celValue instanceof ErrorValue) {
58-
return celValue;
65+
if (value instanceof Map) {
66+
Map<Object, Object> map = (Map<Object, Object>) value;
67+
ImmutableMap.Builder<Object, Object> builder =
68+
ImmutableMap.builderWithExpectedSize(map.size());
69+
for (Map.Entry<Object, Object> entry : map.entrySet()) {
70+
builder.put(maybeUnwrap(entry.getKey()), maybeUnwrap(entry.getValue()));
71+
}
72+
73+
return builder.buildOrThrow();
5974
}
6075

61-
return celValue.value();
76+
return value;
6277
}
6378

6479
/**
@@ -101,6 +116,26 @@ protected Object normalizePrimitive(Object value) {
101116
return value;
102117
}
103118

119+
/** Adapts a {@link CelValue} to a plain old Java Object. */
120+
private static Object unwrap(CelValue celValue) {
121+
Preconditions.checkNotNull(celValue);
122+
123+
if (celValue instanceof OptionalValue) {
124+
OptionalValue<Object, Object> optionalValue = (OptionalValue<Object, Object>) celValue;
125+
if (optionalValue.isZeroValue()) {
126+
return Optional.empty();
127+
}
128+
129+
return Optional.of(optionalValue.value());
130+
}
131+
132+
if (celValue instanceof ErrorValue) {
133+
return celValue;
134+
}
135+
136+
return celValue.value();
137+
}
138+
104139
private ImmutableList<Object> toListValue(Collection<Object> iterable) {
105140
Preconditions.checkNotNull(iterable);
106141

common/src/test/java/dev/cel/common/values/CelValueConverterTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,16 @@ public void toRuntimeValue_optionalValue() {
3737
@Test
3838
@SuppressWarnings("unchecked") // Test only
3939
public void unwrap_optionalValue() {
40-
Optional<Long> result = (Optional<Long>) CEL_VALUE_CONVERTER.unwrap(OptionalValue.create(2L));
40+
Optional<Long> result =
41+
(Optional<Long>) CEL_VALUE_CONVERTER.maybeUnwrap(OptionalValue.create(2L));
4142

4243
assertThat(result).isEqualTo(Optional.of(2L));
4344
}
4445

4546
@Test
4647
@SuppressWarnings("unchecked") // Test only
4748
public void unwrap_emptyOptionalValue() {
48-
Optional<Long> result = (Optional<Long>) CEL_VALUE_CONVERTER.unwrap(OptionalValue.EMPTY);
49+
Optional<Long> result = (Optional<Long>) CEL_VALUE_CONVERTER.maybeUnwrap(OptionalValue.EMPTY);
4950

5051
assertThat(result).isEqualTo(Optional.empty());
5152
}

common/src/test/java/dev/cel/common/values/ProtoCelValueConverterTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class ProtoCelValueConverterTest {
3535

3636
@Test
3737
public void unwrap_nullValue() {
38-
NullValue nullValue = (NullValue) PROTO_CEL_VALUE_CONVERTER.unwrap(NullValue.NULL_VALUE);
38+
NullValue nullValue = (NullValue) PROTO_CEL_VALUE_CONVERTER.maybeUnwrap(NullValue.NULL_VALUE);
3939

4040
// Note: No conversion is attempted. We're using dev.cel.common.values.NullValue.NULL_VALUE as
4141
// the

conformance/src/test/java/dev/cel/conformance/BUILD.bazel

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,18 +175,7 @@ _TESTS_TO_SKIP_PLANNER = [
175175
"timestamps/timestamp_range/sub_time_duration_under",
176176

177177
# Skip until fixed.
178-
"fields/qualified_identifier_resolution/map_key_float",
179-
"fields/qualified_identifier_resolution/map_key_null",
180-
"fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous",
181-
"optionals/optionals/map_null_entry_no_such_key",
182-
"optionals/optionals/map_present_key_invalid_field",
183178
"parse/receiver_function_names",
184-
"proto2/extensions_get/package_scoped_test_all_types_ext",
185-
"proto2/extensions_get/package_scoped_repeated_test_all_types",
186-
"proto2/extensions_get/message_scoped_nested_ext",
187-
"proto2/extensions_get/message_scoped_repeated_test_all_types",
188-
"proto2_ext/get_ext/package_scoped_repeated_test_all_types",
189-
"proto2_ext/get_ext/message_scoped_repeated_test_all_types",
190179

191180
# TODO: Fix null assignment to a field
192181
"proto2/set_null/single_message",

extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import dev.cel.runtime.CelFunctionBinding;
4747
import dev.cel.runtime.CelRuntime;
4848
import dev.cel.runtime.CelRuntimeFactory;
49+
import dev.cel.runtime.CelRuntimeImpl;
4950
import org.junit.Test;
5051
import org.junit.runner.RunWith;
5152

@@ -60,9 +61,14 @@ public final class CelProtoExtensionsTest {
6061
.addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"))
6162
.setContainer(CelContainer.ofName("cel.expr.conformance.proto2"))
6263
.build();
63-
6464
private static final CelRuntime CEL_RUNTIME =
65-
CelRuntimeFactory.standardCelRuntimeBuilder()
65+
CelRuntimeImpl.newBuilder()
66+
.setOptions(
67+
CelOptions.current()
68+
.enableTimestampEpoch(true)
69+
.enableHeterogeneousNumericComparisons(true)
70+
.build())
71+
// CEL-Internal-2
6672
.addFileTypes(TestAllTypesExtensions.getDescriptor())
6773
.build();
6874

runtime/src/main/java/dev/cel/runtime/CelValueRuntimeTypeProvider.java

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import dev.cel.common.exceptions.CelAttributeNotFoundException;
2323
import dev.cel.common.values.BaseProtoCelValueConverter;
2424
import dev.cel.common.values.BaseProtoMessageValueProvider;
25-
import dev.cel.common.values.CelValue;
2625
import dev.cel.common.values.CelValueProvider;
2726
import dev.cel.common.values.CombinedCelValueProvider;
2827
import dev.cel.common.values.SelectableValue;
@@ -64,7 +63,7 @@ static CelValueRuntimeTypeProvider newInstance(CelValueProvider valueProvider) {
6463

6564
@Override
6665
public Object createMessage(String messageName, Map<String, Object> values) {
67-
return maybeUnwrapCelValue(
66+
return protoCelValueConverter.maybeUnwrap(
6867
valueProvider
6968
.newValue(messageName, values)
7069
.orElseThrow(
@@ -87,7 +86,7 @@ public Object selectField(Object message, String fieldName) {
8786
SelectableValue<String> selectableValue = getSelectableValueOrThrow(message, fieldName);
8887
Object value = selectableValue.select(fieldName);
8988

90-
return maybeUnwrapCelValue(value);
89+
return protoCelValueConverter.maybeUnwrap(value);
9190
}
9291

9392
@Override
@@ -120,24 +119,12 @@ public Object adapt(String messageName, Object message) {
120119
}
121120

122121
if (message instanceof MessageLite) {
123-
return maybeUnwrapCelValue(protoCelValueConverter.toRuntimeValue(message));
122+
return protoCelValueConverter.maybeUnwrap(protoCelValueConverter.toRuntimeValue(message));
124123
}
125124

126125
return message;
127126
}
128127

129-
/**
130-
* DefaultInterpreter cannot handle CelValue and instead expects plain Java objects.
131-
*
132-
* <p>This will become unnecessary once we introduce a rewrite of a Cel runtime.
133-
*/
134-
private Object maybeUnwrapCelValue(Object object) {
135-
if (object instanceof CelValue) {
136-
return protoCelValueConverter.unwrap((CelValue) object);
137-
}
138-
return object;
139-
}
140-
141128
private static void throwInvalidFieldSelection(String fieldName) {
142129
throw CelAttributeNotFoundException.forFieldResolution(fieldName);
143130
}

runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ java_library(
128128
"//common/types",
129129
"//common/types:type_providers",
130130
"//common/values",
131-
"//common/values:cel_value",
132131
"//runtime:interpretable",
133132
"@maven//:com_google_errorprone_error_prone_annotations",
134133
"@maven//:com_google_guava_guava",
@@ -334,6 +333,7 @@ java_library(
334333
":localized_evaluation_exception",
335334
":planned_interpretable",
336335
"//common/exceptions:duplicate_key",
336+
"//common/exceptions:invalid_argument",
337337
"//runtime:evaluation_exception",
338338
"//runtime:interpretable",
339339
"@maven//:com_google_errorprone_error_prone_annotations",
@@ -379,7 +379,6 @@ java_library(
379379
"//common:error_codes",
380380
"//common/exceptions:runtime_exception",
381381
"//common/values",
382-
"//common/values:cel_value",
383382
"//runtime:evaluation_exception",
384383
"//runtime:interpretable",
385384
"//runtime:resolved_overload",

runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import com.google.common.base.Preconditions;
1818
import com.google.common.collect.ImmutableMap;
1919
import com.google.common.collect.Sets;
20+
import com.google.common.primitives.UnsignedLong;
2021
import com.google.errorprone.annotations.Immutable;
2122
import dev.cel.common.exceptions.CelDuplicateKeyException;
23+
import dev.cel.common.exceptions.CelInvalidArgumentException;
2224
import dev.cel.runtime.CelEvaluationException;
2325
import dev.cel.runtime.GlobalResolver;
2426
import java.util.HashSet;
@@ -46,13 +48,37 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval
4648
HashSet<Object> keysSeen = Sets.newHashSetWithExpectedSize(keys.length);
4749

4850
for (int i = 0; i < keys.length; i++) {
49-
Object key = keys[i].eval(resolver, frame);
50-
Object val = values[i].eval(resolver, frame);
51+
PlannedInterpretable keyInterpretable = keys[i];
52+
Object key = keyInterpretable.eval(resolver, frame);
53+
if (!(key instanceof String
54+
|| key instanceof Long
55+
|| key instanceof UnsignedLong
56+
|| key instanceof Boolean)) {
57+
throw new LocalizedEvaluationException(
58+
new CelInvalidArgumentException("Unsupported key type: " + key),
59+
keyInterpretable.exprId());
60+
}
5161

52-
if (!keysSeen.add(key)) {
53-
throw new LocalizedEvaluationException(CelDuplicateKeyException.of(key), keys[i].exprId());
62+
boolean isDuplicate = !keysSeen.add(key);
63+
if (!isDuplicate) {
64+
if (key instanceof Long) {
65+
long longVal = (Long) key;
66+
if (longVal >= 0) {
67+
isDuplicate = keysSeen.contains(UnsignedLong.valueOf(longVal));
68+
}
69+
} else if (key instanceof UnsignedLong) {
70+
UnsignedLong ulongVal = (UnsignedLong) key;
71+
isDuplicate = keysSeen.contains(ulongVal.longValue());
72+
}
5473
}
5574

75+
if (isDuplicate) {
76+
throw new LocalizedEvaluationException(
77+
CelDuplicateKeyException.of(key), keyInterpretable.exprId());
78+
}
79+
80+
Object val = values[i].eval(resolver, frame);
81+
5682
if (isOptional[i]) {
5783
if (!(val instanceof Optional)) {
5884
throw new IllegalArgumentException(

runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import com.google.common.base.Joiner;
1818
import dev.cel.common.CelErrorCode;
1919
import dev.cel.common.exceptions.CelRuntimeException;
20-
import dev.cel.common.values.CelValue;
2120
import dev.cel.common.values.CelValueConverter;
2221
import dev.cel.common.values.ErrorValue;
2322
import dev.cel.runtime.CelEvaluationException;
@@ -56,23 +55,21 @@ static Object evalStrictly(
5655
}
5756
}
5857

59-
static Object dispatch(CelResolvedOverload overload, CelValueConverter valueConverter, Object[] args) throws CelEvaluationException {
58+
static Object dispatch(
59+
CelResolvedOverload overload, CelValueConverter valueConverter, Object[] args)
60+
throws CelEvaluationException {
6061
try {
6162
Object result = overload.getDefinition().apply(args);
62-
Object runtimeValue = valueConverter.toRuntimeValue(result);
63-
if (runtimeValue instanceof CelValue) {
64-
return valueConverter.unwrap((CelValue) runtimeValue);
65-
}
66-
67-
return runtimeValue;
63+
return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result));
6864
} catch (CelRuntimeException e) {
6965
// Function dispatch failure that's already been handled -- just propagate.
7066
throw e;
7167
} catch (RuntimeException e) {
7268
// Unexpected function dispatch failure.
73-
throw new IllegalArgumentException(String.format(
74-
"Function '%s' failed with arg(s) '%s'",
75-
overload.getOverloadId(), Joiner.on(", ").join(args)),
69+
throw new IllegalArgumentException(
70+
String.format(
71+
"Function '%s' failed with arg(s) '%s'",
72+
overload.getOverloadId(), Joiner.on(", ").join(args)),
7673
e);
7774
}
7875
}

0 commit comments

Comments
 (0)