Skip to content

Commit df400fb

Browse files
author
Vladimir Vargas Calderón
committed
Add Givens orthogonal layer
1 parent d08b203 commit df400fb

7 files changed

Lines changed: 432 additions & 20 deletions

File tree

dwave/plugins/torch/nn/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
#
1515

1616
from dwave.plugins.torch.nn.modules.linear import *
17+
from dwave.plugins.torch.nn.modules.orthogonal import *
1718
from dwave.plugins.torch.nn.modules.utils import *
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright 2025 D-Wave
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import deque
16+
17+
import torch
18+
import torch.nn as nn
19+
from einops import einsum
20+
21+
from dwave.plugins.torch.nn.modules.utils import store_config
22+
23+
__all__ = ["GivensRotation"]
24+
25+
26+
class _RoundRobinGivens(torch.autograd.Function):
27+
"""Implements custom forward and backward passes to implement the parallel algorithms in
28+
https://arxiv.org/abs/2106.00003
29+
30+
.. note::
31+
We adopt the notation from the paper, but instead of using the rows of U to compute
32+
rotations, we follow the standard convention of using the columns of U. Since U is
33+
orthogonal, this does not affect the result.
34+
"""
35+
36+
@staticmethod
37+
def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor:
38+
"""Creates a rotation matrix in n dimensions using parallel Givens transformations by
39+
blocks.
40+
41+
Args:
42+
ctx (context): Stores information for backward propagation.
43+
angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations
44+
between pairs of dimensions.
45+
blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices
46+
that specify rotations between pairs of dimensions. Each of the ``n - 1`` blocks
47+
contains ``n // 2`` pairs of independent rotations.
48+
n (int): Dimension of the space.
49+
50+
Returns:
51+
torch.Tensor: The nxn rotation matrix.
52+
"""
53+
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
54+
# Within each block, each Givens rotation is commuting, so we can apply them in parallel
55+
U = torch.eye(n, device=angles.device, dtype=angles.dtype)
56+
block_size = n // 2
57+
idx_block = torch.arange(block_size, device=angles.device)
58+
B = blocks # to keep the same notation as in the paper
59+
for b, block in enumerate(B):
60+
# angles is of shape (n_angles,) containing all angles for contiguous blocks.
61+
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)
62+
c = torch.cos(angles_in_block).unsqueeze(0)
63+
s = torch.sin(angles_in_block).unsqueeze(0)
64+
i_idx = block[:, 0]
65+
j_idx = block[:, 1]
66+
r_i = c * U[:, i_idx] + s * U[:, j_idx]
67+
r_j = -s * U[:, i_idx] + c * U[:, j_idx]
68+
U[:, i_idx] = r_i
69+
U[:, j_idx] = r_j
70+
ctx.save_for_backward(angles, B, U)
71+
return U
72+
73+
@staticmethod
74+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
75+
"""Computes the vector-Jacobian product needed for backward propagation.
76+
77+
Args:
78+
ctx (context): Contains information for backward propagation.
79+
grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss
80+
with respect to the output of the forward pass, i.e., dL/dU.
81+
82+
Returns:
83+
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input
84+
angles. No calculation of gradients with respect to blocks or n is needed (cf.
85+
forward method), so None is returned for these.
86+
"""
87+
angles, B, Ufwd_saved = ctx.saved_tensors
88+
Ufwd = Ufwd_saved.clone()
89+
M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n)
90+
n = M.size(1)
91+
block_size = n // 2
92+
A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype)
93+
grad_theta = torch.zeros_like(angles, dtype=angles.dtype)
94+
idx_block = torch.arange(block_size, device=angles.device)
95+
for b, block in enumerate(B):
96+
i_idx = block[:, 0]
97+
j_idx = block[:, 1]
98+
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)
99+
c = torch.cos(angles_in_block)
100+
s = torch.sin(angles_in_block)
101+
r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx]
102+
r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx]
103+
Ufwd[i_idx] = r_i
104+
Ufwd[j_idx] = r_j
105+
r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx]
106+
r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx]
107+
M[:, i_idx] = r_i
108+
M[:, j_idx] = r_j
109+
A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx]
110+
grad_theta[idx_block + b * block_size] = A.sum(dim=1)
111+
return grad_theta, None, None
112+
113+
114+
class GivensRotation(nn.Module):
115+
"""An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in
116+
a round-robin fashion.
117+
118+
Angles are arranged into blocks, where each block references rotations that can be applied in
119+
parallel because these rotations commute.
120+
121+
Args:
122+
n (int): Dimension of the input and output space. Must be at least 2.
123+
bias (bool): If True, adds a learnable bias to the output. Default: True.
124+
"""
125+
126+
@store_config
127+
def __init__(self, n: int, bias: bool = True):
128+
super().__init__()
129+
if not isinstance(n, int) or n <= 1:
130+
raise ValueError(f"n must be an integer greater than 1, {n} was passed")
131+
if not isinstance(bias, bool):
132+
raise ValueError(f"bias must be a boolean, {bias} was passed")
133+
self.n = n
134+
self.n_angles = n * (n - 1) // 2
135+
self.angles = nn.Parameter(torch.randn(self.n_angles))
136+
blocks_edges = self._get_blocks_edges(n)
137+
self.register_buffer("blocks", blocks_edges)
138+
if bias:
139+
self.bias = nn.Parameter(torch.zeros(n))
140+
else:
141+
self.register_parameter("bias", None)
142+
143+
@staticmethod
144+
def _get_blocks_edges(n: int) -> torch.Tensor:
145+
"""Uses the circle method for Round Robin pairing to create blocks of edges for parallel
146+
Givens rotations.
147+
148+
A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs
149+
in the same block can be rotated in parallel since they commute.
150+
151+
Args:
152+
n (int): Dimension of the vector space onto which an orthogonal layer will be built.
153+
154+
Returns:
155+
torch.Tensor: Blocks of edges for parallel Givens rotations stored in a tensor of shape
156+
``(n - 1, n // 2, 2)``.
157+
158+
.. note::
159+
If n is odd, a dummy dimension is added to make it even. When using the resulting blocks
160+
to build an orthogonal transformation, rotations involving the dummy dimension should be
161+
ignored.
162+
"""
163+
is_odd = bool(n % 2 != 0)
164+
if is_odd:
165+
# The circle method requires an even number of nodes, so we add a dummy dimension, the
166+
# additional rotations involving this dimension will be ignored later.
167+
n += 1
168+
169+
def circle_method(sequence):
170+
seq_first_half = sequence[: len(sequence) // 2]
171+
seq_second_half = sequence[len(sequence) // 2 :][::-1]
172+
return list(zip(seq_first_half, seq_second_half))
173+
174+
blocks = []
175+
sequence = list(range(n))
176+
sequence_deque = deque(sequence[1:])
177+
for _ in range(n - 1):
178+
pairs = circle_method(sequence)
179+
if is_odd:
180+
# Remove pairs involving the dummy dimension:
181+
pairs = [pair for pair in pairs if n - 1 not in pair]
182+
blocks.append(pairs)
183+
sequence_deque.rotate(1)
184+
sequence[1:] = list(sequence_deque)
185+
return torch.tensor(blocks, dtype=torch.long)
186+
187+
def _create_rotation_matrix(self) -> torch.Tensor:
188+
"""Computes the Givens rotation matrix."""
189+
return _RoundRobinGivens.apply(self.angles, self.blocks, self.n)
190+
191+
def forward(self, x: torch.Tensor) -> torch.Tensor:
192+
"""Applies the Givens rotation to the input tensor ``x``.
193+
194+
Args:
195+
x (torch.Tensor): Input tensor of shape ``(..., n)``.
196+
197+
Returns:
198+
torch.Tensor: Rotated tensor of shape ``(..., n)``.
199+
"""
200+
unitary = self._create_rotation_matrix()
201+
rotated_x = einsum(x, unitary, "... i, o i -> ... o")
202+
if self.bias is not None:
203+
rotated_x += self.bias
204+
return rotated_x

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"dimod",
3333
"dwave-system",
3434
"dwave-hybrid",
35+
"einops",
3536
]
3637

3738
[project.readme]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Add orthogonal rotation layer using Givens rotations.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ torch==2.9.1
66
dimod==0.12.21
77
dwave-system==1.34.0
88
dwave-hybrid==0.6.14
9+
einops==0.8.1
910

1011
# Development requirements
1112
reno==4.1.0

tests/helper_models.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
import torch.nn as nn
3+
from einops import einsum
4+
5+
6+
class NaiveGivensRotationLayer(nn.Module):
7+
"""Naive implementation of a Givens rotation layer.
8+
9+
Sequentially applies all Givens rotations to implement an orthogonal transformation in an order
10+
provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains
11+
pairs of indices such that no index appears more than once in a block. However, this
12+
implementation does not rely on that assumption, so that indeces can appear multiple times in a
13+
block; however, all pairs of indices must appear exactly once in the entire blocks tensor.
14+
15+
Args:
16+
in_features (int): Number of input features.
17+
out_features (int): Number of output features.
18+
bias (bool): If True, adds a learnable bias to the output. Default: True.
19+
20+
Note:
21+
This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to
22+
out_features.
23+
"""
24+
25+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
26+
super().__init__()
27+
assert in_features == out_features, (
28+
"This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to "
29+
"out_features."
30+
)
31+
self.n = in_features
32+
self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2))
33+
if bias:
34+
self.bias = nn.Parameter(torch.zeros(out_features))
35+
else:
36+
self.register_parameter("bias", None)
37+
38+
def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None):
39+
"""Creates the rotation matrix from the Givens angles by applying the Givens rotations in
40+
order and sequentially, as specified by blocks.
41+
42+
Args:
43+
angles (torch.Tensor): Givens rotation angles.
44+
blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If
45+
None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create
46+
the blocks. Defaults to None.
47+
48+
Returns:
49+
torch.Tensor: Rotation matrix.
50+
"""
51+
block_size = self.n // 2
52+
if blocks is None:
53+
# Create dummy blocks from triu indices:
54+
triu_indices = torch.triu_indices(self.n, self.n, offset=1)
55+
blocks = triu_indices.t().view(-1, block_size, 2)
56+
U = torch.eye(self.n, dtype=angles.dtype)
57+
for b, block in enumerate(blocks):
58+
for k in range(block_size):
59+
i = block[k, 0].item()
60+
j = block[k, 1].item()
61+
angle = angles[b * block_size + k]
62+
c = torch.cos(angle)
63+
s = torch.sin(angle)
64+
Ge = torch.eye(self.n, dtype=angles.dtype)
65+
Ge[i, i] = c
66+
Ge[j, j] = c
67+
Ge[i, j] = -s
68+
Ge[j, i] = s
69+
# Explicit Givens rotation
70+
U = U @ Ge
71+
return U
72+
73+
def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor:
74+
"""Applies the Givens rotation to the input tensor ``x``.
75+
76+
Args:
77+
x (torch.Tensor): Input tensor of shape (..., n).
78+
blocks (torch.Tensor): Blocks specifying the order of rotations.
79+
80+
Returns:
81+
torch.Tensor: Rotated tensor of shape (..., n).
82+
"""
83+
W = self._create_rotation_matrix(self.angles, blocks)
84+
x = einsum(x, W, "... i, o i -> ... o")
85+
if self.bias is not None:
86+
x = x + self.bias
87+
return x

0 commit comments

Comments
 (0)