Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions conformance/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ cc_library(
"//common:ast",
"//common:ast_proto",
"//common:decl_proto_v1alpha1",
"//common:expr",
"//common:source",
"//common:value",
"//common/internal:value_conversion",
Expand All @@ -57,8 +56,6 @@ cc_library(
"//extensions/protobuf:enum_adapter",
"//internal:status_macros",
"//parser",
"//parser:macro",
"//parser:macro_expr_factory",
"//parser:macro_registry",
"//parser:options",
"//parser:standard_macros",
Expand All @@ -75,8 +72,6 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_cel_spec//proto/cel/expr:syntax_cc_proto",
"@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto",
"@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto",
Expand Down Expand Up @@ -302,6 +297,24 @@ gen_conformance_tests(
skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED,
)

gen_conformance_tests(
name = "conformance_variadic",
checked = True,
data = _ALL_TESTS,
enable_variadic_logical_operators = True,
modern = True,
skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED,
)

gen_conformance_tests(
name = "conformance_legacy_variadic",
checked = True,
data = _ALL_TESTS,
enable_variadic_logical_operators = True,
modern = False,
skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED,
)

# Generates a bunch of `cc_test` whose names follow the pattern
# `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`.
gen_conformance_tests(
Expand Down
17 changes: 11 additions & 6 deletions conformance/run.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive):
],
)

def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard):
def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators):
args = []
if modern:
args.append("--modern")
Expand All @@ -72,12 +72,14 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check,
args.append("--noskip_check")
if dashboard:
args.append("--dashboard")
if enable_variadic_logical_operators:
args.append("--enable_variadic_logical_operators")
return args

