Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion common/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,16 @@ cc_library(
srcs = ["signature.cc"],
hdrs = ["signature.h"],
deps = [
"//common:ast",
"//common:type",
"//common:type_kind",
"//common:type_spec_resolver",
"//internal:status_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_protobuf//:protobuf",
],
)

Expand All @@ -159,11 +161,14 @@ cc_test(
srcs = ["signature_test.cc"],
deps = [
":signature",
"//common:ast",
"//common:type",
"//common:type_kind",
"//internal:testing",
"//internal:testing_descriptor_pool",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_protobuf//:protobuf",
],
Expand Down
332 changes: 332 additions & 0 deletions common/internal/signature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,29 @@
#include "common/internal/signature.h"

#include <cstddef>
#include <cstring>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "common/ast.h"
#include "common/type.h"
#include "common/type_kind.h"
#include "common/type_spec_resolver.h"
#include "internal/status_macros.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"

namespace cel::common_internal {

// Signature generator helper functions.
namespace {

void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) {
Expand Down Expand Up @@ -208,4 +217,327 @@ absl::StatusOr<std::string> MakeOverloadSignature(

return result;
}

// Signature parser helper functions.
namespace {

std::string Unescape(std::string_view str) {
size_t first_escape = str.find('\\');
// NOMUTANTS -- equivalent mutant
if (first_escape == std::string_view::npos) {
return std::string(str);
}
std::string result;
result.reserve(str.size());
result.append(str.substr(0, first_escape));
bool escaped = false;
for (size_t i = first_escape; i < str.size(); ++i) {
char c = str[i];
if (escaped) {
result.push_back(c);
escaped = false;
} else if (c == '\\') {
escaped = true;
} else {
result.push_back(c);
}
}
if (escaped) {
result.push_back('\\');
}
return result;
}

class SignatureScanner {
public:
explicit SignatureScanner(std::string_view input,
std::string_view error_prefix = "Invalid signature")
: input_(input), error_prefix_(error_prefix) {}

absl::StatusOr<size_t> FindTopLevelChar(char target, bool find_last = false) {
size_t found_idx = std::string_view::npos;
int nesting = 0;
bool escaped = false;
// Scanning str for delimiter boundaries while ensuring
// brackets are balanced and escape backslashes are bypassed.
for (size_t i = 0; i < input_.size(); ++i) {
char c = input_[i];
if (escaped) {
escaped = false;
continue;
}
if (c == '\\') {
escaped = true;
continue;
}
if (c == target && nesting == 0) {
if (find_last || found_idx == std::string_view::npos) {
found_idx = i;
}
}
if (c == '<') {
nesting++;
} else if (c == '>') {
nesting--;
if (nesting < 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": mismatched brackets"));
}
}
}
if (nesting != 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": mismatched brackets"));
}
return found_idx;
}

absl::StatusOr<std::vector<std::string_view>> SplitTopLevel(char delimiter) {
std::vector<std::string_view> result;
int nesting = 0;
bool escaped = false;
size_t start = 0;
// Scanning str for delimiter while ensuring brackets are balanced and
// escape backslashes are bypassed.
for (size_t i = 0; i < input_.size(); ++i) {
char c = input_[i];
if (escaped) {
escaped = false;
continue;
}
if (c == '\\') {
escaped = true;
continue;
}
if (c == delimiter && nesting == 0) {
result.push_back(input_.substr(start, i - start));
start = i + 1;
}
if (c == '<') {
nesting++;
} else if (c == '>') {
nesting--;
if (nesting < 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": mismatched brackets"));
}
}
}
if (nesting != 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": mismatched brackets"));
}
result.push_back(input_.substr(start));
return result;
}

private:
std::string_view input_;
std::string_view error_prefix_;
};

absl::StatusOr<std::vector<std::string_view>> SplitTypeList(
std::string_view params) {
return SignatureScanner(params, "Invalid type signature").SplitTopLevel(',');
}

absl::StatusOr<TypeSpec> ParseTypeSignature(std::string_view signature) {
if (signature.empty()) {
return absl::InvalidArgumentError("Empty type signature");
}

if (signature[0] == '~') {
std::string_view param_name = signature.substr(1);
if (param_name.empty()) {
return absl::InvalidArgumentError(
"Invalid type signature: invalid type parameter name");
}
CEL_ASSIGN_OR_RETURN(size_t less_idx,
SignatureScanner(param_name)
.FindTopLevelChar('<', /*find_last=*/false));
CEL_ASSIGN_OR_RETURN(size_t comma_idx,
SignatureScanner(param_name)
.FindTopLevelChar(',', /*find_last=*/false));
if (less_idx != std::string_view::npos ||
comma_idx != std::string_view::npos) {
return absl::InvalidArgumentError(
"Invalid type signature: invalid type parameter name");
}
return TypeSpec(ParamTypeSpec(Unescape(param_name)));
}

CEL_ASSIGN_OR_RETURN(size_t less_idx,
SignatureScanner(signature, "Invalid type signature")
.FindTopLevelChar('<', /*find_last=*/false));

std::string name_str;
std::vector<TypeSpec> params;

if (less_idx != std::string_view::npos) {
// If the signature contains a '<', it must also contain a matching '>'.
if (signature.back() != '>') {
return absl::InvalidArgumentError(
"Invalid type signature: missing closing >");
}
name_str = Unescape(signature.substr(0, less_idx));
std::string_view params_str =
signature.substr(less_idx + 1, signature.size() - less_idx - 2);
CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str));
for (std::string_view param_str : param_list) {
CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str));
params.push_back(std::move(param));
}
} else {
name_str = Unescape(signature);
}

