diff --git a/eval/public/cel_options.cc b/eval/public/cel_options.cc index 8ca3c02f8..e0c8e1a4b 100644 --- a/eval/public/cel_options.cc +++ b/eval/public/cel_options.cc @@ -40,7 +40,8 @@ cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) { options.enable_lazy_bind_initialization, options.max_recursion_depth, options.enable_recursive_tracing, - options.enable_fast_builtins}; + options.enable_fast_builtins, + options.locale}; } } // namespace google::api::expr::runtime diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 62946ec91..f88f7a38a 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -17,6 +17,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ +#include + #include "absl/base/attributes.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -196,6 +198,11 @@ struct InterpreterOptions { // // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in bool enable_fast_builtins = false; + + // The locale to use for string formatting. + // + // Default is en_US. + std::string locale = "en_US"; }; // LINT.ThenChange(//depot/google3/runtime/runtime_options.h) diff --git a/extensions/BUILD b/extensions/BUILD index a38a4637f..b59ed8483 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -442,6 +442,7 @@ cc_library( srcs = ["strings.cc"], hdrs = ["strings.h"], deps = [ + ":formatting", "//checker:type_checker_builder", "//checker/internal:builtins_arena", "//common:casting", @@ -578,3 +579,59 @@ cc_test( "@com_google_absl//absl/status:status_matchers", ], ) + +cc_library( + name = "formatting", + srcs = ["formatting.cc"], + hdrs = ["formatting.h"], + deps = [ + "//common:value", + "//common:value_kind", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@icu4c", + ], +) + +cc_test( + name = "formatting_test", + srcs = ["formatting_test.cc"], + deps = [ + ":formatting", + "//common:allocator", + "//common:value", + "//extensions/protobuf:runtime_adapter", + "//internal:parse_text_proto", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//parser", + "//parser:options", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/extensions/formatting.cc b/extensions/formatting.cc new file mode 100644 index 000000000..0002dc10e --- /dev/null +++ b/extensions/formatting.cc @@ -0,0 +1,576 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/formatting.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/btree_map.h" +#include "absl/memory/memory.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_manager.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "unicode/decimfmt.h" +#include "unicode/errorcode.h" +#include "unicode/locid.h" +#include "unicode/numfmt.h" + +namespace cel::extensions { + +namespace { + +absl::StatusOr FormatString( + ValueManager& value_manager, const Value& value, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); + +absl::StatusOr FormatFixed( + const Value& value, std::optional precision, const icu::Locale& locale, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); + +absl::StatusOr>> ParsePrecision( + absl::string_view format) { + if (format[0] != '.') return std::pair{0, std::nullopt}; + + int64_t i = 1; + while (i < std::ssize(format) && absl::ascii_isdigit(format[i])) { + ++i; + } + if (i == std::ssize(format)) { + return absl::InvalidArgumentError( + "Unable to find end of precision specifier"); + } + int precision; + if (!absl::SimpleAtoi(format.substr(1, i - 1), &precision)) { + return absl::InvalidArgumentError( + "Unable to convert precision specifier to integer"); + } + return std::pair{i, precision}; +} + +absl::StatusOr> CreateDoubleNumberFormater( + std::optional min_precision, std::optional max_precision, + bool use_scientific_notation, const icu::Locale& locale) { + icu::ErrorCode error_code; // NOLINT + auto formatter = + absl::WrapUnique(icu::NumberFormat::createInstance(locale, error_code)); + if (formatter == nullptr || error_code.isFailure()) { + return absl::InternalError( + absl::StrCat("Failed to create localized number formatter: ", + error_code.errorName())); + } + formatter->setMinimumIntegerDigits(1); + static constexpr int kDefaultPrecision = 6; + formatter->setMinimumFractionDigits( + min_precision.value_or(kDefaultPrecision)); + formatter->setMaximumFractionDigits( + max_precision.value_or(kDefaultPrecision)); + + if (use_scientific_notation) { + auto dec_fmt = static_cast(formatter.get()); + dec_fmt->setExponentSignAlwaysShown(true); + dec_fmt->setMinimumExponentDigits(2); + } + return formatter; +} + +absl::StatusOr FormatDouble( + double value, std::optional min_precision, + std::optional max_precision, bool use_scientific_notation, + absl::string_view unit, const icu::Locale& locale, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto formatter, CreateDoubleNumberFormater( + min_precision, max_precision, + use_scientific_notation, locale)); + icu::ErrorCode error_code; // NOLINT + icu::UnicodeString output; + formatter->format(value, output, error_code); + + if (error_code.isSuccess()) { + scratch.clear(); + output.toUTF8String(scratch); + absl::StrAppend(&scratch, unit); + return scratch; + } else { + return absl::InternalError(absl::StrCat("Failed to format fixed number: ", + error_code.errorName())); + } +} + +void StrAppendQuoted(ValueKind kind, absl::string_view value, + std::string& target) { + switch (kind) { + case ValueKind::kBytes: + target.push_back('b'); + [[fallthrough]]; + case ValueKind::kString: + target.push_back('\"'); + for (char c : value) { + if (c == '\\' || c == '\"') { + target.push_back('\\'); + } + target.push_back(c); + } + target.push_back('\"'); + break; + case ValueKind::kTimestamp: + absl::StrAppend(&target, "timestamp(\"", value, "\")"); + break; + case ValueKind::kDuration: + absl::StrAppend(&target, "duration(\"", value, "\")"); + break; + case ValueKind::kDouble: + if (value == "NaN") { + absl::StrAppend(&target, "\"NaN\""); + } else if (value == "+Inf") { + absl::StrAppend(&target, "\"+Inf\""); + } else if (value == "-Inf") { + absl::StrAppend(&target, "\"-Inf\""); + } else { + absl::StrAppend(&target, value); + } + break; + default: + absl::StrAppend(&target, value); + break; + } +} + +absl::StatusOr FormatList( + ValueManager& value_manager, const Value& value, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto it, value.GetList().NewIterator(value_manager)); + scratch.clear(); + scratch.push_back('['); + std::string value_scratch; + + while (it->HasNext()) { + CEL_ASSIGN_OR_RETURN(auto next, it->Next(value_manager)); + absl::string_view next_str; + CEL_ASSIGN_OR_RETURN(next_str, + FormatString(value_manager, next, value_scratch)); + StrAppendQuoted(next.kind(), next_str, scratch); + absl::StrAppend(&scratch, ", "); + } + if (scratch.size() > 1) { + scratch.resize(scratch.size() - 2); + } + scratch.push_back(']'); + return scratch; +} + +absl::StatusOr FormatMap( + ValueManager& value_manager, const Value& value, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + absl::btree_map value_map; + std::string value_scratch; + CEL_RETURN_IF_ERROR(value.GetMap().ForEach( + value_manager, + [&](const Value& key, const Value& value) -> absl::StatusOr { + if (key.kind() != ValueKind::kString && + key.kind() != ValueKind::kBool && key.kind() != ValueKind::kInt && + key.kind() != ValueKind::kUint) { + return absl::InvalidArgumentError( + absl::StrCat("Map keys must be strings, booleans, integers, or " + "unsigned integers, was given ", + key.GetTypeName())); + } + CEL_ASSIGN_OR_RETURN(auto key_str, + FormatString(value_manager, key, value_scratch)); + std::string quoted_key_str; + StrAppendQuoted(key.kind(), key_str, quoted_key_str); + value_map.emplace(std::move(quoted_key_str), value); + return true; + })); + + scratch.clear(); + scratch.push_back('{'); + for (const auto& [key, value] : value_map) { + CEL_ASSIGN_OR_RETURN(auto value_str, + FormatString(value_manager, value, value_scratch)); + absl::StrAppend(&scratch, key, ":"); + StrAppendQuoted(value.kind(), value_str, scratch); + absl::StrAppend(&scratch, ", "); + } + if (scratch.size() > 1) { + scratch.resize(scratch.size() - 2); + } + scratch.push_back('}'); + return scratch; +} + +absl::StatusOr FormatString( + ValueManager& value_manager, const Value& value, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (value.kind()) { + case ValueKind::kList: + return FormatList(value_manager, value, scratch); + case ValueKind::kMap: + return FormatMap(value_manager, value, scratch); + case ValueKind::kString: + return value.GetString().NativeString(scratch); + case ValueKind::kBytes: + return value.GetBytes().NativeString(scratch); + case ValueKind::kNull: + return "null"; + case ValueKind::kInt: + scratch.clear(); + absl::StrAppend(&scratch, value.GetInt().NativeValue()); + return scratch; + case ValueKind::kUint: + scratch.clear(); + absl::StrAppend(&scratch, value.GetUint().NativeValue()); + return scratch; + case ValueKind::kDouble: { + auto number = value.GetDouble().NativeValue(); + if (std::isnan(number)) { + return "NaN"; + } + if (number == std::numeric_limits::infinity()) { + return "+Inf"; + } + if (number == -std::numeric_limits::infinity()) { + return "-Inf"; + } + scratch.clear(); + absl::StrAppend(&scratch, number); + return scratch; + } + case ValueKind::kTimestamp: + scratch.clear(); + absl::StrAppend(&scratch, value.DebugString()); + return scratch; + case ValueKind::kDuration: + return FormatDouble(absl::ToDoubleSeconds(value.GetDuration()), + /*min_precision=*/0, /*max_precision=*/9, + /*use_scientific_notation=*/false, + /*unit=*/"s", icu::Locale::getDefault(), scratch); + case ValueKind::kBool: + if (value.GetBool().NativeValue()) { + return "true"; + } + return "false"; + case ValueKind::kType: + return value.GetType().name(); + default: + return absl::InvalidArgumentError(absl::StrFormat( + "Could not convert argument %s to string", value.GetTypeName())); + } +} + +absl::StatusOr FormatDecimal( + const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + scratch.clear(); + switch (value.kind()) { + case ValueKind::kInt: + absl::StrAppend(&scratch, value.GetInt().NativeValue()); + return scratch; + case ValueKind::kUint: + absl::StrAppend(&scratch, value.GetUint().NativeValue()); + return scratch; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Decimal clause can only be used on integers, was given ", + value.GetTypeName())); + } +} + +absl::StatusOr FormatBinary( + const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + decltype(value.GetUint().NativeValue()) unsigned_value; + bool sign_bit = false; + switch (value.kind()) { + case ValueKind::kInt: { + auto tmp = value.GetInt().NativeValue(); + if (tmp < 0) { + sign_bit = true; + // Negating min int is undefined behavior, so we need to use unsigned + // arithmetic. + using unsigned_type = std::make_unsigned::type; + unsigned_value = -static_cast(tmp); + } else { + unsigned_value = tmp; + } + break; + } + case ValueKind::kUint: + unsigned_value = value.GetUint().NativeValue(); + break; + case ValueKind::kBool: + if (value.GetBool().NativeValue()) { + return "1"; + } + return "0"; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Binary clause can only be used on integers and bools, was given ", + value.GetTypeName())); + } + + if (unsigned_value == 0) { + return "0"; + } + + int size = absl::bit_width(unsigned_value) + sign_bit; + scratch.resize(size); + for (int i = size - 1; i >= 0; --i) { + if (unsigned_value & 1) { + scratch[i] = '1'; + } else { + scratch[i] = '0'; + } + unsigned_value >>= 1; + } + if (sign_bit) { + scratch[0] = '-'; + } + return scratch; +} + +absl::StatusOr FormatHex( + const Value& value, bool use_upper_case, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (value.kind()) { + case ValueKind::kString: + scratch = absl::BytesToHexString(value.GetString().NativeString(scratch)); + break; + case ValueKind::kBytes: + scratch = absl::BytesToHexString(value.GetBytes().NativeString(scratch)); + break; + case ValueKind::kInt: { + // Golang supports signed hex, but absl::StrFormat does not. To be + // compatible, we need to add a leading '-' if the value is negative. + auto tmp = value.GetInt().NativeValue(); + if (tmp < 0) { + // Negating min int is undefined behavior, so we need to use unsigned + // arithmetic. + using unsigned_type = std::make_unsigned::type; + scratch = absl::StrFormat("-%x", -static_cast(tmp)); + } else { + scratch = absl::StrFormat("%x", tmp); + } + break; + } + case ValueKind::kUint: + scratch = absl::StrFormat("%x", value.GetUint().NativeValue()); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Hex clause can only be used on integers, byte buffers, " + "and strings, was given ", + value.GetTypeName())); + } + if (use_upper_case) { + absl::AsciiStrToUpper(&scratch); + } + return scratch; +} + +absl::StatusOr FormatOctal( + const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (value.kind()) { + case ValueKind::kInt: { + // Golang supports signed octals, but absl::StrFormat does not. To be + // compatible, we need to add a leading '-' if the value is negative. + auto tmp = value.GetInt().NativeValue(); + if (tmp < 0) { + // Negating min int is undefined behavior, so we need to use unsigned + // arithmetic. + using unsigned_type = std::make_unsigned::type; + scratch = absl::StrFormat("-%o", -static_cast(tmp)); + } else { + scratch = absl::StrFormat("%o", tmp); + } + return scratch; + } + case ValueKind::kUint: + scratch = absl::StrFormat("%o", value.GetUint().NativeValue()); + return scratch; + default: + return absl::InvalidArgumentError( + absl::StrCat("Octal clause can only be used on integers, was given ", + value.GetTypeName())); + } +} + +absl::StatusOr GetDouble(const Value& value, std::string& scratch) { + if (value.kind() == ValueKind::kString) { + auto str = value.GetString().NativeString(scratch); + if (str == "NaN") { + return std::nan(""); + } else if (str == "Infinity") { + return std::numeric_limits::infinity(); + } else if (str == "-Infinity") { + return -std::numeric_limits::infinity(); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Only \"NaN\", \"Infinity\", and \"-Infinity\" are " + "supported for conversion to double: ", + str)); + } + } + if (value.kind() != ValueKind::kDouble) { + return absl::InvalidArgumentError( + absl::StrCat("Expected a double but got a ", value.GetTypeName())); + } + return value.GetDouble().NativeValue(); +} + +absl::StatusOr FormatFixed( + const Value& value, std::optional precision, const icu::Locale& locale, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto number, GetDouble(value, scratch)); + return FormatDouble(number, precision, precision, + /*use_scientific_notation=*/false, /*unit=*/"", locale, + scratch); +} + +absl::StatusOr FormatScientific( + const Value& value, std::optional precision, const icu::Locale& locale, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto number, GetDouble(value, scratch)); + return FormatDouble(number, precision, precision, + /*use_scientific_notation=*/true, /*unit=*/"", locale, + scratch); +} + +absl::StatusOr> ParseAndFormatClause( + ValueManager& value_manager, absl::string_view format, const Value& value, + const icu::Locale& locale, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto precision_pair, ParsePrecision(format)); + auto [read, precision] = precision_pair; + switch (format[read]) { + case 's': { + CEL_ASSIGN_OR_RETURN(auto result, + FormatString(value_manager, value, scratch)); + return std::pair{read, result}; + } + case 'd': { + CEL_ASSIGN_OR_RETURN(auto result, FormatDecimal(value, scratch)); + return std::pair{read, result}; + } + case 'f': { + CEL_ASSIGN_OR_RETURN(auto result, + FormatFixed(value, precision, locale, scratch)); + return std::pair{read, result}; + } + case 'e': { + CEL_ASSIGN_OR_RETURN(auto result, + FormatScientific(value, precision, locale, scratch)); + return std::pair{read, result}; + } + case 'b': { + CEL_ASSIGN_OR_RETURN(auto result, FormatBinary(value, scratch)); + return std::pair{read, result}; + } + case 'x': + case 'X': { + CEL_ASSIGN_OR_RETURN( + auto result, + FormatHex(value, + /*use_upper_case=*/format[read] == 'X', scratch)); + return std::pair{read, result}; + } + case 'o': { + CEL_ASSIGN_OR_RETURN(auto result, FormatOctal(value, scratch)); + return std::pair{read, result}; + } + default: + return absl::InvalidArgumentError(absl::StrFormat( + "Unrecognized formatting clause \"%c\"", format[read])); + } +} + +absl::StatusOr Format(ValueManager& value_manager, + const StringValue& format_value, + const ListValue& args, const icu::Locale& locale) { + std::string format_scratch, clause_scratch; + absl::string_view format = format_value.NativeString(format_scratch); + std::string result; + result.reserve(format.size()); + int64_t arg_index = 0; + CEL_ASSIGN_OR_RETURN(int64_t args_size, args.Size()); + for (int64_t i = 0; i < std::ssize(format); ++i) { + if (format[i] != '%') { + result.push_back(format[i]); + continue; + } + ++i; + if (i >= std::ssize(format)) { + return absl::InvalidArgumentError("Unexpected end of format string"); + } + if (format[i] == '%') { + result.push_back('%'); + continue; + } + if (arg_index >= args_size) { + return absl::InvalidArgumentError( + absl::StrFormat("Index %d out of range", arg_index)); + } + CEL_ASSIGN_OR_RETURN(auto value, args.Get(value_manager, arg_index++)); + CEL_ASSIGN_OR_RETURN(auto clause, + ParseAndFormatClause(value_manager, format.substr(i), + value, locale, clause_scratch)); + absl::StrAppend(&result, clause.second); + i += clause.first; + } + return value_manager.CreateUncheckedStringValue(std::move(result)); +} + +} // namespace + +absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + auto locale = icu::Locale::createCanonical(options.locale.c_str()); + if (locale.isBogus() || absl::string_view(locale.getISO3Language()).empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse locale: ", options.locale)); + } + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, StringValue, ListValue>:: + CreateDescriptor("format", /*receiver_style=*/true), + BinaryFunctionAdapter, StringValue, ListValue>:: + WrapFunction([locale](ValueManager& value_manager, + const StringValue& format, + const ListValue& args) { + return Format(value_manager, format, args, locale); + }))); + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/formatting.h b/extensions/formatting.h new file mode 100644 index 000000000..bc2002006 --- /dev/null +++ b/extensions/formatting.h @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Register extension functions for string formatting. +absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc new file mode 100644 index 000000000..130208819 --- /dev/null +++ b/extensions/formatting_test.cc @@ -0,0 +1,912 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/formatting.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/allocator.h" +#include "common/value.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::parser::ParserOptions; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +struct FormattingTestCase { + std::string name; + std::string format; + std::string format_args; + absl::flat_hash_map> + dyn_args; + std::string expected; + std::string locale = "en_US"; + std::optional error = std::nullopt; +}; + +template +ParsedMessageValue MakeMessage(absl::string_view text) { + return ParsedMessageValue(internal::DynamicParseTextProto( + Allocator(NewDeleteAllocator<>{}), text, + internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory())); +} + +using StringFormatTest = TestWithParam; +TEST_P(StringFormatTest, TestStringFormatting) { + const FormattingTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + const auto options = RuntimeOptions{.locale = test_case.locale}; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + auto registration_status = + RegisterStringFormattingFunctions(builder.function_registry(), options); + if (test_case.error.has_value() && !registration_status.ok()) { + EXPECT_THAT(registration_status.message(), HasSubstr(*test_case.error)); + return; + } else { + ASSERT_THAT(registration_status, IsOk()); + } + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + auto expr_str = absl::StrFormat("'''%s'''.format([%s])", test_case.format, + test_case.format_args); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(expr_str, "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + for (const auto& [name, value] : test_case.dyn_args) { + if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, + StringValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, BoolValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, IntValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, + UintValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, + DoubleValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue( + name, DurationValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue( + name, TimestampValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, std::get(value)); + } + } + auto result = program->Evaluate(&arena, activation); + if (test_case.error.has_value()) { + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), HasSubstr(*test_case.error)); + } else { + ASSERT_TRUE(result.ok()); + ASSERT_TRUE(result->Is()); + EXPECT_THAT(result->GetString().ToString(), test_case.expected); + } +} + +INSTANTIATE_TEST_SUITE_P( + TestStringFormatting, StringFormatTest, + ValuesIn({ + { + .name = "Basic", + .format = "%s %s!", + .format_args = "'hello', 'world'", + .expected = "hello world!", + }, + { + .name = "EscapedPercentSign", + .format = "Percent sign %%!", + .format_args = "'hello', 'world'", + .expected = "Percent sign %!", + }, + { + .name = "IncompleteCase", + .format = "%", + .format_args = "'hello'", + .error = "Unexpected end of format string", + }, + { + .name = "MissingFormatArg", + .format = "%s", + .format_args = "", + .error = "Index 0 out of range", + }, + { + .name = "MissingFormatArg2", + .format = "%s, %s", + .format_args = "'hello'", + .error = "Index 1 out of range", + }, + { + .name = "InvalidPrecision", + .format = "%.6", + .format_args = "'hello'", + .error = "Unable to find end of precision specifier", + }, + { + .name = "InvalidPrecision2", + .format = "%.f", + .format_args = "'hello'", + .error = "Unable to convert precision specifier to integer", + }, + { + .name = "InvalidPrecision3", + .format = "%.", + .format_args = "'hello'", + .error = "Unable to find end of precision specifier", + }, + { + .name = "DecimalFormatingClause", + .format = "int %d, uint %d", + .format_args = "-1, uint(2)", + .expected = R"(int -1, uint 2)", + }, + { + .name = "DecimalDoesNotWorkWithDouble", + .format = "double %d", + .format_args = "double(\"-Inf\")", + .error = + "Decimal clause can only be used on integers, was given double", + }, + { + .name = "OctalFormatingClause", + .format = "int %o, uint %o", + .format_args = "-10, uint(20)", + .expected = R"(int -12, uint 24)", + }, + { + .name = "OctalDoesNotWorkWithDouble", + .format = "double %o", + .format_args = "double(\"-Inf\")", + .error = + "Octal clause can only be used on integers, was given double", + }, + { + .name = "HexFormatingClause", + .format = "int %x, uint %X, string %x, bytes %X", + .format_args = "-10, uint(255), 'hello', b'world'", + .expected = "int -a, uint FF, string 68656c6c6f, bytes 776F726C64", + }, + { + .name = "HexFormatingClauseLeadingZero", + .format = "string: %x", + .format_args = R"(b'\x00\x00hello\x00')", + .expected = "string: 000068656c6c6f00", + }, + { + .name = "HexDoesNotWorkWithDouble", + .format = "double %x", + .format_args = "double(\"-Inf\")", + .error = "Hex clause can only be used on integers, byte buffers, " + "and strings, was given double", + }, + { + .name = "BinaryFormatingClause", + .format = "int %b, uint %b, bool %b, bool %b", + .format_args = "-32, uint(20), false, true", + .expected = "int -100000, uint 10100, bool 0, bool 1", + }, + { + .name = "BinaryFormatingClauseLimits", + .format = "min_int %b, max_int %b, max_uint %b", + .format_args = + absl::StrCat(std::numeric_limits::min(), ",", + std::numeric_limits::max(), ",", + std::numeric_limits::max(), "u"), + .expected = "min_int " + "-10000000000000000000000000000000000000000000000000000" + "00000000000, max_int " + "111111111111111111111111111111111111111111111111111111" + "111111111, max_uint " + "111111111111111111111111111111111111111111111111111111" + "1111111111", + }, + { + .name = "BinaryFormatingClauseZero", + .format = "zero %b", + .format_args = "0", + .expected = "zero 0", + }, + { + .name = "HexFormatingClauseLimits", + .format = "min_int %x, max_int %x, max_uint %x", + .format_args = + absl::StrCat(std::numeric_limits::min(), ",", + std::numeric_limits::max(), ",", + std::numeric_limits::max(), "u"), + .expected = "min_int -8000000000000000, max_int 7fffffffffffffff, " + "max_uint ffffffffffffffff", + }, + { + .name = "OctalFormatingClauseLimits", + .format = "min_int %o, max_int %o, max_uint %o", + .format_args = + absl::StrCat(std::numeric_limits::min(), ",", + std::numeric_limits::max(), ",", + std::numeric_limits::max(), "u"), + .expected = + "min_int -1000000000000000000000, max_int " + "777777777777777777777, max_uint 1777777777777777777777", + }, + { + .name = "FixedClauseFormatting", + .format = "%f", + .format_args = "10000.1234", + .expected = "10,000.123400", + }, + { + .name = "FixedClauseFormattingWithPrecision", + .format = "%.2f", + .format_args = "10000.1234", + .expected = "10,000.12", + }, + { + .name = "FixedClauseFormattingWithLocale", + .format = "%.2f", + .format_args = "10000.1234", + .expected = "10.000,12", + .locale = "de_DE", + }, + { + .name = "FixedClauseFormattingWithC", + .format = "%.2f", + .format_args = "10000.1234", + .locale = "C", + .error = "Failed to parse locale: C", + }, + { + .name = "FixedClauseFormattingWithInvalidLocale", + .format = "%.2f", + .format_args = "10000.1234", + .locale = "bogus locale", + .error = "Failed to parse locale: bogus locale", + }, + { + .name = "ListSupportForStringWithQuotes", + .format = "%s", + .format_args = R"(["a\"b","a\\b"])", + .expected = R"(["a\"b", "a\\b"])", + }, + { + .name = "ListSupportForStringWithDouble", + .format = "%s", + .format_args = R"([double("NaN"),double("Inf"), double("-Inf")])", + .expected = R"(["NaN", "+Inf", "-Inf"])", + }, + { + .name = "FixedClauseFormattingWithDynArgs", + .format = "%.2f %d", + .format_args = "arg, message.single_int32", + .dyn_args = + { + {"arg", 10000.1234}, + {"message", + MakeMessage(R"pb(single_int32: 42)pb")}, + }, + .expected = "10,000.12 42", + }, + { + .name = "NoOp", + .format = "no substitution", + .expected = "no substitution", + }, + { + .name = "MidStringSubstitution", + .format = "str is %s and some more", + .format_args = "'filler'", + .expected = "str is filler and some more", + }, + { + .name = "PercentEscaping", + .format = "%% and also %%", + .expected = "% and also %", + }, + { + .name = "SubstitutionInsideEscapedPercentSigns", + .format = "%%%s%%", + .format_args = "'text'", + .expected = "%text%", + }, + { + .name = "SubstitutionWithOneEscapedPercentSignOnTheRight", + .format = "%s%%", + .format_args = "'percent on the right'", + .expected = "percent on the right%", + }, + { + .name = "SubstitutionWithOneEscapedPercentSignOnTheLeft", + .format = "%%%s", + .format_args = "'percent on the left'", + .expected = "%percent on the left", + }, + { + .name = "MultipleSubstitutions", + .format = "%d %d %d, %s %s %s, %d %d %d, %s %s %s", + .format_args = "1, 2, 3, 'A', 'B', 'C', 4, 5, 6, 'D', 'E', 'F'", + .expected = "1 2 3, A B C, 4 5 6, D E F", + }, + { + .name = "PercentSignEscapeSequenceSupport", + .format = "\u0025\u0025escaped \u0025s\u0025\u0025", + .format_args = "'percent'", + .expected = "%escaped percent%", + }, + { + .name = "FixedPointFormattingClause", + .format = "%.3f", + .format_args = "1.2345", + .expected = "1.234", + .locale = "en_US", + }, + { + .name = "BinaryFormattingClause", + .format = "this is 5 in binary: %b", + .format_args = "5", + .expected = "this is 5 in binary: 101", + }, + { + .name = "UintSupportForBinaryFormatting", + .format = "unsigned 64 in binary: %b", + .format_args = "uint(64)", + .expected = "unsigned 64 in binary: 1000000", + }, + { + .name = "BoolSupportForBinaryFormatting", + .format = "bit set from bool: %b", + .format_args = "true", + .expected = "bit set from bool: 1", + }, + { + .name = "OctalFormattingClause", + .format = "%o", + .format_args = "11", + .expected = "13", + }, + { + .name = "UintSupportForOctalFormattingClause", + .format = "this is an unsigned octal: %o", + .format_args = "uint(65535)", + .expected = "this is an unsigned octal: 177777", + }, + { + .name = "LowercaseHexadecimalFormattingClause", + .format = "%x is 20 in hexadecimal", + .format_args = "30", + .expected = "1e is 20 in hexadecimal", + }, + { + .name = "UppercaseHexadecimalFormattingClause", + .format = "%X is 20 in hexadecimal", + .format_args = "30", + .expected = "1E is 20 in hexadecimal", + }, + { + .name = "UnsignedSupportForHexadecimalFormattingClause", + .format = "%X is 6000 in hexadecimal", + .format_args = "uint(6000)", + .expected = "1770 is 6000 in hexadecimal", + }, + { + .name = "StringSupportWithHexadecimalFormattingClause", + .format = "%x", + .format_args = R"("Hello world!")", + .expected = "48656c6c6f20776f726c6421", + }, + { + .name = "StringSupportWithUppercaseHexadecimalFormattingClause", + .format = "%X", + .format_args = R"("Hello world!")", + .expected = "48656C6C6F20776F726C6421", + }, + { + .name = "ByteSupportWithHexadecimalFormattingClause", + .format = "%x", + .format_args = R"(b"byte string")", + .expected = "6279746520737472696e67", + }, + { + .name = "ByteSupportWithUppercaseHexadecimalFormattingClause", + .format = "%X", + .format_args = R"(b"byte string")", + .expected = "6279746520737472696E67", + }, + { + .name = "ScientificNotationFormattingClause", + .format = "%.6e", + .format_args = "1052.032911275", + .expected = "1.052033E+03", // Different from Golang formatting. + .locale = "en_US", + }, + { + .name = "LocaleSupport", + .format = "%.3f", + .format_args = "3.14", + .expected = "3,140", + .locale = "fr_FR", + }, + { + .name = "DefaultPrecisionForFixedPointClause", + .format = "%f", + .format_args = "2.71828", + .expected = "2.718280", + .locale = "en_US", + }, + { + .name = "DefaultPrecisionForScientificNotation", + .format = "%e", + .format_args = "2.71828", + .expected = "2.718280E+00", // Different from Golang formatting. + }, + { + .name = "UnicodeOutputForScientificNotation", + .format = "unescaped unicode: %e, escaped unicode: %e", + .format_args = "2.71828, 2.71828", + .expected = "unescaped unicode: 2.718280E+00, escaped unicode: " + "2.718280E+00", + }, + { + .name = "NaNSupportForFixedPoint", + .format = "%f", + .format_args = "\"NaN\"", + .expected = "NaN", + }, + { + .name = "PositiveInfinitySupportForFixedPoint", + .format = "%f", + .format_args = "\"Infinity\"", + .expected = "∞", + }, + { + .name = "NegativeInfinitySupportForFixedPoint", + .format = "%f", + .format_args = "\"-Infinity\"", + .expected = "-∞", + }, + { + .name = "UintSupportForDecimalClause", + .format = "%d", + .format_args = "uint(64)", + .expected = "64", + }, + { + .name = "NullSupportForString", + .format = "null: %s", + .format_args = "null", + .expected = "null: null", + }, + { + .name = "IntSupportForString", + .format = "%s", + .format_args = "999999999999", + .expected = "999999999999", + }, + { + .name = "BytesSupportForString", + .format = "some bytes: %s", + .format_args = "b\"xyz\"", + .expected = "some bytes: xyz", + }, + { + .name = "TypeSupportForString", + .format = "type is %s", + .format_args = "type(\"test string\")", + .expected = "type is string", + }, + { + .name = "TimestampSupportForString", + .format = "%s", + .format_args = "timestamp(\"2023-02-03T23:31:20+00:00\")", + .expected = "2023-02-03T23:31:20Z", + }, + { + .name = "DurationSupportForString", + .format = "%s", + .format_args = "duration(\"1h45m47s\")", + .expected = "6347s", + }, + { + .name = "ListSupportForString", + .format = "%s", + .format_args = + R"(["abc", 3.14, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")])", + .expected = + R"(["abc", 3.14, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")])", + }, + { + .name = "MapSupportForString", + .format = "%s", + .format_args = + R"({"key1": b"xyz", "key5": null, "key2": duration("7200s"), "key4": true, "key3": 2.71828})", + .expected = + R"({"key1":b"xyz", "key2":duration("7200s"), "key3":2.71828, "key4":true, "key5":null})", + .locale = "nl_NL", + }, + { + .name = "MapSupportAllKeyTypes", + .format = "map with multiple key types: %s", + .format_args = + R"({1: "value1", uint(2): "value2", true: double("NaN")})", + .expected = + R"(map with multiple key types: {1:"value1", 2:"value2", true:"NaN"})", + }, + { + .name = "MapAfterDecimalFormatting", + .format = "%d %s", + .format_args = R"(42, {"key": 1})", + .expected = "42 {\"key\":1}", + }, + { + .name = "BooleanSupportForString", + .format = "true bool: %s, false bool: %s", + .format_args = "true, false", + .expected = "true bool: true, false bool: false", + }, + { + .name = "DynTypeSupportForStringFormattingClause", + .format = "Dynamic String: %s", + .format_args = R"(dynStr)", + .dyn_args = {{"dynStr", "a string"}}, + .expected = "Dynamic String: a string", + }, + { + .name = "DynTypeSupportForNumbersWithStringFormattingClause", + .format = "Dynamic Int Str: %s Dynamic Double Str: %s", + .format_args = R"(dynIntStr, dynDoubleStr)", + .dyn_args = + { + {"dynIntStr", 32}, + {"dynDoubleStr", 56.8}, + }, + .expected = "Dynamic Int Str: 32 Dynamic Double Str: 56.8", + .locale = "en_US", + }, + { + .name = "DynTypeSupportForIntegerFormattingClause", + .format = "Dynamic Int: %d", + .format_args = R"(dynInt)", + .dyn_args = {{"dynInt", 128}}, + .expected = "Dynamic Int: 128", + }, + { + .name = "DynTypeSupportForIntegerFormattingClauseUnsigned", + .format = "Dynamic Unsigned Int: %d", + .format_args = R"(dynUnsignedInt)", + .dyn_args = {{"dynUnsignedInt", uint64_t{256}}}, + .expected = "Dynamic Unsigned Int: 256", + }, + { + .name = "DynTypeSupportForHexFormattingClause", + .format = "Dynamic Hex Int: %x", + .format_args = R"(dynHexInt)", + .dyn_args = {{"dynHexInt", 22}}, + .expected = "Dynamic Hex Int: 16", + }, + { + .name = "DynTypeSupportForHexFormattingClauseUppercase", + .format = "Dynamic Hex Int: %X (uppercase)", + .format_args = R"(dynHexInt)", + .dyn_args = {{"dynHexInt", 26}}, + .expected = "Dynamic Hex Int: 1A (uppercase)", + }, + { + .name = "DynTypeSupportForUnsignedHexFormattingClause", + .format = "Dynamic Hex Int: %x (unsigned)", + .format_args = R"(dynUnsignedHexInt)", + .dyn_args = {{"dynUnsignedHexInt", uint64_t{500}}}, + .expected = "Dynamic Hex Int: 1f4 (unsigned)", + }, + { + .name = "DynTypeSupportForFixedPointFormattingClause", + .format = "Dynamic Double: %.3f", + .format_args = R"(dynDouble)", + .dyn_args = {{"dynDouble", 4.5}}, + .expected = "Dynamic Double: 4.500", + .locale = "en_US", + }, + { + .name = "DynTypeSupportForFixedPointFormattingClauseCommaSeparatorL" + "ocale", + .format = "Dynamic Double: %f", + .format_args = R"(dynDouble)", + .dyn_args = {{"dynDouble", 4.5}}, + .expected = "Dynamic Double: 4,500000", + .locale = "fr_FR", + }, + { + .name = "DynTypeSupportForScientificNotation", + .format = "(Dynamic Type) E: %e", + .format_args = R"(dynE)", + .dyn_args = {{"dynE", 2.71828}}, + .expected = "(Dynamic Type) E: 2.718280E+00", + .locale = "en_US", + }, + { + .name = "DynTypeNaNInfinitySupportForFixedPoint", + .format = "NaN: %f, Infinity: %f", + .format_args = R"(dynNaN, dynInf)", + .dyn_args = {{"dynNaN", std::nan("")}, + {"dynInf", std::numeric_limits::infinity()}}, + .expected = "NaN: NaN, Infinity: ∞", + }, + { + .name = "DynTypeSupportForTimestamp", + .format = "Dynamic Type Timestamp: %s", + .format_args = R"(dynTime)", + .dyn_args = {{"dynTime", absl::FromUnixSeconds(1257894000)}}, + .expected = "Dynamic Type Timestamp: 2009-11-10T23:00:00Z", + }, + { + .name = "DynTypeSupportForDuration", + .format = "Dynamic Type Duration: %s", + .format_args = R"(dynDuration)", + .dyn_args = {{"dynDuration", absl::Hours(2) + absl::Minutes(25) + + absl::Seconds(47)}}, + .expected = "Dynamic Type Duration: 8747s", + }, + { + .name = "UnrecognizedFormattingClause", + .format = "%a", + .format_args = "1", + .error = "Unrecognized formatting clause \"a\"", + }, + { + .name = "OutOfBoundsArgIndex", + .format = "%d %d %d", + .format_args = "0, 1", + .error = "Index 2 out of range", + }, + { + .name = "StringSubstitutionIsNotAllowedWithBinaryClause", + .format = "string is %b", + .format_args = "\"abc\"", + .error = "Binary clause can only be used on integers and bools, " + "was given string", + }, + { + .name = "DurationSubstitutionIsNotAllowedWithDecimalClause", + .format = "%d", + .format_args = "duration(\"30m2s\")", + .error = "Decimal clause can only be used on integers, was given " + "google.protobuf.Duration", + }, + { + .name = "StringSubstitutionIsNotAllowedWithOctalClause", + .format = "octal: %o", + .format_args = "\"a string\"", + .error = + "Octal clause can only be used on integers, was given string", + }, + { + .name = "DoubleSubstitutionIsNotAllowedWithHexClause", + .format = "double is %x", + .format_args = "0.5", + .error = "Hex clause can only be used on integers, byte buffers, " + "and strings, was given double", + }, + { + .name = "UppercaseIsNotAllowedForScientificClause", + .format = "double is %E", + .format_args = "0.5", + .error = "Unrecognized formatting clause \"E\"", + }, + { + .name = "ObjectIsNotAllowed", + .format = "object is %s", + .format_args = "cel.expr.conformance.proto3.TestAllTypes{}", + .error = "Could not convert argument " + "cel.expr.conformance.proto3.TestAllTypes to string", + }, + { + .name = "ObjectInsideList", + .format = "%s", + .format_args = "[1, 2, cel.expr.conformance.proto3.TestAllTypes{}]", + .error = "Could not convert argument " + "cel.expr.conformance.proto3.TestAllTypes to string", + }, + { + .name = "ObjectInsideMap", + .format = "%s", + .format_args = + "{1: \"a\", 2: cel.expr.conformance.proto3.TestAllTypes{}}", + .error = "Could not convert argument " + "cel.expr.conformance.proto3.TestAllTypes to string", + }, + { + .name = "NullNotAllowedForDecimalClause", + .format = "null: %d", + .format_args = "null", + .error = "Decimal clause can only be used on integers, was given " + "null_type", + }, + { + .name = "NullNotAllowedForScientificNotationClause", + .format = "null: %e", + .format_args = "null", + .error = "Expected a double but got a null_type", + }, + { + .name = "NullNotAllowedForFixedPointClause", + .format = "null: %f", + .format_args = "null", + .error = "Expected a double but got a null_type", + }, + { + .name = "NullNotAllowedForHexadecimalClause", + .format = "null: %x", + .format_args = "null", + .error = "Hex clause can only be used on integers, byte buffers, " + "and strings, was given null_type", + }, + { + .name = "NullNotAllowedForUppercaseHexadecimalClause", + .format = "null: %X", + .format_args = "null", + .error = "Hex clause can only be used on integers, byte buffers, " + "and strings, was given null_type", + }, + { + .name = "NullNotAllowedForBinaryClause", + .format = "null: %b", + .format_args = "null", + .error = "Binary clause can only be used on integers and bools, " + "was given null_type", + }, + { + .name = "NullNotAllowedForOctalClause", + .format = "null: %o", + .format_args = "null", + .error = "Octal clause can only be used on integers, was given " + "null_type", + }, + { + .name = "NegativeBinaryFormattingClause", + .format = "this is -5 in binary: %b", + .format_args = "-5", + .expected = "this is -5 in binary: -101", + }, + { + .name = "NegativeOctalFormattingClause", + .format = "%o", + .format_args = "-11", + .expected = "-13", + }, + { + .name = "NegativeHexadecimalFormattingClause", + .format = "%x is -30 in hexadecimal", + .format_args = "-30", + .expected = "-1e is -30 in hexadecimal", + }, + { + .name = "DefaultPrecisionForString", + .format = "%s", + .format_args = "2.71", + .expected = "2.71", + .locale = "en_US", + }, + { + .name = "DefaultListPrecisionForString", + .format = "%s", + .format_args = "[2.71]", + .expected = + "[2.71]", // Different from Golang (2.710000) consistent with + // the precision of a double outside of a list. + .locale = "en_US", + }, + { + .name = "AutomaticRoundingForString", + .format = "%s", + .format_args = "10002.71", + .expected = "10002.7", // Different from Golang (10002.71) which + // does not round. + .locale = "en_US", + }, + { + .name = "DefaultScientificNotationForString", + .format = "%s", + .format_args = "0.000000002", + .expected = "2e-09", + .locale = "en_US", + }, + { + .name = "DefaultListScientificNotationForString", + .format = "%s", + .format_args = "[0.000000002]", + .expected = + "[2e-09]", // Different from Golang (0.000000) consistent with + // the notation of a double outside of a list. + .locale = "en_US", + }, + { + .name = "NaNSupportForString", + .format = "%s", + .format_args = R"(double("NaN"))", + .expected = "NaN", + }, + { + .name = "PositiveInfinitySupportForString", + .format = "%s", + .format_args = R"(double("Inf"))", + .expected = "+Inf", + }, + { + .name = "NegativeInfinitySupportForString", + .format = "%s", + .format_args = R"(double("-Inf"))", + .expected = "-Inf", + }, + { + .name = "InfinityListSupportForString", + .format = "%s", + .format_args = R"([double("NaN"), double("+Inf"), double("-Inf")])", + .expected = R"(["NaN", "+Inf", "-Inf"])", + }, + { + .name = "SmallDurationSupportForString", + .format = "%s", + .format_args = R"(duration("2ns"))", + .expected = "0.000000002s", + }, + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/strings.cc b/extensions/strings.cc index 535416261..5ff77db34 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -36,6 +36,7 @@ #include "common/value_manager.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "extensions/formatting.h" #include "internal/status_macros.h" #include "internal/utf8.h" #include "runtime/function_adapter.h" @@ -431,6 +432,7 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), VariadicFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); return absl::OkStatus(); } diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index cb793e6f6..c02efd639 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -228,6 +228,29 @@ TEST(Strings, UpperAscii) { EXPECT_TRUE(result.GetBool().NativeValue()); } +TEST(Strings, Format) { + google::protobuf::Arena arena; + const auto options = RuntimeOptions{}; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'abc %d'.format([2]) == 'abc 2'", + "", ParserOptions{})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + ASSERT_TRUE(result.Is()); + EXPECT_TRUE(result.GetBool().NativeValue()); +} + TEST(StringsCheckerLibrary, SmokeTest) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 9159b19b1..69fdbb819 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -166,6 +166,11 @@ struct RuntimeOptions { // // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in bool enable_fast_builtins = false; + + // The locale to use for string formatting. + // + // Default is the "en_US" locale. + std::string locale = "en_US"; }; // LINT.ThenChange(//depot/google3/eval/public/cel_options.h)