def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard):
def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard, enable_variadic_logical_operators):
cc_test(
name = _conformance_test_name(name, optimize, recursive),
args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(rlocationpath {})".format(test) for test in data],
args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators) + ["$(rlocationpath {})".format(test) for test in data],
env = select(
{
"@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)},
Expand All @@ -89,18 +91,20 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_
tags = tags,
)

def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []):
def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = [], enable_variadic_logical_operators = False):
"""Generates conformance tests.

Args:
name: prefix for all tests
data: textproto targets describing conformance tests
modern: run using modern APIs
checked: whether to apply type checking
data: textproto targets describing conformance tests
select_opt: enable select optimization
dashboard: enable dashboard mode
skip_tests: tests to skip in the format of the cel-spec test runner. See documentation
in github.com/google/cel-spec/tests/simple/simple_test.go
tags: tags added to the generated targets
dashboard: enable dashboard mode
enable_variadic_logical_operators: enable variadic logical operators
"""
skip_check = not checked
tests = []
Expand All @@ -119,6 +123,7 @@ def gen_conformance_tests(name, data, modern = False, checked = False, select_op
skip_tests = _expand_tests_to_skip(skip_tests),
tags = tags,
dashboard = dashboard,
enable_variadic_logical_operators = enable_variadic_logical_operators,
)
native.test_suite(
name = name,
Expand Down
5 changes: 5 additions & 0 deletions conformance/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ ABSL_FLAG(std::vector<std::string>, skip_tests, {}, "Tests to skip");
ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures");
ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions");
ABSL_FLAG(bool, select_optimization, false, "Enable select optimization.");
ABSL_FLAG(bool, enable_variadic_logical_operators, false,
"Enable parsing logical AND & OR operators as a single flat variadic "
"call.");

namespace {

Expand Down Expand Up @@ -261,6 +264,8 @@ NewConformanceServiceFromFlags() {
.modern = absl::GetFlag(FLAGS_modern),
.recursive = absl::GetFlag(FLAGS_recursive),
.select_optimization = absl::GetFlag(FLAGS_select_optimization),
.enable_variadic_logical_operators =
absl::GetFlag(FLAGS_enable_variadic_logical_operators),
});
ABSL_CHECK_OK(status_or_service);
return std::shared_ptr<cel_conformance::ConformanceServiceInterface>(
Expand Down
46 changes: 30 additions & 16 deletions conformance/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ cel::expr::Expr ExtractExpr(

absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request,
conformance::v1alpha1::ParseResponse& response,
bool enable_optional_syntax) {
bool enable_optional_syntax,
bool enable_variadic_logical_operators) {
if (request.cel_source().empty()) {
return absl::InvalidArgumentError("no source code");
}
cel::ParserOptions options;
options.enable_optional_syntax = enable_optional_syntax;
options.enable_quoted_identifiers = true;
options.enable_variadic_logical_operators = enable_variadic_logical_operators;
cel::MacroRegistry macros;
CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options));
CEL_RETURN_IF_ERROR(
Expand Down Expand Up @@ -236,7 +238,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena,
class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
public:
static absl::StatusOr<std::unique_ptr<LegacyConformanceServiceImpl>> Create(
bool optimize, bool recursive, bool select_optimization) {
bool optimize, bool recursive, bool select_optimization,
bool enable_variadic_logical_operators) {
static auto* constant_arena = new Arena();

google::protobuf::LinkMessageReflection<
Expand Down Expand Up @@ -313,14 +316,15 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions(
builder->GetRegistry(), options));

return absl::WrapUnique(
new LegacyConformanceServiceImpl(std::move(builder)));
return absl::WrapUnique(new LegacyConformanceServiceImpl(
std::move(builder), enable_variadic_logical_operators));
}

void Parse(const conformance::v1alpha1::ParseRequest& request,
conformance::v1alpha1::ParseResponse& response) override {
auto status =
LegacyParse(request, response, /*enable_optional_syntax=*/false);
LegacyParse(request, response, /*enable_optional_syntax=*/false,
enable_variadic_logical_operators_);
if (!status.ok()) {
auto* issue = response.add_issues();
issue->set_code(ToGrpcCode(status.code()));
Expand Down Expand Up @@ -418,17 +422,20 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
}

private:
explicit LegacyConformanceServiceImpl(
std::unique_ptr<CelExpressionBuilder> builder)
: builder_(std::move(builder)) {}
LegacyConformanceServiceImpl(std::unique_ptr<CelExpressionBuilder> builder,
bool enable_variadic_logical_operators)
: builder_(std::move(builder)),
enable_variadic_logical_operators_(enable_variadic_logical_operators) {}

std::unique_ptr<CelExpressionBuilder> builder_;
bool enable_variadic_logical_operators_;
};

class ModernConformanceServiceImpl : public ConformanceServiceInterface {
public:
static absl::StatusOr<std::unique_ptr<ModernConformanceServiceImpl>> Create(
bool optimize, bool recursive, bool select_optimization) {
bool optimize, bool recursive, bool select_optimization,
bool enable_variadic_logical_operators) {
google::protobuf::LinkMessageReflection<
cel::expr::conformance::proto3::TestAllTypes>();
google::protobuf::LinkMessageReflection<
Expand Down Expand Up @@ -470,8 +477,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
options.max_recursion_depth = 48;
}

return absl::WrapUnique(new ModernConformanceServiceImpl(
options, optimize, select_optimization));
return absl::WrapUnique(
new ModernConformanceServiceImpl(options, optimize, select_optimization,
enable_variadic_logical_operators));
}

absl::StatusOr<std::unique_ptr<const cel::Runtime>> Setup(
Expand Down Expand Up @@ -523,7 +531,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
void Parse(const conformance::v1alpha1::ParseRequest& request,
conformance::v1alpha1::ParseResponse& response) override {
auto status =
LegacyParse(request, response, /*enable_optional_syntax=*/true);
LegacyParse(request, response, /*enable_optional_syntax=*/true,
enable_variadic_logical_operators_);
if (!status.ok()) {
auto* issue = response.add_issues();
issue->set_code(ToGrpcCode(status.code()));
Expand Down Expand Up @@ -614,10 +623,12 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
private:
ModernConformanceServiceImpl(const RuntimeOptions& options,
bool enable_optimizations,
bool enable_select_optimization)
bool enable_select_optimization,
bool enable_variadic_logical_operators)
: options_(options),
enable_optimizations_(enable_optimizations),
enable_select_optimization_(enable_select_optimization) {}
enable_select_optimization_(enable_select_optimization),
enable_variadic_logical_operators_(enable_variadic_logical_operators) {}

static absl::StatusOr<std::unique_ptr<cel::TraceableProgram>> Plan(
const cel::Runtime& runtime,
Expand Down Expand Up @@ -648,6 +659,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
RuntimeOptions options_;
bool enable_optimizations_;
bool enable_select_optimization_;
bool enable_variadic_logical_operators_;
};

} // namespace
Expand All @@ -660,10 +672,12 @@ absl::StatusOr<std::unique_ptr<ConformanceServiceInterface>>
NewConformanceService(const ConformanceServiceOptions& options) {
if (options.modern) {
return google::api::expr::runtime::ModernConformanceServiceImpl::Create(
options.optimize, options.recursive, options.select_optimization);
options.optimize, options.recursive, options.select_optimization,
options.enable_variadic_logical_operators);
} else {
return google::api::expr::runtime::LegacyConformanceServiceImpl::Create(
options.optimize, options.recursive, options.select_optimization);
options.optimize, options.recursive, options.select_optimization,
options.enable_variadic_logical_operators);
}
}