auto read_param_or_dyn = [&](size_t index) {
auto spec = std::make_unique<TypeSpec>(DynTypeSpec());
if (params.size() > index) {
*spec = std::move(params[index]);
}
return spec;
};

if (name_str == "null") return TypeSpec(NullTypeSpec());
if (name_str == "bool") return TypeSpec(PrimitiveType::kBool);
if (name_str == "int") return TypeSpec(PrimitiveType::kInt64);
if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64);
if (name_str == "double") return TypeSpec(PrimitiveType::kDouble);
if (name_str == "string") return TypeSpec(PrimitiveType::kString);
if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes);
if (name_str == "any") return TypeSpec(WellKnownTypeSpec::kAny);
if (name_str == "timestamp") return TypeSpec(WellKnownTypeSpec::kTimestamp);
if (name_str == "duration") return TypeSpec(WellKnownTypeSpec::kDuration);
if (name_str == "dyn") return TypeSpec(DynTypeSpec());

// Handle standard Protobuf well-known wrapper types to preserve
// backward compatibility for users migrating YAML configuration files.
if (name_str == "bool_wrapper" || name_str == "google.protobuf.BoolValue")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool));
if (name_str == "int_wrapper" || name_str == "google.protobuf.Int64Value" ||
name_str == "google.protobuf.Int32Value")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64));
if (name_str == "uint_wrapper" || name_str == "google.protobuf.UInt64Value" ||
name_str == "google.protobuf.UInt32Value")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64));
if (name_str == "double_wrapper" ||
name_str == "google.protobuf.DoubleValue" ||
name_str == "google.protobuf.FloatValue")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble));
if (name_str == "string_wrapper" || name_str == "google.protobuf.StringValue")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString));
if (name_str == "bytes_wrapper" || name_str == "google.protobuf.BytesValue")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes));

if (name_str == "type") {
if (params.size() > 1) {
return absl::InvalidArgumentError(
"Invalid type signature: type expects at most 1 parameter");
}
return TypeSpec(read_param_or_dyn(0));
}

if (name_str == "list") {
if (params.size() > 1) {
return absl::InvalidArgumentError(
"Invalid type signature: list expects at most 1 parameter");
}
return TypeSpec(ListTypeSpec(read_param_or_dyn(0)));
}

if (name_str == "map") {
if (!params.empty() && params.size() != 2) {
return absl::InvalidArgumentError(
"Invalid type signature: map expects 0 or 2 parameters");
}
auto key = read_param_or_dyn(0);
auto value = read_param_or_dyn(1);
return TypeSpec(MapTypeSpec(std::move(key), std::move(value)));
}

if (name_str == "function") {
auto result_type = read_param_or_dyn(0);
std::vector<TypeSpec> arg_types;
for (size_t i = 1; i < params.size(); ++i) {
arg_types.push_back(std::move(params[i]));
}
return TypeSpec(
FunctionTypeSpec(std::move(result_type), std::move(arg_types)));
}

if (name_str.empty() || absl::StrContains(name_str, "..")) {
return absl::InvalidArgumentError(
"Invalid type signature: invalid identifier");
}

return TypeSpec(AbstractType(name_str, std::move(params)));
}

} // namespace

absl::StatusOr<ParsedFunctionOverload> ParseFunctionSignature(
std::string_view signature) {
if (signature.empty()) {
return absl::InvalidArgumentError("Empty function signature");
}

CEL_ASSIGN_OR_RETURN(size_t paren_idx,
SignatureScanner(signature, "Invalid function signature")
.FindTopLevelChar('(', /*find_last=*/false));

if (paren_idx == std::string_view::npos || signature.back() != ')') {
return absl::InvalidArgumentError("Invalid function signature");
}

std::string_view prefix = signature.substr(0, paren_idx);
std::string_view args_str =
signature.substr(paren_idx + 1, signature.size() - paren_idx - 2);

std::vector<TypeSpec> arg_types;
ParsedFunctionOverload out;

CEL_ASSIGN_OR_RETURN(size_t dot_idx,
SignatureScanner(prefix, "Invalid function signature")
.FindTopLevelChar('.', /*find_last=*/true));

if (dot_idx != std::string_view::npos) {
out.is_member = true;
std::string_view receiver_str = prefix.substr(0, dot_idx);
std::string_view func_str = prefix.substr(dot_idx + 1);

CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str));
arg_types.push_back(std::move(receiver_param));
out.function_name = Unescape(func_str);
} else {
out.is_member = false;
out.function_name = Unescape(prefix);
}

if (out.function_name.empty()) {
return absl::InvalidArgumentError(
"Invalid function signature: empty function name");
}

if (!args_str.empty()) {
CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str));
for (std::string_view arg_str : arg_list) {
CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str));
arg_types.push_back(std::move(arg_param));
}
}

auto result_type = std::make_unique<TypeSpec>(DynTypeSpec());
out.signature_type =
TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types)));

return out;
}

absl::StatusOr<Type> ParseType(std::string_view signature, google::protobuf::Arena* arena,
const google::protobuf::DescriptorPool& pool) {
CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(signature));
return cel::ConvertTypeSpecToType(type_spec, arena, pool);
}

} // namespace cel::common_internal
Loading