Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qmb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import vmc as _ # type: ignore[no-redef]
from . import imag as _ # type: ignore[no-redef]
from . import rldiag as _ # type: ignore[no-redef]
from . import rlim as _ # type: ignore[no-redef]
from . import precompile as _ # type: ignore[no-redef]
from . import list_loss as _ # type: ignore[no-redef]
from . import chop_imag as _ # type: ignore[no-redef]
Expand Down
14 changes: 11 additions & 3 deletions qmb/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ def __init__(self, dim_input: int, dim_output: int, hidden_size: tuple[int, ...]

dimensions: list[int] = [dim_input] + list(hidden_size) + [dim_output]
linears: list[torch.nn.Module] = [select_linear_layer(i, j) for i, j in itertools.pairwise(dimensions)]
modules: list[torch.nn.Module] = [module for linear in linears for module in (linear, torch.nn.SiLU())][:-1]
self.model: torch.nn.Module = torch.nn.Sequential(*modules)
self.layers: torch.nn.ModuleList = torch.nn.ModuleList(linears)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MLP.
"""
return self.model(x)
for index, layer in enumerate(self.layers):
y = layer(x)
if x.shape != y.shape:
x = y
else:
x = x + y
if index != len(self.layers) - 1:
x = torch.nn.functional.normalize(x, dim=-1)
x = torch.nn.functional.silu(x)
return x


class WaveFunctionElectronUpDown(torch.nn.Module):
Expand Down
150 changes: 150 additions & 0 deletions qmb/rlim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
This file implements a reinforcement learning based imaginary time evolution algorithm.
"""

import logging
import typing
import dataclasses
import torch
import torch.utils.tensorboard
import tyro
from .common import CommonConfig
from .subcommand_dict import subcommand_dict
from .optimizer import initialize_optimizer


@dataclasses.dataclass
class RlimConfig:
"""
The reinforcement learning based imaginary time evolution algorithm.
"""

# pylint: disable=too-many-instance-attributes

common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes]

# The sampling count
sampling_count: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 4000
# The number of relative configurations to be used in energy calculation
relative_count: typing.Annotated[int, tyro.conf.arg(aliases=["-c"])] = 40000
# The learning rate for the local optimizer
learning_rate: typing.Annotated[float, tyro.conf.arg(aliases=["-r"])] = 1e-3
# The learning rate for the imaginary time evolution
evolution_time: typing.Annotated[float, tyro.conf.arg(aliases=["-t"])] = 1e-3
# The number of steps for the local optimizer
local_step: typing.Annotated[int, tyro.conf.arg(aliases=["-s"])] = 32
# The dropout of the loss function
dropout: typing.Annotated[float, tyro.conf.arg(aliases=["-d"])] = 0.5

def main(self) -> None:
"""
The main function for the RLIM optimization.
"""
# pylint: disable=too-many-statements
# pylint: disable=too-many-locals

model, network, data = self.common.main()
ref_network = network
Copy link

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning ref_network to network creates an alias—updates to network will also affect ref_network. Consider using copy.deepcopy(network) to freeze the reference model.

Suggested change
ref_network = network
ref_network = copy.deepcopy(network)

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning ref_network = network creates a reference to the same object rather than an independent copy. This means both networks will be updated simultaneously during optimization, which may not be the intended behavior for a reference network that should remain stable.

Suggested change
ref_network = network
ref_network = type(network)() # Create a new instance of the same model class
ref_network.load_state_dict(network.state_dict()) # Copy the parameters from the original network

Copilot uses AI. Check for mistakes.

logging.info(
"Arguments Summary: "
"Sampling Count: %d, "
"Relative Count: %d, "
"Learning Rate: %.10f, "
"Evolution Time: %.10f, "
"Local Steps: %d, "
"Dropout: %.2f",
self.sampling_count,
self.relative_count,
self.learning_rate,
self.evolution_time,
self.local_step,
self.dropout,
)

optimizer = initialize_optimizer(
network.parameters(),
use_lbfgs=False,
learning_rate=self.learning_rate,
state_dict=data.get("optimizer"),
)

if "rlim" not in data:
data["rlim"] = {"global": 0, "local": 0}

writer = torch.utils.tensorboard.SummaryWriter(log_dir=self.common.folder()) # type: ignore[no-untyped-call]

while True:
logging.info("Starting a new optimization cycle")

logging.info("Sampling configurations")
configs_i, psi_i, _, _ = network.generate_unique(self.sampling_count)
ref_configs_i, ref_psi_i, _, _ = ref_network.generate_unique(self.sampling_count)
logging.info("Sampling completed, unique configurations count: %d, reference unique configurations count: %d", len(configs_i), len(ref_configs_i))

