-
Notifications
You must be signed in to change notification settings - Fork 0
Add the basic parity tensor class and its arithmetic operators. #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6bf1100
40ba62a
4ee3380
93e2bb8
bee4047
6453b41
ae6aaf9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,173 @@ | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| A parity tensor class. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| __all__ = ["ParityTensor"] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import dataclasses | ||||||||||||||||||||||
| import typing | ||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @dataclasses.dataclass | ||||||||||||||||||||||
| class ParityTensor: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| A parity tensor class, which stores a tensor along with information about its edges. | ||||||||||||||||||||||
| Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| edges: tuple[tuple[int, int], ...] | ||||||||||||||||||||||
| tensor: torch.Tensor | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __post_init__(self) -> None: | ||||||||||||||||||||||
| assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})." | ||||||||||||||||||||||
| for dim, (even, odd) in zip(self.tensor.shape, self.edges): | ||||||||||||||||||||||
| assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." | ||||||||||||||||||||||
|
Comment on lines
+25
to
+27
|
||||||||||||||||||||||
| assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})." | |
| for dim, (even, odd) in zip(self.tensor.shape, self.edges): | |
| assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." | |
| if len(self.edges) != self.tensor.dim(): | |
| raise ValueError(f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()}).") | |
| for dim, (even, odd) in zip(self.tensor.shape, self.edges): | |
| if even < 0 or odd < 0 or dim != even + odd: | |
| raise ValueError(f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative.") |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert statements for input validation in production code is not recommended as they can be disabled with the -O flag. Consider raising a ValueError instead for more robust error handling.
| assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}." | |
| if self.edges != other.edges: | |
| raise ValueError(f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}.") |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No input validation is performed on the 'other' parameter before attempting tensor operations. This could lead to unexpected behavior or errors if 'other' is not a valid operand for tensor arithmetic.
| ) | |
| ) | |
| if not isinstance(other, (int, float, torch.Tensor)): | |
| raise TypeError(f"Unsupported operand type(s) for +: 'ParityTensor' and '{type(other).__name__}'") |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Element-wise multiplication between ParityTensors with the same edges may not preserve the parity structure. Consider whether multiplication should validate edge compatibility or if it should have different behavior than addition/subtraction.
| self._validate_edge_compatibility(other) | |
| self._validate_edge_compatibility(other) | |
| self._validate_parity_structure(other) |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Element-wise division between ParityTensors with the same edges may not preserve the parity structure. Consider whether division should validate edge compatibility or if it should have different behavior than addition/subtraction.
| return ParityTensor( | |
| edges=self.edges, | |
| tensor=self.tensor / other.tensor, | |
| ) | |
| result = ParityTensor( | |
| edges=self.edges, | |
| tensor=self.tensor / other.tensor, | |
| ) | |
| result._validate_parity_preservation() | |
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert statements for input validation in production code is not recommended as they can be disabled with the -O flag. Consider raising a ValueError instead for more robust error handling.