Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ cc_library(
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -179,6 +180,7 @@ cc_test(
"//checker:type_checker_builder",
"//checker:validation_result",
"//common:ast",
"//common:ast_proto",
"//common:container",
"//common:decl",
"//common:expr",
Expand All @@ -187,7 +189,10 @@ cc_test(
"//internal:status_macros",
"//internal:testing",
"//internal:testing_descriptor_pool",
"//parser",
"//parser:macro_registry",
"//testutil:baseline_tests",
"//testutil:test_macros",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
77 changes: 72 additions & 5 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "absl/base/nullability.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
Expand Down Expand Up @@ -59,6 +60,15 @@
namespace cel::checker_internal {
namespace {

bool MatchesBlock(const Expr& expr) {
if (!expr.has_call_expr()) {
return false;
}
const auto& call = expr.call_expr();
return call.function() == "cel.@block" && call.args().size() == 2 &&
call.args()[0].has_list_expr();
}

using AstType = cel::TypeSpec;
using Severity = TypeCheckIssue::Severity;

Expand Down Expand Up @@ -204,13 +214,23 @@ class ResolveVisitor : public AstVisitorBase {
arena_(arena),
current_scope_(&root_scope_) {}

void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); }
void PreVisitExpr(const Expr& expr) override {
expr_stack_.push_back(&expr);
if (expr_stack_.size() == 1 && MatchesBlock(expr)) {
ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2);
ABSL_DCHECK(block_init_list_ == nullptr);
block_init_list_ = &expr.call_expr().args()[0];
}
}

void PostVisitExpr(const Expr& expr) override {
if (expr_stack_.empty()) {
return;
}
expr_stack_.pop_back();
if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) {
HandleBlockIndex(&expr);
}
}

void PostVisitConst(const Expr& expr, const Constant& constant) override;
Expand Down Expand Up @@ -389,6 +409,7 @@ class ResolveVisitor : public AstVisitorBase {
absl::string_view field_name);

void HandleOptSelect(const Expr& expr);
void HandleBlockIndex(const Expr* expr);

// Get the assigned type of the given subexpression. Should only be called if
// the given subexpression is expected to have already been checked.
Expand Down Expand Up @@ -421,6 +442,7 @@ class ResolveVisitor : public AstVisitorBase {
std::vector<const Expr*> expr_stack_;
absl::flat_hash_map<const Expr*, std::vector<std::string>>
maybe_namespaced_functions_;
const Expr* block_init_list_ = nullptr;
// Select operations that need to be resolved outside of the traversal.
// These are handled separately to disambiguate between namespaces and field
// accesses
Expand Down Expand Up @@ -609,8 +631,15 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) {
}

void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) {
// Follows list type inferencing behavior in Go (see map comments above).
if (&expr == block_init_list_) {
// Don't try to coalesce list type here because it can influence the
// resolved type of the list elements. cel.@block is always list<dyn> and
// the elements are treated independently at runtime.
types_[&expr] = ListType();
return;
}

// Follows list type inferencing behavior in Go (see map comments above).
Type overall_elem_type =
inference_context_->InstantiateTypeParams(TypeParamType("E"));
auto assignability_context = inference_context_->CreateAssignabilityContext();
Expand Down Expand Up @@ -1172,6 +1201,44 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) {
}
}

