Skip to content

Commit ad6d42e

Browse files
authored
feat(rust, python): add strict parameter to decoding expressions (#6342)
1 parent 8eced68 commit ad6d42e

File tree

8 files changed

+77
-33
lines changed

8 files changed

+77
-33
lines changed

polars/polars-core/src/chunked_array/binary/encoding.rs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
1+
use std::borrow::Cow;
2+
13
use base64::engine::general_purpose;
24
use base64::Engine as _;
35
use hex;
46

57
use crate::prelude::*;
68

79
impl BinaryChunked {
8-
pub fn hex_decode(&self) -> PolarsResult<BinaryChunked> {
9-
self.try_apply(|s| {
10-
let bytes =
11-
hex::decode(s).map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
12-
Ok(bytes.into())
13-
})
10+
pub fn hex_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
11+
if strict {
12+
self.try_apply(|s| {
13+
let bytes = hex::decode(s).map_err(|_e| {
14+
PolarsError::ComputeError(
15+
"Invalid 'hex' encoding found. Try setting 'strict' to false to ignore."
16+
.into(),
17+
)
18+
})?;
19+
Ok(bytes.into())
20+
})
21+
} else {
22+
Ok(self.apply_on_opt(|opt_s| opt_s.and_then(|s| hex::decode(s).ok().map(Cow::Owned))))
23+
}
1424
}
1525

1626
pub fn hex_encode(&self) -> Series {
@@ -19,13 +29,22 @@ impl BinaryChunked {
1929
.unwrap()
2030
}
2131

22-
pub fn base64_decode(&self) -> PolarsResult<BinaryChunked> {
23-
self.try_apply(|s| {
24-
let bytes = general_purpose::STANDARD
25-
.decode(s)
26-
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
27-
Ok(bytes.into())
28-
})
32+
pub fn base64_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
33+
if strict {
34+
self.try_apply(|s| {
35+
let bytes = general_purpose::STANDARD.decode(s).map_err(|_e| {
36+
PolarsError::ComputeError(
37+
"Invalid 'base64' encoding found. Try setting 'strict' to false to ignore."
38+
.into(),
39+
)
40+
})?;
41+
Ok(bytes.into())
42+
})
43+
} else {
44+
Ok(self.apply_on_opt(|opt_s| {
45+
opt_s.and_then(|s| general_purpose::STANDARD.decode(s).ok().map(Cow::Owned))
46+
}))
47+
}
2948
}
3049

3150
pub fn base64_encode(&self) -> Series {

polars/polars-core/src/chunked_array/strings/encoding.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ impl Utf8Chunked {
1111
}
1212

1313
#[cfg(feature = "binary_encoding")]
14-
pub fn hex_decode(&self) -> PolarsResult<BinaryChunked> {
14+
pub fn hex_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
1515
self.cast_unchecked(&DataType::Binary)?
1616
.binary()?
17-
.hex_decode()
17+
.hex_decode(strict)
1818
}
1919

2020
#[must_use]
@@ -28,10 +28,10 @@ impl Utf8Chunked {
2828
}
2929

3030
#[cfg(feature = "binary_encoding")]
31-
pub fn base64_decode(&self) -> PolarsResult<BinaryChunked> {
31+
pub fn base64_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
3232
self.cast_unchecked(&DataType::Binary)?
3333
.binary()?
34-
.base64_decode()
34+
.base64_decode(strict)
3535
}
3636

3737
#[must_use]

py-polars/polars/internals/expr/binary.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,23 @@ def starts_with(self, sub: bytes) -> pli.Expr:
5656
"""
5757
return pli.wrap_expr(self._pyexpr.binary_starts_with(sub))
5858

59-
def decode(self, encoding: TransferEncoding) -> pli.Expr:
59+
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Expr:
6060
"""
6161
Decode a value using the provided encoding.
6262
6363
Parameters
6464
----------
6565
encoding : {'hex', 'base64'}
6666
The encoding to use.
67+
strict
68+
Raise an error if the underlying value cannot be decoded,
69+
otherwise mask out with a null value.
6770
6871
"""
6972
if encoding == "hex":
70-
return pli.wrap_expr(self._pyexpr.binary_hex_decode())
73+
return pli.wrap_expr(self._pyexpr.binary_hex_decode(strict))
7174
elif encoding == "base64":
72-
return pli.wrap_expr(self._pyexpr.binary_base64_decode())
75+
return pli.wrap_expr(self._pyexpr.binary_base64_decode(strict))
7376
else:
7477
raise ValueError(
7578
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"

py-polars/polars/internals/expr/string.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -665,20 +665,23 @@ def json_path_match(self, json_path: str) -> pli.Expr:
665665
"""
666666
return pli.wrap_expr(self._pyexpr.str_json_path_match(json_path))
667667

668-
def decode(self, encoding: TransferEncoding) -> pli.Expr:
668+
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Expr:
669669
"""
670670
Decode a value using the provided encoding.
671671
672672
Parameters
673673
----------
674674
encoding : {'hex', 'base64'}
675675
The encoding to use.
676+
strict
677+
Raise an error if the underlying value cannot be decoded,
678+
otherwise mask out with a null value.
676679
677680
"""
678681
if encoding == "hex":
679-
return pli.wrap_expr(self._pyexpr.str_hex_decode())
682+
return pli.wrap_expr(self._pyexpr.str_hex_decode(strict))
680683
elif encoding == "base64":
681-
return pli.wrap_expr(self._pyexpr.str_base64_decode())
684+
return pli.wrap_expr(self._pyexpr.str_base64_decode(strict))
682685
else:
683686
raise ValueError(
684687
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"

py-polars/polars/internals/series/binary.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,17 @@ def starts_with(self, sub: bytes) -> pli.Series:
5656
5757
"""
5858

59-
def decode(self, encoding: TransferEncoding) -> pli.Series:
59+
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Series:
6060
"""
6161
Decode a value using the provided encoding.
6262
6363
Parameters
6464
----------
6565
encoding : {'hex', 'base64'}
6666
The encoding to use.
67+
strict
68+
Raise an error if the underlying value cannot be decoded,
69+
otherwise mask out with a null value.
6770
6871
"""
6972

py-polars/polars/internals/series/string.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,17 @@ def starts_with(self, sub: str) -> pli.Series:
266266
267267
"""
268268

269-
def decode(self, encoding: TransferEncoding) -> pli.Series:
269+
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Series:
270270
"""
271271
Decode a value using the provided encoding.
272272
273273
Parameters
274274
----------
275275
encoding : {'hex', 'base64'}
276276
The encoding to use.
277+
strict
278+
Raise an error if the underlying value cannot be decoded,
279+
otherwise mask out with a null value.
277280
278281
"""
279282

py-polars/src/lazy/dsl.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -739,11 +739,11 @@ impl PyExpr {
739739
.with_fmt("str.hex_encode")
740740
.into()
741741
}
742-
pub fn str_hex_decode(&self) -> PyExpr {
742+
pub fn str_hex_decode(&self, strict: bool) -> PyExpr {
743743
self.clone()
744744
.inner
745745
.map(
746-
move |s| s.utf8()?.hex_decode().map(|s| s.into_series()),
746+
move |s| s.utf8()?.hex_decode(strict).map(|s| s.into_series()),
747747
GetOutput::same_type(),
748748
)
749749
.with_fmt("str.hex_decode")
@@ -760,11 +760,11 @@ impl PyExpr {
760760
.into()
761761
}
762762

763-
pub fn str_base64_decode(&self) -> PyExpr {
763+
pub fn str_base64_decode(&self, strict: bool) -> PyExpr {
764764
self.clone()
765765
.inner
766766
.map(
767-
move |s| s.utf8()?.base64_decode().map(|s| s.into_series()),
767+
move |s| s.utf8()?.base64_decode(strict).map(|s| s.into_series()),
768768
GetOutput::same_type(),
769769
)
770770
.with_fmt("str.base64_decode")
@@ -781,11 +781,11 @@ impl PyExpr {
781781
.with_fmt("binary.hex_encode")
782782
.into()
783783
}
784-
pub fn binary_hex_decode(&self) -> PyExpr {
784+
pub fn binary_hex_decode(&self, strict: bool) -> PyExpr {
785785
self.clone()
786786
.inner
787787
.map(
788-
move |s| s.binary()?.hex_decode().map(|s| s.into_series()),
788+
move |s| s.binary()?.hex_decode(strict).map(|s| s.into_series()),
789789
GetOutput::same_type(),
790790
)
791791
.with_fmt("binary.hex_decode")
@@ -802,11 +802,11 @@ impl PyExpr {
802802
.into()
803803
}
804804

805-
pub fn binary_base64_decode(&self) -> PyExpr {
805+
pub fn binary_base64_decode(&self, strict: bool) -> PyExpr {
806806
self.clone()
807807
.inner
808808
.map(
809-
move |s| s.binary()?.base64_decode().map(|s| s.into_series()),
809+
move |s| s.binary()?.base64_decode(strict).map(|s| s.into_series()),
810810
GetOutput::same_type(),
811811
)
812812
.with_fmt("binary.base64_decode")

py-polars/tests/unit/test_utf8.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
import polars as pl
24

35

@@ -22,3 +24,14 @@ def test_length_vs_nchars() -> None:
2224
]
2325
)
2426
assert df.rows() == [("café", 5, 4), ("東京", 6, 2)]
27+
28+
29+
def test_decode_strict() -> None:
30+
df = pl.DataFrame(
31+
{"strings": ["0IbQvTc3", "0J%2FQldCf0JA%3D", "0J%2FRgNC%2B0YHRgtC%2B"]}
32+
)
33+
assert df.select(pl.col("strings").str.decode("base64", strict=False)).to_dict(
34+
False
35+
) == {"strings": [b"\xd0\x86\xd0\xbd77", None, None]}
36+
with pytest.raises(pl.ComputeError):
37+
df.select(pl.col("strings").str.decode("base64", strict=True))

0 commit comments

Comments
 (0)