diff --git a/extensions/strings.cc b/extensions/strings.cc index 62d41aae1..06d53cb61 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -25,6 +25,7 @@ #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/cord.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "common/casting.h" #include "common/type.h" @@ -89,6 +90,17 @@ absl::StatusOr Join1(ValueManager& value_manager, return Join2(value_manager, value, StringValue{}); } +absl::StatusOr Replace(ValueManager& value_manager, + const StringValue& string, + const StringValue& original, + const StringValue& replacement) { + std::string content = string.NativeString(); + absl::StrReplaceAll({{original.NativeString(), replacement.NativeString()}}, + &content); + // We assume the original strings were well-formed. + return value_manager.CreateUncheckedStringValue(std::move(content)); +} + struct SplitWithEmptyDelimiter { ValueManager& value_manager; int64_t& limit; @@ -230,6 +242,12 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, CreateDescriptor("join", /*receiver_style=*/true), BinaryFunctionAdapter, ListValue, StringValue>::WrapFunction(Join2))); + CEL_RETURN_IF_ERROR(registry.Register( + VariadicFunctionAdapter< + absl::StatusOr, StringValue, StringValue, + StringValue>::CreateDescriptor("replace", /*receiver_style=*/true), + VariadicFunctionAdapter, StringValue, StringValue, + StringValue>::WrapFunction(Replace))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index 070c7a26d..c6a2ba15b 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -39,6 +39,34 @@ using ::google::api::expr::v1alpha1::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; +TEST(Strings, Replace) { + MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting(); + const auto options = RuntimeOptions{}; + ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder(options)); + EXPECT_OK(RegisterStringsFunctions(builder.function_registry(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("foo.replace(\"_\", \" \") == \"hello world!\"", + "", ParserOptions{})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + common_internal::LegacyValueManager value_factory(memory_manager, + runtime->GetTypeProvider()); + + Activation activation; + activation.InsertOrAssignValue("foo", + StringValue{absl::Cord("hello_world!")}); + + ASSERT_OK_AND_ASSIGN(Value result, + program->Evaluate(activation, value_factory)); + ASSERT_TRUE(result.Is()); + EXPECT_TRUE(result.As().NativeValue()); +} + TEST(Strings, SplitWithEmptyDelimiterCord) { MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting(); const auto options = RuntimeOptions{};