void ResolveVisitor::HandleBlockIndex(const Expr* expr) {
ABSL_DCHECK(block_init_list_ != nullptr);
ABSL_DCHECK(block_init_list_->has_list_expr());
const auto& elements = block_init_list_->list_expr().elements();
int index = -1;
for (size_t i = 0; i < elements.size(); ++i) {
if (&elements[i].expr() == expr) {
index = i;
break;
}
}
if (index < 0) {
status_.Update(absl::InternalError(
"could not resolve expression as a cel.@block subexpression"));
return;
}
std::string var_name = absl::StrCat("@index", index);

// Block is typically manually assembled from logically separate
// expressions so fix the type instead of inferring any remaining free type
// params as for normal subexpressions.
auto type = inference_context_->FinalizeType(GetDeducedType(expr));

VariableDecl decl = MakeVariableDecl(var_name, std::move(type));

// The C++ runtime requires that the indexes are topologically ordered.
// They just come into scope in order as we walk the AST so we don't need
// to do any additional work to check references to other initializers in
// an init expr.
//
// We may want to relax this just requiring that the references are
// acyclic as in the Java implementation.
auto* scope =
comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get();
scope->InsertVariableIfAbsent(std::move(decl));
current_scope_ = scope;
}

class ResolveRewriter : public AstRewriterBase {
public:
explicit ResolveRewriter(const ResolveVisitor& visitor,
Expand Down Expand Up @@ -1230,15 +1297,15 @@ class ResolveRewriter : public AstRewriterBase {

if (auto iter = visitor_.types().find(&expr);
iter != visitor_.types().end()) {
auto flattened_type =
FlattenType(inference_context_.FinalizeType(iter->second));
cel::Type finalized_type = inference_context_.FinalizeType(iter->second);
auto flattened_type = FlattenType(finalized_type);

if (!flattened_type.ok()) {
status_.Update(flattened_type.status());
return rewritten;
}
type_map_[expr.id()] = *std::move(flattened_type);
resolved_types_[expr.id()] = iter->second;
resolved_types_[expr.id()] = finalized_type;
rewritten = true;
}

Expand Down
95 changes: 95 additions & 0 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "checker/type_checker_builder.h"
#include "checker/validation_result.h"
#include "common/ast.h"
#include "common/ast_proto.h"
#include "common/container.h"
#include "common/decl.h"
#include "common/expr.h"
Expand All @@ -45,7 +46,10 @@
#include "internal/status_macros.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"
#include "parser/macro_registry.h"
#include "parser/parser.h"
#include "testutil/baseline_tests.h"
#include "testutil/test_macros.h"
#include "cel/expr/conformance/proto2/test_all_types.pb.h"
#include "cel/expr/conformance/proto3/test_all_types.pb.h"
#include "google/protobuf/arena.h"
Expand Down Expand Up @@ -108,6 +112,17 @@ google::protobuf::Arena* absl_nonnull TestTypeArena() {
return &(*kArena);
}

absl::StatusOr<std::unique_ptr<Ast>> MakeTestParsedAstWithMacros(
absl::string_view expression, const cel::MacroRegistry& registry) {
CEL_ASSIGN_OR_RETURN(
auto source,
cel::NewSource(expression, /*description=*/std::string(expression)));
CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse(
*source, registry,
{.enable_optional_syntax = true}));
return cel::CreateAstFromParsedExpr(parsed_expr);
}

FunctionDecl MakeIdentFunction() {
auto decl = MakeFunctionDecl(
"identity",
Expand Down Expand Up @@ -272,6 +287,13 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena
/*return_type=*/TypeType(arena, TypeParamType("A")),
TypeParamType("A"))));

Type kParam(TypeParamType("T"));
CEL_ASSIGN_OR_RETURN(
auto block_decl,
MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam,
ListType(), kParam)));
env.InsertFunctionIfAbsent(std::move(block_decl));

env.InsertFunctionIfAbsent(std::move(not_op));
env.InsertFunctionIfAbsent(std::move(not_strictly_false));
env.InsertFunctionIfAbsent(std::move(add_op));
Expand All @@ -289,6 +311,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena
env.InsertFunctionIfAbsent(std::move(to_type));
env.InsertFunctionIfAbsent(std::move(to_duration));
env.InsertFunctionIfAbsent(std::move(to_timestamp));
env.InsertFunctionIfAbsent(std::move(block_decl));

return absl::OkStatus();
}
Expand All @@ -308,6 +331,78 @@ TEST(TypeCheckerImplTest, SmokeTest) {
EXPECT_THAT(result.GetIssues(), IsEmpty());
}