Expand Down
1 change: 1 addition & 0 deletions conformance/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ struct ConformanceServiceOptions {
bool arena;
bool recursive;
bool select_optimization;
bool enable_variadic_logical_operators = false;
};

absl::StatusOr<std::unique_ptr<ConformanceServiceInterface>>
Expand Down
1 change: 1 addition & 0 deletions eval/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ cc_test(
"//internal:status_macros",
"//internal:testing",
"//parser",
"//parser:options",
"//runtime:function",
"//runtime:function_adapter",
"//runtime:runtime_options",
Expand Down
80 changes: 41 additions & 39 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2154,7 +2154,7 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) {
case BinaryCond::kOr:
visitor_->ValidateOrError(
!expr->call_expr().has_target() &&
expr->call_expr().args().size() == 2,
expr->call_expr().args().size() >= 2,
"Invalid argument count for a binary function call.");
break;
case BinaryCond::kOptionalOr:
Expand All @@ -2172,28 +2172,40 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) {
return;
}
const int last_arg_index = expr->call_expr().args().size() - 1;
if (short_circuiting_ && arg_num < last_arg_index &&
(cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) {
// If first branch evaluation result is enough to determine output,
// jump over the second branch and provide result of the first argument as
// final output.
// Retain pointers to the jump steps so we can update the target after
// planning the next arguments.
std::unique_ptr<JumpStepBase> jump_step;
switch (cond_) {
case BinaryCond::kAnd:
jump_step = CreateCondJumpStep(false, true, {}, expr->id());
break;
case BinaryCond::kOr:
jump_step = CreateCondJumpStep(true, true, {}, expr->id());
break;
default:
ABSL_UNREACHABLE();
if (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) {
if (arg_num > 0) {
switch (cond_) {
case BinaryCond::kAnd:
visitor_->AddStep(CreateAndStep(expr->id()));
break;
case BinaryCond::kOr:
visitor_->AddStep(CreateOrStep(expr->id()));
break;
default:
break;
}
if (short_circuiting_ && !jump_steps_.empty()) {
visitor_->SetProgressStatusIfError(
jump_steps_.back().set_target(visitor_->GetCurrentIndex()));
}
}
ProgramStepIndex index = visitor_->GetCurrentIndex();
if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step));
jump_step_ptr) {
jump_steps_.push_back(Jump(index, jump_step_ptr));
if (short_circuiting_ && arg_num < last_arg_index) {
std::unique_ptr<JumpStepBase> jump_step;
switch (cond_) {
case BinaryCond::kAnd:
jump_step = CreateCondJumpStep(false, true, {}, expr->id());
break;
case BinaryCond::kOr:
jump_step = CreateCondJumpStep(true, true, {}, expr->id());
break;
default:
ABSL_UNREACHABLE();
}
ProgramStepIndex index = visitor_->GetCurrentIndex();
if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step));
jump_step_ptr) {
jump_steps_.push_back(Jump(index, jump_step_ptr));
}
}
}
}
Expand Down Expand Up @@ -2251,17 +2263,9 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) {
return;
}

int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)
? expr->call_expr().args().size()
: 2;
for (int i = 0; i < args_count - 1; ++i) {
if (cond_ == BinaryCond::kOptionalOr ||
cond_ == BinaryCond::kOptionalOrValue) {
switch (cond_) {
case BinaryCond::kAnd:
visitor_->AddStep(CreateAndStep(expr->id()));
break;
case BinaryCond::kOr:
visitor_->AddStep(CreateOrStep(expr->id()));
break;
case BinaryCond::kOptionalOr:
visitor_->AddStep(
CreateOptionalOrStep(/*is_or_value=*/false, expr->id()));
Expand All @@ -2273,13 +2277,11 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) {
default:
ABSL_UNREACHABLE();
}
}
if (short_circuiting_) {
// If short-circuiting is enabled, point the conditional jump past the
// boolean operator step.
for (auto& jump : jump_steps_) {
visitor_->SetProgressStatusIfError(
jump.set_target(visitor_->GetCurrentIndex()));
if (short_circuiting_) {
for (auto& jump : jump_steps_) {
visitor_->SetProgressStatusIfError(
jump.set_target(visitor_->GetCurrentIndex()));
}
}
}
}
Expand Down
Loading
Loading