Skip to content

Commit

Permalink
Checker support for declaring context messages.
Browse files Browse the repository at this point in the history
Add support for declaring a context message type to the C++ type checker.
The checker considers the top level fields of the type as variables in the
type check environment.

PiperOrigin-RevId: 699221362
  • Loading branch information
jnthntatum authored and copybara-github committed Nov 22, 2024
1 parent 241c9dd commit c39a717
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 8 deletions.
7 changes: 0 additions & 7 deletions checker/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,10 @@ cc_library(
":type_checker",
"//common:decl",
"//common:type",
"//internal:status_macros",
"//parser:macro",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf",
],
)

Expand Down
24 changes: 23 additions & 1 deletion checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ cc_library(
"//common:type_kind",
"//internal:status_macros",
"//parser:macro",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -189,6 +189,28 @@ cc_test(
],
)

cc_test(
name = "type_checker_builder_impl_test",
srcs = ["type_checker_builder_impl_test.cc"],
deps = [
":test_ast_helpers",
":type_checker_impl",
"//base/ast_internal:ast_impl",
"//base/ast_internal:expr",
"//checker:type_checker",
"//checker:validation_result",
"//common:decl",
"//common:type",
"//internal:testing",
"//internal:testing_descriptor_pool",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:optional",
],
)

