From 565284c3fe3b6e30af307884b44cc0e6cddb1331 Mon Sep 17 00:00:00 2001 From: Manuel Saelices Date: Mon, 6 Jan 2025 23:41:09 +0100 Subject: [PATCH 1/3] Add a new validate parameter to the b64decode() function Signed-off-by: Manuel Saelices --- stdlib/src/base64/base64.mojo | 19 +++++++++++++------ stdlib/test/base64/test_base64.mojo | 10 +++++++++- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/stdlib/src/base64/base64.mojo b/stdlib/src/base64/base64.mojo index 6a9d585fb5..39c6a2de3e 100644 --- a/stdlib/src/base64/base64.mojo +++ b/stdlib/src/base64/base64.mojo @@ -112,9 +112,12 @@ fn b64encode(input_bytes: List[UInt8, _]) -> String: @always_inline -fn b64decode(str: String) -> String: +fn b64decode[validate: Bool = False](str: String) raises -> String: """Performs base64 decoding on the input string. + Parameters: + validate: If true, the function will validate the input string. + Args: str: A base64 encoded string. @@ -122,7 +125,11 @@ fn b64decode(str: String) -> String: The decoded string. """ var n = str.byte_length() - debug_assert(n % 4 == 0, "Input length must be divisible by 4") + + @parameter + if validate: + if n % 4 != 0: + raise Error("ValueError: Input length must be divisible by 4") var p = String._buffer_type(capacity=n + 1) @@ -133,10 +140,10 @@ fn b64decode(str: String) -> String: var c = _ascii_to_value(str[i + 2]) var d = _ascii_to_value(str[i + 3]) - debug_assert( - a >= 0 and b >= 0 and c >= 0 and d >= 0, - "Unexpected character encountered", - ) + @parameter + if validate: + if a < 0 or b < 0 or c < 0 or d < 0: + raise Error("ValueError: Unexpected character encountered") p.append((a << 2) | (b >> 4)) if str[i + 2] == "=": diff --git a/stdlib/test/base64/test_base64.mojo b/stdlib/test/base64/test_base64.mojo index 6512844905..dd66e5e399 100644 --- a/stdlib/test/base64/test_base64.mojo +++ b/stdlib/test/base64/test_base64.mojo @@ -14,7 +14,7 @@ from base64 import b16decode, b16encode, b64decode, b64encode -from testing import assert_equal +from testing import assert_equal, assert_raises def test_b64encode(): @@ -60,6 +60,14 @@ def test_b64decode(): assert_equal(b64decode("QUJDREVGYWJjZGVm"), "ABCDEFabcdef") + with assert_raises( + contains="ValueError: Input length must be divisible by 4" + ): + _ = b64decode[validate=True]("invalid base64 string") + + with assert_raises(contains="ValueError: Unexpected character encountered"): + _ = b64decode[validate=True]("invalid base64 string!!!") + def test_b16encode(): assert_equal(b16encode("a"), "61") From b32f2f591eb061b882af08c215e61d4a93879082 Mon Sep 17 00:00:00 2001 From: Manuel Saelices Date: Mon, 6 Jan 2025 23:47:25 +0100 Subject: [PATCH 2/3] Add changelog entry for the b64decode() new validate param Signed-off-by: Manuel Saelices --- docs/changelog.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 69c235f28c..bb4de3c75f 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -32,6 +32,8 @@ what we publish. ### Standard library changes +- Add a new `validate` parameter to the `b64decode()` function. + - `UnsafePointer`'s `bitcast` method has now been split into `bitcast` for changing the type, `origin_cast` for changing mutability, `static_alignment_cast` for changing alignment, From c1873d6ad78d23b45d335be2a9ffc2fba136426f Mon Sep 17 00:00:00 2001 From: Manuel Saelices Date: Tue, 21 Jan 2025 14:45:54 +0100 Subject: [PATCH 3/3] Improve error messages including the length and the problematic char Signed-off-by: Manuel Saelices --- stdlib/src/base64/base64.mojo | 25 ++++++++++++++----------- stdlib/test/base64/test_base64.mojo | 6 ++++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/stdlib/src/base64/base64.mojo b/stdlib/src/base64/base64.mojo index b063916ddc..f0929a0f3e 100644 --- a/stdlib/src/base64/base64.mojo +++ b/stdlib/src/base64/base64.mojo @@ -34,7 +34,7 @@ from ._b64encode import b64encode_with_buffers as _b64encode_with_buffers @always_inline -fn _ascii_to_value(char: StringSlice) -> Int: +fn _ascii_to_value[validate: Bool = False](char: StringSlice) raises -> Int: """Converts an ASCII character to its integer value for base64 decoding. Args: @@ -58,6 +58,12 @@ fn _ascii_to_value(char: StringSlice) -> Int: elif char == "/": return 63 else: + + @parameter + if validate: + raise Error( + 'ValueError: Unexpected character "{}" encountered'.format(char) + ) return -1 @@ -129,21 +135,18 @@ fn b64decode[validate: Bool = False](str: StringSlice) raises -> String: @parameter if validate: if n % 4 != 0: - raise Error("ValueError: Input length must be divisible by 4") + raise Error( + "ValueError: Input length {} must be divisible by 4".format(n) + ) var p = String._buffer_type(capacity=n + 1) # This algorithm is based on https://arxiv.org/abs/1704.00605 for i in range(0, n, 4): - var a = _ascii_to_value(str[i]) - var b = _ascii_to_value(str[i + 1]) - var c = _ascii_to_value(str[i + 2]) - var d = _ascii_to_value(str[i + 3]) - - @parameter - if validate: - if a < 0 or b < 0 or c < 0 or d < 0: - raise Error("ValueError: Unexpected character encountered") + var a = _ascii_to_value[validate](str[i]) + var b = _ascii_to_value[validate](str[i + 1]) + var c = _ascii_to_value[validate](str[i + 2]) + var d = _ascii_to_value[validate](str[i + 3]) p.append((a << 2) | (b >> 4)) if str[i + 2] == "=": diff --git a/stdlib/test/base64/test_base64.mojo b/stdlib/test/base64/test_base64.mojo index dd66e5e399..f8b7993c00 100644 --- a/stdlib/test/base64/test_base64.mojo +++ b/stdlib/test/base64/test_base64.mojo @@ -61,11 +61,13 @@ def test_b64decode(): assert_equal(b64decode("QUJDREVGYWJjZGVm"), "ABCDEFabcdef") with assert_raises( - contains="ValueError: Input length must be divisible by 4" + contains="ValueError: Input length 21 must be divisible by 4" ): _ = b64decode[validate=True]("invalid base64 string") - with assert_raises(contains="ValueError: Unexpected character encountered"): + with assert_raises( + contains='ValueError: Unexpected character " " encountered' + ): _ = b64decode[validate=True]("invalid base64 string!!!")