logging.info("Calculating relative configurations")
if self.relative_count <= len(configs_i):
configs_src = configs_i
configs_dst = configs_i
else:
configs_src = configs_i
configs_dst = torch.cat([configs_i, model.find_relative(configs_i, psi_i, self.relative_count - len(configs_i))])
logging.info("Relative configurations calculated, count: %d", len(configs_dst))
if self.relative_count <= len(ref_configs_i):
ref_configs_src = ref_configs_i
ref_configs_dst = ref_configs_i
else:
ref_configs_src = ref_configs_i
ref_configs_dst = torch.cat([ref_configs_i, model.find_relative(ref_configs_i, ref_psi_i, self.relative_count - len(ref_configs_i))])
logging.info("Reference relative configurations calculated, count: %d", len(ref_configs_dst))

def closure() -> torch.Tensor:
# Optimizing loss
optimizer.zero_grad()
psi_src = network(configs_src) # psi s
ref_psi_src = network(ref_configs_src) # psi r
Copy link

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference wavefunction is computed with the main network instead of ref_network. Update this (and similar calls) to use ref_network so the reference model remains unchanged.

Suggested change
ref_psi_src = network(ref_configs_src) # psi r
ref_psi_src = ref_network(ref_configs_src) # psi r

Copilot uses AI. Check for mistakes.
with torch.no_grad():
psi_dst = network(configs_dst) # psi s'
ref_psi_dst = network(ref_configs_dst) # psi r'
hamiltonian_psi_dst = model.apply_within(configs_dst, psi_dst, configs_src) # H ss' psi s'
ref_hamiltonian_psi_dst = model.apply_within(ref_configs_dst, ref_psi_dst, ref_configs_src) # H rr' psi r'
a = torch.outer(psi_src.detach(), ref_psi_src) - torch.outer(psi_src, ref_psi_src.detach())
b = torch.outer(hamiltonian_psi_dst, ref_psi_src) - torch.outer(psi_src, ref_hamiltonian_psi_dst)
diff = torch.nn.functional.dropout(torch.view_as_real(a - self.evolution_time * b).abs(), p=self.dropout).flatten()
loss = diff @ diff
loss.backward() # type: ignore[no-untyped-call]
# Calculate energy
with torch.no_grad():
num = psi_src.conj() @ hamiltonian_psi_dst
den = psi_src.conj() @ psi_src
energy = (num / den).real
loss.energy = energy # type: ignore[attr-defined]
return loss

logging.info("Starting local optimization process")

for i in range(self.local_step):
loss: torch.Tensor = optimizer.step(closure) # type: ignore[assignment,arg-type]
energy: float = loss.energy # type: ignore[attr-defined]
Comment on lines +121 to +128
Copy link

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamically adding attributes to tensor objects is not a clean practice. Consider returning a tuple or using a dataclass to pass both loss and energy values instead of monkey-patching the tensor object.

Suggested change
loss.energy = energy # type: ignore[attr-defined]
return loss
logging.info("Starting local optimization process")
for i in range(self.local_step):
loss: torch.Tensor = optimizer.step(closure) # type: ignore[assignment,arg-type]
energy: float = loss.energy # type: ignore[attr-defined]
return LossEnergy(loss=loss, energy=energy)
logging.info("Starting local optimization process")
for i in range(self.local_step):
loss_energy: LossEnergy = optimizer.step(closure) # type: ignore[assignment,arg-type]
loss, energy = loss_energy.loss, loss_energy.energy

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing the dynamically added energy attribute requires type ignore comments and makes the code fragile. This is a consequence of the monkey-patching approach on line 117.

Copilot uses AI. Check for mistakes.
logging.info("Local optimization in progress, step: %d, loss: %.10f, energy: %.10f, ref energy: %.10f, energy error: %.10f", i, loss.item(), energy, model.ref_energy,
energy - model.ref_energy)
writer.add_scalar("rlim/energy", energy, data["rlim"]["local"]) # type: ignore[no-untyped-call]
writer.add_scalar("rlim/error", energy - model.ref_energy, data["rlim"]["local"]) # type: ignore[no-untyped-call]
writer.add_scalar("rlim/loss", loss, data["rlim"]["local"]) # type: ignore[no-untyped-call]
Copy link

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a Tensor directly to add_scalar may cause type issues—use loss.item() to log a Python float.

Suggested change
writer.add_scalar("rlim/loss", loss, data["rlim"]["local"]) # type: ignore[no-untyped-call]
writer.add_scalar("rlim/loss", loss.item(), data["rlim"]["local"]) # type: ignore[no-untyped-call]

Copilot uses AI. Check for mistakes.
data["rlim"]["local"] += 1

logging.info("Local optimization process completed")

writer.flush() # type: ignore[no-untyped-call]

logging.info("Saving model checkpoint")
data["rlim"]["global"] += 1
data["network"] = network.state_dict()
data["optimizer"] = optimizer.state_dict()
self.common.save(data, data["rlim"]["global"])
logging.info("Checkpoint successfully saved")

logging.info("Current optimization cycle completed")


subcommand_dict["rlim"] = RlimConfig