Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion parity_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A parity tensor package.
"""

__all__ = ["__version__"]
__all__ = ["__version__", "ParityTensor"]

from .version import __version__
from .parity_tensor import ParityTensor
173 changes: 173 additions & 0 deletions parity_tensor/parity_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
A parity tensor class.
"""

from __future__ import annotations

__all__ = ["ParityTensor"]

import dataclasses
import typing
import torch


@dataclasses.dataclass
class ParityTensor:
"""
A parity tensor class, which stores a tensor along with information about its edges.
Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers.
"""

edges: tuple[tuple[int, int], ...]
tensor: torch.Tensor

def __post_init__(self) -> None:
assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})."
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."
Comment on lines +25 to +27
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assert statements for input validation in production code is not recommended as they can be disabled with the -O flag. Consider raising a ValueError instead for more robust error handling.

Suggested change
assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})."
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."
if len(self.edges) != self.tensor.dim():
raise ValueError(f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()}).")
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
if even < 0 or odd < 0 or dim != even + odd:
raise ValueError(f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative.")

Copilot uses AI. Check for mistakes.
Comment on lines +25 to +27
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assert statements for input validation in production code is not recommended as they can be disabled with the -O flag. Consider raising a ValueError instead for more robust error handling.

Suggested change
assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})."
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."
if len(self.edges) != self.tensor.dim():
raise ValueError(f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()}).")
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
if even < 0 or odd < 0 or dim != even + odd:
raise ValueError(f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative.")

Copilot uses AI. Check for mistakes.

def _validate_edge_compatibility(self, other: ParityTensor) -> None:
"""
Validate that the edges of two ParityTensor instances are compatible for arithmetic operations.
"""
assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}."
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assert statements for input validation in production code is not recommended as they can be disabled with the -O flag. Consider raising a ValueError instead for more robust error handling.

Suggested change
assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}."
if self.edges != other.edges:
raise ValueError(f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}.")

Copilot uses AI. Check for mistakes.

def __pos__(self) -> ParityTensor:
return ParityTensor(
edges=self.edges,
tensor=+self.tensor,
)

def __neg__(self) -> ParityTensor:
return ParityTensor(
edges=self.edges,
tensor=-self.tensor,
)

def __add__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
edges=self.edges,
tensor=self.tensor + other.tensor,
)
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No input validation is performed on the 'other' parameter before attempting tensor operations. This could lead to unexpected behavior or errors if 'other' is not a valid operand for tensor arithmetic.

Suggested change
)
)
if not isinstance(other, (int, float, torch.Tensor)):
raise TypeError(f"Unsupported operand type(s) for +: 'ParityTensor' and '{type(other).__name__}'")

Copilot uses AI. Check for mistakes.
result = self.tensor + other
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __radd__(self, other: typing.Any) -> ParityTensor:
result = other + self.tensor
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __iadd__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor += other.tensor
else:
self.tensor += other
return self

def __sub__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
edges=self.edges,
tensor=self.tensor - other.tensor,
)
result = self.tensor - other
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __rsub__(self, other: typing.Any) -> ParityTensor:
result = other - self.tensor
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __isub__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor -= other.tensor
else:
self.tensor -= other
return self

def __mul__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Element-wise multiplication between ParityTensors with the same edges may not preserve the parity structure. Consider whether multiplication should validate edge compatibility or if it should have different behavior than addition/subtraction.

Suggested change
self._validate_edge_compatibility(other)
self._validate_edge_compatibility(other)
self._validate_parity_structure(other)

Copilot uses AI. Check for mistakes.
return ParityTensor(
edges=self.edges,
tensor=self.tensor * other.tensor,
)
result = self.tensor * other
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __rmul__(self, other: typing.Any) -> ParityTensor:
result = other * self.tensor
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __imul__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor *= other.tensor
else:
self.tensor *= other
return self

def __truediv__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
edges=self.edges,
tensor=self.tensor / other.tensor,
)
Comment on lines +146 to +149
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Element-wise division between ParityTensors with the same edges may not preserve the parity structure. Consider whether division should validate edge compatibility or if it should have different behavior than addition/subtraction.

Suggested change
return ParityTensor(
edges=self.edges,
tensor=self.tensor / other.tensor,
)
result = ParityTensor(
edges=self.edges,
tensor=self.tensor / other.tensor,
)
result._validate_parity_preservation()
return result

Copilot uses AI. Check for mistakes.
result = self.tensor / other
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __rtruediv__(self, other: typing.Any) -> ParityTensor:
result = other / self.tensor
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
)
return NotImplemented

def __itruediv__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor /= other.tensor
else:
self.tensor /= other
return self
Loading