Skip to content

Commit c4b581a

Browse files
committed
Raise AttributeError on attempts to access unset oneof fields
This commit modifies `Message.__getattribute__` to raise `AttributeError` whenever an attempt is made to access an unset `oneof` field. This provides several benefits over the current approach: * There is no longer any risk of `betterproto` users accidentally relying on values of unset fields. * Pattern matching with `match/case` on messages containing `oneof` groups is now supported. The following is now possible: ``` @dataclasses.dataclass(eq=False, repr=False) class Test(betterproto.Message): x: int = betterproto.int32_field(1, group="g") y: str = betterproto.string_field(2, group="g") match Test(y="text"): case Test(x=v): print("x", v) case Test(y=v): print("y", v) ``` Before this commit the code above would output `x 0` instead of `y text`, but now the output is `y text` as expected. The reason this works is because an `AttributeError` in a `case` pattern does not propagate and instead simply skips the `case`. * We now have a type-checkable way to deconstruct `oneof`. When running `mypy` for the snippet above `v` has type `int` in the first `case` and type `str` in the second `case`. For versions of Python that do not support `match/case` (before 3.10) it is now possbile to use `try/except/else` blocks to achieve the same result: ``` t = Test(y="text") try: v0: int = t.x except AttributeError: v1: str = t.y # `oneof` contains `y` else: pass # `oneof` contains `x` ``` This is a breaking change. The previous behavior is still accessible via `Message.__unsafe_get`.
1 parent 098989e commit c4b581a

File tree

4 files changed

+43
-24
lines changed

4 files changed

+43
-24
lines changed

src/betterproto/__init__.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,19 @@ def __post_init__(self) -> None:
649649
def __raw_get(self, name: str) -> Any:
650650
return super().__getattribute__(name)
651651

652+
def __unsafe_get(self, name: str) -> Any:
653+
"""
654+
Lazily initialize default values to avoid infinite recursion for recursive
655+
message types
656+
"""
657+
value = super().__getattribute__(name)
658+
if value is not PLACEHOLDER:
659+
return value
660+
661+
value = self._get_field_default(name)
662+
super().__setattr__(name, value)
663+
return value
664+
652665
def __eq__(self, other) -> bool:
653666
if type(self) is not type(other):
654667
return False
@@ -691,17 +704,20 @@ def __repr__(self) -> str:
691704
if not TYPE_CHECKING:
692705

693706
def __getattribute__(self, name: str) -> Any:
694-
"""
695-
Lazily initialize default values to avoid infinite recursion for recursive
696-
message types
697-
"""
698-
value = super().__getattribute__(name)
699-
if value is not PLACEHOLDER:
700-
return value
707+
# Raise `AttributeError` on attempts to access unset `oneof` fields
708+
try:
709+
group_current = super().__getattribute__("_group_current")
710+
except AttributeError:
711+
pass
712+
else:
713+
if name not in {"__class__", "_betterproto"}:
714+
group = self._betterproto.oneof_group_by_field.get(name)
715+
if group is not None and group_current[group] != name:
716+
raise AttributeError(
717+
f"'{self.__class__.__name__}.{group}' is set to '{group_current[group]}', not '{name}'"
718+
)
701719

702-
value = self._get_field_default(name)
703-
super().__setattr__(name, value)
704-
return value
720+
return Message.__unsafe_get(self, name)
705721

