diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 1c560cdb9..22bc75b9a 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -155,6 +155,7 @@ cc_library( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -179,6 +180,7 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", + "//common:ast_proto", "//common:container", "//common:decl", "//common:expr", @@ -187,7 +189,10 @@ cc_test( "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", "//testutil:baseline_tests", + "//testutil:test_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 05601fdbb..97dca9c59 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -25,6 +25,7 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -59,6 +60,15 @@ namespace cel::checker_internal { namespace { +bool MatchesBlock(const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call = expr.call_expr(); + return call.function() == "cel.@block" && call.args().size() == 2 && + call.args()[0].has_list_expr(); +} + using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; @@ -204,13 +214,23 @@ class ResolveVisitor : public AstVisitorBase { arena_(arena), current_scope_(&root_scope_) {} - void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } + void PreVisitExpr(const Expr& expr) override { + expr_stack_.push_back(&expr); + if (expr_stack_.size() == 1 && MatchesBlock(expr)) { + ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2); + ABSL_DCHECK(block_init_list_ == nullptr); + block_init_list_ = &expr.call_expr().args()[0]; + } + } void PostVisitExpr(const Expr& expr) override { if (expr_stack_.empty()) { return; } expr_stack_.pop_back(); + if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) { + HandleBlockIndex(&expr); + } } void PostVisitConst(const Expr& expr, const Constant& constant) override; @@ -389,6 +409,7 @@ class ResolveVisitor : public AstVisitorBase { absl::string_view field_name); void HandleOptSelect(const Expr& expr); + void HandleBlockIndex(const Expr* expr); // Get the assigned type of the given subexpression. Should only be called if // the given subexpression is expected to have already been checked. @@ -421,6 +442,7 @@ class ResolveVisitor : public AstVisitorBase { std::vector expr_stack_; absl::flat_hash_map> maybe_namespaced_functions_; + const Expr* block_init_list_ = nullptr; // Select operations that need to be resolved outside of the traversal. // These are handled separately to disambiguate between namespaces and field // accesses @@ -609,8 +631,15 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { } void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { - // Follows list type inferencing behavior in Go (see map comments above). + if (&expr == block_init_list_) { + // Don't try to coalesce list type here because it can influence the + // resolved type of the list elements. cel.@block is always list and + // the elements are treated independently at runtime. + types_[&expr] = ListType(); + return; + } + // Follows list type inferencing behavior in Go (see map comments above). Type overall_elem_type = inference_context_->InstantiateTypeParams(TypeParamType("E")); auto assignability_context = inference_context_->CreateAssignabilityContext(); @@ -1172,6 +1201,44 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { } } +void ResolveVisitor::HandleBlockIndex(const Expr* expr) { + ABSL_DCHECK(block_init_list_ != nullptr); + ABSL_DCHECK(block_init_list_->has_list_expr()); + const auto& elements = block_init_list_->list_expr().elements(); + int index = -1; + for (size_t i = 0; i < elements.size(); ++i) { + if (&elements[i].expr() == expr) { + index = i; + break; + } + } + if (index < 0) { + status_.Update(absl::InternalError( + "could not resolve expression as a cel.@block subexpression")); + return; + } + std::string var_name = absl::StrCat("@index", index); + + // Block is typically manually assembled from logically separate + // expressions so fix the type instead of inferring any remaining free type + // params as for normal subexpressions. + auto type = inference_context_->FinalizeType(GetDeducedType(expr)); + + VariableDecl decl = MakeVariableDecl(var_name, std::move(type)); + + // The C++ runtime requires that the indexes are topologically ordered. + // They just come into scope in order as we walk the AST so we don't need + // to do any additional work to check references to other initializers in + // an init expr. + // + // We may want to relax this just requiring that the references are + // acyclic as in the Java implementation. + auto* scope = + comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get(); + scope->InsertVariableIfAbsent(std::move(decl)); + current_scope_ = scope; +} + class ResolveRewriter : public AstRewriterBase { public: explicit ResolveRewriter(const ResolveVisitor& visitor, @@ -1230,15 +1297,15 @@ class ResolveRewriter : public AstRewriterBase { if (auto iter = visitor_.types().find(&expr); iter != visitor_.types().end()) { - auto flattened_type = - FlattenType(inference_context_.FinalizeType(iter->second)); + cel::Type finalized_type = inference_context_.FinalizeType(iter->second); + auto flattened_type = FlattenType(finalized_type); if (!flattened_type.ok()) { status_.Update(flattened_type.status()); return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); - resolved_types_[expr.id()] = iter->second; + resolved_types_[expr.id()] = finalized_type; rewritten = true; } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index e6cd641d6..4b21c50f5 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -36,6 +36,7 @@ #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "common/ast_proto.h" #include "common/container.h" #include "common/decl.h" #include "common/expr.h" @@ -45,7 +46,10 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" #include "testutil/baseline_tests.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -108,6 +112,17 @@ google::protobuf::Arena* absl_nonnull TestTypeArena() { return &(*kArena); } +absl::StatusOr> MakeTestParsedAstWithMacros( + absl::string_view expression, const cel::MacroRegistry& registry) { + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse( + *source, registry, + {.enable_optional_syntax = true})); + return cel::CreateAstFromParsedExpr(parsed_expr); +} + FunctionDecl MakeIdentFunction() { auto decl = MakeFunctionDecl( "identity", @@ -272,6 +287,13 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena /*return_type=*/TypeType(arena, TypeParamType("A")), TypeParamType("A")))); + Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto block_decl, + MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam, + ListType(), kParam))); + env.InsertFunctionIfAbsent(std::move(block_decl)); + env.InsertFunctionIfAbsent(std::move(not_op)); env.InsertFunctionIfAbsent(std::move(not_strictly_false)); env.InsertFunctionIfAbsent(std::move(add_op)); @@ -289,6 +311,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); env.InsertFunctionIfAbsent(std::move(to_timestamp)); + env.InsertFunctionIfAbsent(std::move(block_decl)); return absl::OkStatus(); } @@ -308,6 +331,78 @@ TEST(TypeCheckerImplTest, SmokeTest) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } +TEST(TypeCheckerImplTest, BlockMacroSupport) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAstWithMacros( + "cel.block([1, 2], cel.index(0) + cel.index(1))", registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Overall type should be int. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kInt64); +} + +TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // cel.index(1) refers to 'a' which is string. + // So overall type should be string. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kString); +} + +TEST(TypeCheckerImplTest, BadIndex) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + HasSubstr("undeclared reference to '@index2' (in container")); +} + TEST(TypeCheckerImplTest, SimpleIdentsResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/compiler/BUILD b/compiler/BUILD index 50bc1c9fa..d4a0ab4ac 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -42,10 +42,12 @@ cc_library( hdrs = ["compiler_factory.h"], deps = [ ":compiler", + "//checker:type_check_issue", "//checker:type_checker", "//checker:type_checker_builder", "//checker:type_checker_builder_factory", "//checker:validation_result", + "//common:ast", "//common:source", "//internal:noop_delete", "//internal:status_macros", diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 14586825e..6df28cab2 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -17,16 +17,19 @@ #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/source.h" #include "compiler/compiler.h" #include "internal/status_macros.h" @@ -55,9 +58,26 @@ class CompilerImpl : public Compiler { google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); - CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); + std::vector parse_issues; + absl::StatusOr> ast = + parser_->Parse(*source, &parse_issues); + if (!ast.ok()) { + if (ast.status().code() != absl::StatusCode::kInvalidArgument || + parse_issues.empty()) { + return ast.status(); + } + std::vector check_issues; + check_issues.reserve(parse_issues.size()); + for (const auto& issue : parse_issues) { + check_issues.push_back(TypeCheckIssue::CreateError( + issue.location(), std::string(issue.message()))); + } + ValidationResult result(std::move(check_issues)); + result.SetSource(std::move(source)); + return result; + } CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast), arena)); + type_checker_->Check(*std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 214c23765..b31f7457c 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -413,5 +413,17 @@ TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); } +TEST(CompilerFactoryTest, ReturnsIssuesFromParser) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a +")); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), testing::Not(testing::IsEmpty())); +} + } // namespace } // namespace cel diff --git a/conformance/BUILD b/conformance/BUILD index 139739891..9b527cf35 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -69,6 +69,7 @@ cc_library( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//testutil:test_macros", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -221,7 +222,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ "proto3/set_null/list_value", "proto3/set_null/single_struct", - # cel.@block + # no optional support for legacy types "block_ext/basic/optional_list", "block_ext/basic/optional_map", "block_ext/basic/optional_map_chained", @@ -231,7 +232,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ _TESTS_TO_SKIP_CHECKED = [ # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. - "block_ext", + # "block_ext", ] _TESTS_TO_SKIP_LEGACY_DASHBOARD = [ diff --git a/conformance/service.cc b/conformance/service.cc index 3edc214e6..463334bb5 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -14,7 +14,6 @@ #include "conformance/service.h" -#include #include #include #include @@ -36,11 +35,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" @@ -48,7 +44,6 @@ #include "common/ast.h" #include "common/ast_proto.h" #include "common/decl_proto_v1alpha1.h" -#include "common/expr.h" #include "common/internal/value_conversion.h" #include "common/source.h" #include "common/value.h" @@ -72,8 +67,6 @@ #include "extensions/select_optimization.h" #include "extensions/strings.h" #include "internal/status_macros.h" -#include "parser/macro.h" -#include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" @@ -85,6 +78,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" @@ -106,109 +100,6 @@ namespace google::api::expr::runtime { namespace { -bool IsCelNamespace(const cel::Expr& target) { - return target.has_ident_expr() && target.ident_expr().name() == "cel"; -} - -absl::optional CelBlockMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& bindings_arg = args[0]; - if (!bindings_arg.has_list_expr()) { - return factory.ReportErrorAt( - bindings_arg, "cel.block requires the first arg to be a list literal"); - } - return factory.NewCall("cel.@block", args); -} - -absl::optional CelIndexMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& index_arg = args[0]; - if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - int64_t index = index_arg.const_expr().int_value(); - if (index < 0) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - return factory.NewIdent(absl::StrCat("@index", index)); -} - -absl::optional CelIterVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.iterVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.iterVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::optional CelAccuVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.accuVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.accuVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::Status RegisterCelBlockMacros(cel::MacroRegistry& registry) { - CEL_ASSIGN_OR_RETURN(auto block_macro, - cel::Macro::Receiver("block", 2, CelBlockMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(block_macro)); - CEL_ASSIGN_OR_RETURN(auto index_macro, - cel::Macro::Receiver("index", 1, CelIndexMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(index_macro)); - CEL_ASSIGN_OR_RETURN( - auto iter_var_macro, - cel::Macro::Receiver("iterVar", 2, CelIterVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(iter_var_macro)); - CEL_ASSIGN_OR_RETURN( - auto accu_var_macro, - cel::Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(accu_var_macro)); - return absl::OkStatus(); -} - google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } @@ -250,7 +141,7 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); - CEL_RETURN_IF_ERROR(RegisterCelBlockMacros(macros)); + CEL_RETURN_IF_ERROR(cel::test::RegisterTestMacros(macros)); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, @@ -285,6 +176,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, if (!request.no_std_env()) { CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( diff --git a/extensions/BUILD b/extensions/BUILD index ff37e2c3f..05104a4a5 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -215,7 +215,10 @@ cc_library( srcs = ["bindings_ext.cc"], hdrs = ["bindings_ext.h"], deps = [ - "//common:ast", + "//checker:type_checker_builder", + "//common:decl", + "//common:expr", + "//common:type", "//compiler", "//internal:status_macros", "//parser:macro", diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc index f097709ca..c59f724bd 100644 --- a/extensions/bindings_ext.cc +++ b/extensions/bindings_ext.cc @@ -21,7 +21,10 @@ #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "common/ast.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" @@ -34,6 +37,8 @@ namespace { static constexpr char kCelNamespace[] = "cel"; static constexpr char kBind[] = "bind"; +static constexpr char kBlock[] = "cel.@block"; +static constexpr char kBlockOverloadId[] = "cel_block_list"; static constexpr char kUnusedIterVar[] = "#unused"; bool IsTargetNamespace(const Expr& target) { @@ -47,6 +52,19 @@ inline absl::Status ConfigureParser(ParserBuilder& parser_builder) { return absl::OkStatus(); } +absl::Status ConfigureChecker(int version, + TypeCheckerBuilder& type_checker_builder) { + if (version < 1) { + return absl::OkStatus(); + } + static Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl(kBlock, MakeOverloadDecl(kBlockOverloadId, kParam, + ListType(), kParam))); + return type_checker_builder.AddFunction(std::move(decl)); +} + } // namespace std::vector bindings_macros() { @@ -70,8 +88,16 @@ std::vector bindings_macros() { return {*cel_bind}; } -CompilerLibrary BindingsCompilerLibrary() { - return CompilerLibrary("cel.lib.ext.bindings", &ConfigureParser); +CompilerLibrary BindingsCompilerLibrary(int version) { + return CompilerLibrary( + "cel.lib.ext.bindings", &ConfigureParser, + [version](auto& b) { return ConfigureChecker(version, b); }); +} + +CheckerLibrary BindingsCheckerLibrary(int version) { + return CheckerLibrary{"cel.lib.ext.bindings", [version](auto& b) { + return ConfigureChecker(version, b); + }}; } } // namespace cel::extensions diff --git a/extensions/bindings_ext.h b/extensions/bindings_ext.h index a338b24f6..40b83a37f 100644 --- a/extensions/bindings_ext.h +++ b/extensions/bindings_ext.h @@ -25,6 +25,7 @@ namespace cel::extensions { +constexpr int kBindingsVersionLatest = 1; // bindings_macros() returns a macro for cel.bind() which can be used to support // local variable bindings within expressions. std::vector bindings_macros(); @@ -35,7 +36,10 @@ inline absl::Status RegisterBindingsMacros(MacroRegistry& registry, } // Declarations for the bindings extension library. -CompilerLibrary BindingsCompilerLibrary(); +CompilerLibrary BindingsCompilerLibrary(int version = kBindingsVersionLatest); + +// Declarations for the bindings extension library. +CheckerLibrary BindingsCheckerLibrary(int version = kBindingsVersionLatest); } // namespace cel::extensions diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 3088e6fa8..8d10caa35 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -23,7 +23,6 @@ #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -110,19 +109,6 @@ struct MacroTestCase { absl::string_view err = ""; }; -std::string FormatIssues(const cel::ValidationResult& result) { - std::string issues; - for (const auto& issue : result.GetIssues()) { - if (!issues.empty()) { - absl::StrAppend(&issues, "\n", - issue.ToDisplayString(*result.GetSource())); - } else { - issues = issue.ToDisplayString(*result.GetSource()); - } - } - return issues; -} - class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) @@ -381,16 +367,16 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); - auto result = compiler->Compile(test_case.expr, ""); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile(test_case.expr, "")); if (!test_case.err.empty()) { - EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr(test_case.err))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.err)); return; } - ASSERT_THAT(result, IsOk()); - ASSERT_TRUE(result->IsValid()) << FormatIssues(*result); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( @@ -411,9 +397,8 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); - - ASSERT_OK_AND_ASSIGN(auto program, - runtime->CreateProgram(*result->ReleaseAst())); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); google::protobuf::Arena arena; cel::Activation activation; diff --git a/parser/BUILD b/parser/BUILD index 63813bb59..6650d9fe9 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -244,9 +244,11 @@ cc_library( ":options", "//common:ast", "//common:source", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/parser/parser.cc b/parser/parser.cc index f4ee3a1c5..709e2fd41 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -112,13 +112,12 @@ struct ParserError { }; std::string DisplayParserError(const cel::Source& source, - const ParserError& error) { - auto location = - source.GetLocation(error.range.begin).value_or(SourceLocation{}); + SourceLocation location, + absl::string_view message) { return absl::StrCat(absl::StrFormat("ERROR: %s:%zu:%zu: %s", source.description(), location.line, // add one to the 0-based column - location.column + 1, error.message), + location.column + 1, message), source.DisplayErrorLocation(location)); } @@ -209,7 +208,7 @@ class ParserMacroExprFactory final : public MacroExprFactory { bool HasErrors() const { return error_count_ != 0; } - std::string ErrorMessage() { + std::vector CollectIssues() { // Errors are collected as they are encountered, not by their location // within the source. To have a more stable error message as implementation // details change, we sort the collected errors by their source location @@ -226,20 +225,23 @@ class ParserMacroExprFactory final : public MacroExprFactory { }); // Build the summary error message using the sorted errors. bool errors_truncated = error_count_ > 100; - std::vector messages; - messages.reserve( + std::vector issues; + issues.reserve( errors_.size() + errors_truncated); // Reserve space for the transform and an // additional element when truncation occurs. - std::transform(errors_.begin(), errors_.end(), std::back_inserter(messages), - [this](const ParserError& error) { - return cel::DisplayParserError(source_, error); - }); + std::transform( + errors_.begin(), errors_.end(), std::back_inserter(issues), + [this](const ParserError& error) { + auto location = + source_.GetLocation(error.range.begin).value_or(SourceLocation{}); + return cel::ParseIssue(location, error.message); + }); if (errors_truncated) { - messages.emplace_back( - absl::StrCat(error_count_ - 100, " more errors were truncated.")); + issues.push_back(cel::ParseIssue( + absl::StrCat(error_count_ - 100, " more errors were truncated."))); } - return absl::StrJoin(messages, "\n"); + return issues; } void AddMacroCall(int64_t macro_id, absl::string_view function, @@ -602,6 +604,15 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) { return factory_.NewCall(ops_[mid], function_, std::move(arguments)); } +std::string FormatIssues(const cel::Source& source, + absl::Span issues) { + return absl::StrJoin( + issues, "\n", [&source](std::string* out, const cel::ParseIssue& issue) { + absl::StrAppend(out, cel::DisplayParserError(source, issue.location(), + issue.message())); + }); +} + class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: @@ -673,7 +684,7 @@ class ParserVisitor final : public CelBaseVisitor, const std::string& msg, std::exception_ptr e) override; bool HasErrored() const; - std::string ErrorMessage(); + std::vector CollectIssues(); private: template @@ -1434,7 +1445,9 @@ void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, bool ParserVisitor::HasErrored() const { return factory_.HasErrors(); } -std::string ParserVisitor::ErrorMessage() { return factory_.ErrorMessage(); } +std::vector ParserVisitor::CollectIssues() { + return factory_.CollectIssues(); +} Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, @@ -1638,9 +1651,10 @@ struct ParseResult { EnrichedSourceInfo enriched_source_info; }; -absl::StatusOr ParseImpl(const cel::Source& source, - const cel::MacroRegistry& registry, - const ParserOptions& options) { +absl::StatusOr ParseImpl( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options, + std::vector* parse_issues = nullptr) { try { CodePointStream input(source.content(), source.description()); if (input.size() > options.expression_size_codepoint_limit) { @@ -1673,13 +1687,23 @@ absl::StatusOr ParseImpl(const cel::Source& source, expr = ExprFromAny(visitor.visit(parser.start())); } catch (const ParseCancellationException& e) { if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return absl::CancelledError(e.what()); } if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return { @@ -1706,10 +1730,12 @@ class ParserImpl : public cel::Parser { macro_registry_(std::move(macro_registry)), library_ids_(std::move(library_ids)) {} - absl::StatusOr> Parse( - const cel::Source& source) const override { + absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* parse_issues) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, - ParseImpl(source, macro_registry_, options_)); + ::google::api::expr::parser::ParseImpl( + source, macro_registry_, options_, parse_issues)); return std::make_unique(std::move(parse_result.expr), std::move(parse_result.source_info)); } diff --git a/parser/parser_interface.h b/parser/parser_interface.h index 7cc21ff26..ad6e8ca84 100644 --- a/parser/parser_interface.h +++ b/parser/parser_interface.h @@ -16,10 +16,14 @@ #include #include +#include +#include +#include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "common/ast.h" #include "common/source.h" #include "parser/macro.h" @@ -73,6 +77,26 @@ class ParserBuilder { virtual absl::StatusOr> Build() = 0; }; +// Information about a parse failure. +class ParseIssue { + public: + explicit ParseIssue(std::string message) : message_(std::move(message)) {} + ParseIssue(SourceLocation location, std::string message) + : location_(location), message_(std::move(message)) {} + + ParseIssue(const ParseIssue& other) = default; + ParseIssue& operator=(const ParseIssue& other) = default; + ParseIssue(ParseIssue&& other) = default; + ParseIssue& operator=(ParseIssue&& other) = default; + + SourceLocation location() const { return location_; } + absl::string_view message() const { return message_; } + + private: + SourceLocation location_; + std::string message_; +}; + // Interface for stateful CEL parser objects for use with a `Compiler` // (bundled parse and type check). This is not needed for most users: // prefer using the free functions in `parser.h` for more flexibility. @@ -81,13 +105,35 @@ class Parser { virtual ~Parser() = default; // Parses the given source into a CEL AST. - virtual absl::StatusOr> Parse( - const cel::Source& source) const = 0; + absl::StatusOr> Parse( + const cel::Source& source) const; + + // Parses the given source into a CEL AST, collecting parse errors in + // `issues`. If `issues` is non-null, it will be cleared and all parse + // issues will be appended to it. + absl::StatusOr> Parse( + const cel::Source& source, std::vector* issues) const; // Returns a builder initialized with the configuration of this parser. virtual std::unique_ptr ToBuilder() const = 0; + + protected: + virtual absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* absl_nullable parse_issues) const = 0; }; +inline absl::StatusOr> Parser::Parse( + const cel::Source& source) const { + return ParseImpl(source, nullptr); +} + +inline absl::StatusOr> Parser::Parse( + const cel::Source& source, std::vector* issues) const { + if (issues != nullptr) issues->clear(); + return ParseImpl(source, issues); +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ diff --git a/parser/parser_test.cc b/parser/parser_test.cc index a1a65481d..587b63a30 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1782,6 +1782,25 @@ TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { EXPECT_FALSE(ast->IsChecked()); } +TEST(ParserTest, ParseFailurePopulatesIssues) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a +", "test.cel")); + std::vector issues; + auto ast_result = parser->Parse(*source, &issues); + EXPECT_THAT(ast_result, Not(IsOk())); + ASSERT_THAT(issues, testing::SizeIs(1)); + EXPECT_THAT(ast_result.status().message(), + HasSubstr("ERROR: test.cel:1:4: Syntax error: mismatched input " + "'' expecting")); + EXPECT_THAT(issues[0].message(), + HasSubstr("Syntax error: mismatched input '' expecting")); + EXPECT_EQ(issues[0].location().line, 1); + // 0-based, but adjusted to 1-based in error message. + EXPECT_EQ(issues[0].location().column, 3); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); diff --git a/testutil/BUILD b/testutil/BUILD index 292696033..782c95ca6 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -62,6 +62,26 @@ cc_library( deps = ["//internal:proto_matchers"], ) +cc_library( + name = "test_macros", + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + deps = [ + "//common:expr", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "baseline_tests", testonly = True, diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc new file mode 100644 index 000000000..158135762 --- /dev/null +++ b/testutil/test_macros.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/test_macros.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +namespace { + +bool IsCelNamespace(const Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == "cel"; +} + +absl::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& bindings_arg = args[0]; + if (!bindings_arg.has_list_expr()) { + return factory.ReportErrorAt( + bindings_arg, "cel.block requires the first arg to be a list literal"); + } + return factory.NewCall("cel.@block", args); +} + +absl::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@index", index)); +} + +absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.iterVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.iterVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +absl::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.accuVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.accuVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +Macro MakeCelBlockMacro() { + auto macro_or_status = Macro::Receiver("block", 2, CelBlockMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIndexMacro() { + auto macro_or_status = Macro::Receiver("index", 1, CelIndexMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIterVarMacro() { + auto macro_or_status = Macro::Receiver("iterVar", 2, CelIterVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelAccuVarMacro() { + auto macro_or_status = Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +} // namespace + +const Macro& CelBlockMacro() { + static const absl::NoDestructor macro(MakeCelBlockMacro()); + return *macro; +} + +const Macro& CelIndexMacro() { + static const absl::NoDestructor macro(MakeCelIndexMacro()); + return *macro; +} + +const Macro& CelIterVarMacro() { + static const absl::NoDestructor macro(MakeCelIterVarMacro()); + return *macro; +} + +const Macro& CelAccuVarMacro() { + static const absl::NoDestructor macro(MakeCelAccuVarMacro()); + return *macro; +} + +absl::Status RegisterTestMacros(MacroRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelBlockMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIndexMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIterVarMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelAccuVarMacro())); + return absl::OkStatus(); +} + +} // namespace cel::test diff --git a/testutil/test_macros.h b/testutil/test_macros.h new file mode 100644 index 000000000..cad897999 --- /dev/null +++ b/testutil/test_macros.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +const Macro& CelBlockMacro(); +const Macro& CelIndexMacro(); +const Macro& CelIterVarMacro(); +const Macro& CelAccuVarMacro(); + +absl::Status RegisterTestMacros(MacroRegistry& registry); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_