Skip to content

Commit

Permalink
[Feature] Store SOAP condition matrices as the dtype of their paramet…
Browse files Browse the repository at this point in the history
…ers (#335)

* Added parameter dtype support throughout the conditioner code

* Update pytorch_optimizer/optimizer/soap.py

* Update pytorch_optimizer/optimizer/soap.py

* Update pytorch_optimizer/optimizer/soap.py

---------

Co-authored-by: Kyle Vedder <[email protected]>
Co-authored-by: Hyeongchan Kim <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2025
1 parent 4b439ab commit aca76b6
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions pytorch_optimizer/optimizer/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def project(

for mat in state['Q']:
if len(mat) > 0:
grad = torch.tensordot(grad, mat.to(grad.dtype), dims=[[0], [0 if project_type == 'forward' else 1]])
grad = torch.tensordot(grad, mat, dims=[[0], [0 if project_type == 'forward' else 1]])
else:
grad = grad.permute([*list(range(1, len(grad.shape))), 0])

Expand All @@ -123,9 +123,11 @@ def get_orthogonal_matrix(mat: torch.Tensor) -> List[torch.Tensor]:
continue

try:
_, q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))
_, q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
except Exception: # pragma: no cover
_, q = torch.linalg.eigh(m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device))
_, q = torch.linalg.eigh(
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=torch.float64)
)
q = q.to(m.dtype)

q = torch.flip(q, dims=[1])
Expand Down Expand Up @@ -156,7 +158,13 @@ def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, mer

power_iter = m @ o[:, sort_idx]

q, _ = torch.linalg.qr(power_iter)
# Compute QR decomposition
# We cast to float32 because:
# - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
# - the correctness / numerical stability of the Q orthogonalization is important for the stability
# of the optimizer
q, _ = torch.linalg.qr(power_iter.to(torch.float32))
q = q.to(power_iter.dtype)

matrices.append(q)

Expand Down Expand Up @@ -185,7 +193,7 @@ def init_pre_conditioner(
if not precondition_1d or grad.shape[0] > max_precondition_dim:
state['GG'].append([])
else:
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
else:
if merge_dims:
grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))
Expand All @@ -194,7 +202,7 @@ def init_pre_conditioner(
if sh > max_precondition_dim:
state['GG'].append([])
else:
state['GG'].append(torch.zeros(sh, sh, device=grad.device))
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))

state['Q'] = None
state['precondition_frequency'] = precondition_frequency
Expand Down

0 comments on commit aca76b6

Please sign in to comment.