TEST(TypeCheckerImplTest, BlockMacroSupport) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());

google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

MacroRegistry registry;
ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk());

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(
auto ast,
MakeTestParsedAstWithMacros(
"cel.block([1, 2], cel.index(0) + cel.index(1))", registry));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_TRUE(result.IsValid());
EXPECT_THAT(result.GetIssues(), IsEmpty());

// Overall type should be int.
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto root_id = checked_ast->root_expr().id();
EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(),
PrimitiveType::kInt64);
}

TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());

google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

MacroRegistry registry;
ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk());

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(
auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))",
registry));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_TRUE(result.IsValid());
EXPECT_THAT(result.GetIssues(), IsEmpty());

// cel.index(1) refers to 'a' which is string.
// So overall type should be string.
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto root_id = checked_ast->root_expr().id();
EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(),
PrimitiveType::kString);
}

TEST(TypeCheckerImplTest, BadIndex) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());

google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

MacroRegistry registry;
ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk());

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(
auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))",
registry));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_FALSE(result.IsValid());
EXPECT_THAT(result.FormatError(),
HasSubstr("undeclared reference to '@index2' (in container"));
}

TEST(TypeCheckerImplTest, SimpleIdentsResolved) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());

Expand Down
2 changes: 2 additions & 0 deletions compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ cc_library(
hdrs = ["compiler_factory.h"],
deps = [
":compiler",
"//checker:type_check_issue",
"//checker:type_checker",
"//checker:type_checker_builder",
"//checker:type_checker_builder_factory",
"//checker:validation_result",
"//common:ast",
"//common:source",
"//internal:noop_delete",
"//internal:status_macros",
Expand Down
24 changes: 22 additions & 2 deletions compiler/compiler_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "checker/type_check_issue.h"
#include "checker/type_checker.h"
#include "checker/type_checker_builder.h"
#include "checker/type_checker_builder_factory.h"
#include "checker/validation_result.h"
#include "common/ast.h"
#include "common/source.h"
#include "compiler/compiler.h"
#include "internal/status_macros.h"
Expand Down Expand Up @@ -55,9 +58,26 @@ class CompilerImpl : public Compiler {
google::protobuf::Arena* arena) const override {
CEL_ASSIGN_OR_RETURN(auto source,
cel::NewSource(expression, std::string(description)));
CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source));
std::vector<cel::ParseIssue> parse_issues;
absl::StatusOr<std::unique_ptr<cel::Ast>> ast =
parser_->Parse(*source, &parse_issues);
if (!ast.ok()) {
if (ast.status().code() != absl::StatusCode::kInvalidArgument ||
parse_issues.empty()) {
return ast.status();
}
std::vector<TypeCheckIssue> check_issues;
check_issues.reserve(parse_issues.size());
for (const auto& issue : parse_issues) {
check_issues.push_back(TypeCheckIssue::CreateError(
issue.location(), std::string(issue.message())));
}
ValidationResult result(std::move(check_issues));
result.SetSource(std::move(source));
return result;
}
CEL_ASSIGN_OR_RETURN(ValidationResult result,
type_checker_->Check(std::move(ast), arena));
type_checker_->Check(*std::move(ast), arena));

result.SetSource(std::move(source));
if (!validator_.validations().empty()) {
Expand Down
12 changes: 12 additions & 0 deletions compiler/compiler_factory_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,5 +413,17 @@ TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) {
it->second.GetOptional().GetParameter().GetList().GetElement().IsInt());
}

TEST(CompilerFactoryTest, ReturnsIssuesFromParser) {
ASSERT_OK_AND_ASSIGN(
auto builder,
NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool()));

ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build());

ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a +"));
EXPECT_FALSE(result.IsValid());
EXPECT_THAT(result.GetIssues(), testing::Not(testing::IsEmpty()));
}

} // namespace
} // namespace cel
Loading