diff --git a/stdlib/src/collections/string/string.mojo b/stdlib/src/collections/string/string.mojo index b31dbf6a49..d768b2bae2 100644 --- a/stdlib/src/collections/string/string.mojo +++ b/stdlib/src/collections/string/string.mojo @@ -1668,14 +1668,15 @@ struct String( """Return the number of non-overlapping occurrences of substring `substr` in the string. - If sub is empty, returns the number of empty strings between characters - which is the length of the string plus one. - Args: - substr: The substring to count. + substr: The substring to count. Returns: - The number of occurrences of `substr`. + The number of occurrences of `substr`. + + Notes: + If sub is empty, returns the number of empty strings between characters + which is the length of the string plus one. """ return self.as_string_slice().count(substr) @@ -1885,51 +1886,54 @@ struct String( Returns: The string where all occurrences of `old` are replaced with `new`. """ - if not old: - return self._interleave(new) - - var occurrences = self.count(old) - if occurrences == -1: - return self - - var self_start = self.unsafe_ptr() - var self_ptr = self.unsafe_ptr() + var s_ptr = self.unsafe_ptr() var new_ptr = new.unsafe_ptr() - var self_len = self.byte_length() + var s_len = self.byte_length() var old_len = old.byte_length() var new_len = new.byte_length() - var res = Self._buffer_type() - res.reserve(self_len + (old_len - new_len) * occurrences + 1) - - for _ in range(occurrences): - var curr_offset = int(self_ptr) - int(self_start) - - var idx = self.find(old, curr_offset) + if old_len == 0: + var capacity = s_len + new_len * self.byte_length() + 1 + var res_ptr = UnsafePointer[Byte].alloc(capacity) + var offset = 0 + for s in self: + memcpy(res_ptr + offset, new_ptr, new_len) + offset += new_len + memcpy(res_ptr + offset, s.unsafe_ptr(), s.byte_length()) + offset += s.byte_length() + res_ptr[capacity - 1] = 0 + return String(ptr=res_ptr, length=capacity) + + # FIXME(#3792): this should use self.as_bytes().count(old) which will be + # faster because returning unicode offsets has overhead and will return + # less bytes than necessary and cause a segfault + var occurrences = self.count(old) + if occurrences == 0: + return self - debug_assert(idx >= 0, "expected to find occurrence during find") + var capacity = s_len + (new_len - old_len) * occurrences + 1 + var res_ptr = UnsafePointer[Byte].alloc(capacity) + var s_offset = 0 + var res_offset = 0 + while s_offset < s_len: + # FIXME(#3548): this should use raw bytes self.as_bytes().find(...) + var idx = self.find(old, s_offset) + if idx == -1: + memcpy(res_ptr + res_offset, s_ptr + s_offset, s_len - s_offset) + break # Copy preceding unchanged chars - for _ in range(curr_offset, idx): - res.append(self_ptr[]) - self_ptr += 1 - + var length = idx - s_offset + memcpy(res_ptr + res_offset, s_ptr + s_offset, length) + res_offset += length + s_offset += length + old_len # Insert a copy of the new replacement string - for i in range(new_len): - res.append(new_ptr[i]) + memcpy(res_ptr + res_offset, new_ptr, new_len) + res_offset += new_len - self_ptr += old_len - - while True: - var val = self_ptr[] - if val == 0: - break - res.append(self_ptr[]) - self_ptr += 1 - - res.append(0) - return String(res^) + res_ptr[capacity - 1] = 0 + return String(ptr=res_ptr, length=capacity) fn strip(self, chars: StringSlice) -> StringSlice[__origin_of(self)]: """Return a copy of the string with leading and trailing characters @@ -2019,18 +2023,6 @@ struct String( """ hasher._update_with_bytes(self.unsafe_ptr(), self.byte_length()) - fn _interleave(self, val: String) -> String: - var res = Self._buffer_type() - var val_ptr = val.unsafe_ptr() - var self_ptr = self.unsafe_ptr() - res.reserve(val.byte_length() * self.byte_length() + 1) - for i in range(self.byte_length()): - for j in range(val.byte_length()): - res.append(val_ptr[j]) - res.append(self_ptr[i]) - res.append(0) - return String(res^) - fn lower(self) -> String: """Returns a copy of the string with all cased characters converted to lowercase.