cc_library(
name = "type_inference_context",
srcs = ["type_inference_context.cc"],
Expand Down
63 changes: 63 additions & 0 deletions checker/internal/type_checker_builder_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/no_destructor.h"
#include "absl/base/nullability.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "checker/internal/type_check_env.h"
#include "checker/internal/type_checker_impl.h"
#include "checker/type_checker.h"
Expand All @@ -34,6 +37,7 @@
#include "common/type_introspector.h"
#include "internal/status_macros.h"
#include "parser/macro.h"
#include "google/protobuf/descriptor.h"

namespace cel::checker_internal {
namespace {
Expand Down Expand Up @@ -78,8 +82,34 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) {

} // namespace

absl::Status TypeCheckerBuilderImpl::AddContextDeclarationVariables(
absl::Nonnull<const google::protobuf::Descriptor*> descriptor) {
for (int i = 0; i < descriptor->field_count(); i++) {
const google::protobuf::FieldDescriptor* proto_field = descriptor->field(i);
MessageTypeField cel_field(proto_field);
cel_field.name();
Type field_type = cel_field.GetType();
if (field_type.IsEnum()) {
field_type = IntType();
}
if (!env_.InsertVariableIfAbsent(
MakeVariableDecl(std::string(cel_field.name()), field_type))) {
return absl::AlreadyExistsError(
absl::StrCat("variable '", cel_field.name(),
"' already exists (from context declaration: '",
descriptor->full_name(), "')"));
}
}

return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<TypeChecker>>
TypeCheckerBuilderImpl::Build() && {
for (const auto* type : context_types_) {
CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(type));
}

auto checker = std::make_unique<checker_internal::TypeCheckerImpl>(
std::move(env_), options_);
return checker;
Expand Down Expand Up @@ -108,6 +138,39 @@ absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) {
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddContextDeclaration(
absl::string_view type) {
CEL_ASSIGN_OR_RETURN(absl::optional<Type> resolved_type,
env_.LookupTypeName(type));

if (!resolved_type.has_value()) {
return absl::NotFoundError(
absl::StrCat("context declaration '", type, "' not found"));
}

if (!resolved_type->IsStruct()) {
return absl::InvalidArgumentError(
absl::StrCat("context declaration '", type, "' is not a struct"));
}

if (!resolved_type->AsStruct()->IsMessage()) {
return absl::InvalidArgumentError(
absl::StrCat("context declaration '", type,
"' is not protobuf message backed struct"));
}

const google::protobuf::Descriptor* descriptor =
&(**(resolved_type->AsStruct()->AsMessage()));

if (absl::c_linear_search(context_types_, descriptor)) {
return absl::AlreadyExistsError(
absl::StrCat("context declaration '", type, "' already exists"));
}

context_types_.push_back(descriptor);
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) {
CEL_RETURN_IF_ERROR(CheckStdMacroOverlap(decl));
bool inserted = env_.InsertFunctionIfAbsent(decl);
Expand Down
5 changes: 5 additions & 0 deletions checker/internal/type_checker_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder {
absl::Status AddLibrary(CheckerLibrary library) override;

absl::Status AddVariable(const VariableDecl& decl) override;
absl::Status AddContextDeclaration(absl::string_view type) override;
absl::Status AddFunction(const FunctionDecl& decl) override;

void SetExpectedType(const Type& type) override;
Expand All @@ -71,9 +72,13 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder {
const CheckerOptions& options() const override { return options_; }

private:
absl::Status AddContextDeclarationVariables(
absl::Nonnull<const google::protobuf::Descriptor*> descriptor);

CheckerOptions options_;
std::vector<CheckerLibrary> libraries_;
absl::flat_hash_set<std::string> library_ids_;
std::vector<absl::Nonnull<const google::protobuf::Descriptor*>> context_types_;

checker_internal::TypeCheckEnv env_;
};
Expand Down
212 changes: 212 additions & 0 deletions checker/internal/type_checker_builder_impl_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "checker/internal/type_checker_builder_impl.h"

#include <memory>
#include <string>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "base/ast_internal/ast_impl.h"
#include "base/ast_internal/expr.h"
#include "checker/internal/test_ast_helpers.h"
#include "checker/type_checker.h"
#include "checker/validation_result.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/type_introspector.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"

namespace cel::checker_internal {
namespace {

using ::absl_testing::IsOk;
using ::absl_testing::StatusIs;
using ::cel::ast_internal::AstImpl;

using AstType = cel::ast_internal::Type;

struct ContextDeclsTestCase {
std::string expr;
AstType expected_type;
};

class ContextDeclsFieldsDefinedTest
: public testing::TestWithParam<ContextDeclsTestCase> {};

TEST_P(ContextDeclsFieldsDefinedTest, ContextDeclsFieldsDefined) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
ASSERT_THAT(
builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"),
IsOk());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<TypeChecker> type_checker,
std::move(builder).Build());
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(GetParam().expr));
ASSERT_OK_AND_ASSIGN(ValidationResult result,
type_checker->Check(std::move(ast)));

ASSERT_TRUE(result.IsValid());

const auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst());

EXPECT_EQ(ast_impl.GetReturnType(), GetParam().expected_type);
}

INSTANTIATE_TEST_SUITE_P(
TestAllTypes, ContextDeclsFieldsDefinedTest,
testing::Values(
ContextDeclsTestCase{"single_int64",
AstType(ast_internal::PrimitiveType::kInt64)},
ContextDeclsTestCase{"single_uint32",
AstType(ast_internal::PrimitiveType::kUint64)},
ContextDeclsTestCase{"single_double",
AstType(ast_internal::PrimitiveType::kDouble)},
ContextDeclsTestCase{"single_string",
AstType(ast_internal::PrimitiveType::kString)},
ContextDeclsTestCase{"single_any",
AstType(ast_internal::WellKnownType::kAny)},
ContextDeclsTestCase{"single_duration",
AstType(ast_internal::WellKnownType::kDuration)},
ContextDeclsTestCase{"single_bool_wrapper",
AstType(ast_internal::PrimitiveTypeWrapper(
ast_internal::PrimitiveType::kBool))},
ContextDeclsTestCase{
"list_value",
AstType(ast_internal::ListType(
std::make_unique<AstType>(ast_internal::DynamicType())))},
ContextDeclsTestCase{
"standalone_message",
AstType(ast_internal::MessageType(
"cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))},
ContextDeclsTestCase{"standalone_enum",
AstType(ast_internal::PrimitiveType::kInt64)},
ContextDeclsTestCase{
"repeated_bytes",
AstType(ast_internal::ListType(std::make_unique<AstType>(
ast_internal::PrimitiveType::kBytes)))},
ContextDeclsTestCase{
"repeated_nested_message",
AstType(ast_internal::ListType(std::make_unique<
AstType>(ast_internal::MessageType(
"cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))},
ContextDeclsTestCase{
"map_int32_timestamp",
AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::PrimitiveType::kInt64),
std::make_unique<AstType>(
ast_internal::WellKnownType::kTimestamp)))},
ContextDeclsTestCase{
"single_struct",
AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::PrimitiveType::kString),
std::make_unique<AstType>(ast_internal::DynamicType())))}));

TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
ASSERT_THAT(
builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"),
IsOk());
EXPECT_THAT(
builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"),
StatusIs(absl::StatusCode::kAlreadyExists,
"context declaration 'cel.expr.conformance.proto3.TestAllTypes' "
"already exists"));
}

TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
EXPECT_THAT(
builder.AddContextDeclaration("com.example.UnknownType"),
StatusIs(absl::StatusCode::kNotFound,
"context declaration 'com.example.UnknownType' not found"));
}

TEST(ContextDeclsTest, ErrorOnNonStructMessageType) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
EXPECT_THAT(
builder.AddContextDeclaration("google.protobuf.Timestamp"),
StatusIs(
absl::StatusCode::kInvalidArgument,
"context declaration 'google.protobuf.Timestamp' is not a struct"));
}

TEST(ContextDeclsTest, CustomStructNotSupported) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
class MyTypeProvider : public cel::TypeIntrospector {
public:
absl::StatusOr<absl::optional<Type>> FindTypeImpl(
absl::string_view name) const override {
if (name == "com.example.MyStruct") {
return common_internal::MakeBasicStructType("com.example.MyStruct");
}
return absl::nullopt;
}
};

builder.AddTypeProvider(std::make_unique<MyTypeProvider>());

EXPECT_THAT(builder.AddContextDeclaration("com.example.MyStruct"),
StatusIs(absl::StatusCode::kInvalidArgument,
"context declaration 'com.example.MyStruct' is not "
"protobuf message backed struct"));
}

TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
ASSERT_THAT(
builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"),
IsOk());
// We resolve the context declaration variables at the Build() call, so the
// error surfaces then.
ASSERT_THAT(
builder.AddContextDeclaration("cel.expr.conformance.proto2.TestAllTypes"),
IsOk());

EXPECT_THAT(
std::move(builder).Build(),
StatusIs(absl::StatusCode::kAlreadyExists,
"variable 'single_int32' already exists (from context "
"declaration: 'cel.expr.conformance.proto2.TestAllTypes')"));
}

TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
ASSERT_THAT(
builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"),
IsOk());
ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())),
IsOk());

EXPECT_THAT(
std::move(builder).Build(),
StatusIs(absl::StatusCode::kAlreadyExists,
"variable 'single_int64' already exists (from context "
"declaration: 'cel.expr.conformance.proto3.TestAllTypes')"));
}

} // namespace
} // namespace cel::checker_internal
Loading

0 comments on commit c39a717

Please sign in to comment.