706722
def __setattr__(self, attr: str, value: Any) -> None:
707723
if (
@@ -761,7 +777,7 @@ def __bytes__(self) -> bytes:
761777
"""
762778
output = bytearray()
763779
for field_name, meta in self._betterproto.meta_by_field_name.items():
764-
value = getattr(self, field_name)
780+
value = self.__unsafe_get(field_name)
765781

766782
if value is None:
767783
# Optional items should be skipped. This is used for the Google
@@ -1016,7 +1032,7 @@ def parse(self: T, data: bytes) -> T:
10161032
parsed.wire_type, meta, field_name, parsed.value
10171033
)
10181034

1019-
current = getattr(self, field_name)
1035+
current = self.__unsafe_get(field_name)
10201036
if meta.proto_type == TYPE_MAP:
10211037
# Value represents a single key/value pair entry in the map.
10221038
current[value.key] = value.value
@@ -1077,7 +1093,7 @@ def to_dict(
10771093
defaults = self._betterproto.default_gen
10781094
for field_name, meta in self._betterproto.meta_by_field_name.items():
10791095
field_is_repeated = defaults[field_name] is list
1080-
value = getattr(self, field_name)
1096+
value = self.__unsafe_get(field_name)
10811097
cased_name = casing(field_name).rstrip("_") # type: ignore
10821098
if meta.proto_type == TYPE_MESSAGE:
10831099
if isinstance(value, datetime):
@@ -1209,7 +1225,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:
12091225

12101226
if value[key] is not None:
12111227
if meta.proto_type == TYPE_MESSAGE:
1212-
v = getattr(self, field_name)
1228+
v = self.__unsafe_get(field_name)
12131229
cls = self._betterproto.cls_by_field[field_name]
12141230
if isinstance(v, list):
12151231
if cls == datetime:
@@ -1486,7 +1502,6 @@ def _validate_field_groups(cls, values):
14861502
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
14871503

14881504
for group, field_set in group_to_one_ofs.items():
1489-
14901505
if len(field_set) == 1:
14911506
(field,) = field_set
14921507
field_name = field.name

tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def test_bytes_are_the_same_for_oneof():
5050

5151
# None of these fields were explicitly set BUT they should not actually be null
5252
# themselves
53-
assert isinstance(message.foo, Foo)
54-
assert isinstance(message2.foo, Foo)
53+
assert isinstance(message._Message__unsafe_get("foo"), Foo)
54+
assert isinstance(message2._Message__unsafe_get("foo"), Foo)
5555

5656
assert isinstance(message_reference.foo, ReferenceFoo)
5757
assert isinstance(message_reference2.foo, ReferenceFoo)

tests/inputs/oneof_enum/test_oneof_enum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_which_one_of_returns_enum_with_default_value():
1818
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
1919
)
2020

21-
assert message.move == Move(
21+
assert message._Message__unsafe_get("move") == Move(
2222
x=0, y=0
2323
) # Proto3 will default this as there is no null
2424
assert message.signal == Signal.PASS
@@ -33,7 +33,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
3333
message.from_json(
3434
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
3535
)
36-
assert message.move == Move(
36+
assert message._Message__unsafe_get("move") == Move(
3737
x=0, y=0
3838
) # Proto3 will default this as there is no null
3939
assert message.signal == Signal.RESIGN
@@ -44,5 +44,5 @@ def test_which_one_of_returns_second_field_when_set():
4444
message = Test()
4545
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
4646
assert message.move == Move(x=2, y=3)
47-
assert message.signal == Signal.PASS
47+
assert message._Message__unsafe_get("signal") == Signal.PASS
4848
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

tests/test_features.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
Optional,
1919
)
2020

21+
import pytest
22+
2123
import betterproto
2224

2325

@@ -151,17 +153,19 @@ class Foo(betterproto.Message):
151153
foo.baz = "test"
152154

153155
# Other oneof fields should now be unset
154-
assert foo.bar == 0
156+
with pytest.raises(AttributeError):
157+
foo.bar
155158
assert betterproto.which_one_of(foo, "group1")[0] == "baz"
156159

157-
foo.sub.val = 1
160+
foo.sub = Sub(val=1)
158161
assert betterproto.serialized_on_wire(foo.sub)
159162

160163
foo.abc = "test"
161164

162165
# Group 1 shouldn't be touched, group 2 should have reset
163-
assert foo.sub.val == 0
164-
assert betterproto.serialized_on_wire(foo.sub) is False
166+
with pytest.raises(AttributeError):
167+
foo.sub.val
168+
assert betterproto.serialized_on_wire(foo._Message__unsafe_get("sub")) is False
165169
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
166170

167171
# Zero value should always serialize for one-of

0 commit comments

Comments
 (0)