-
Notifications
You must be signed in to change notification settings - Fork 0
Add a reinforcement learning based imaginary time evolution algorithm. #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,150 @@ | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| This file implements a reinforcement learning based imaginary time evolution algorithm. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||
| import typing | ||||||||||||||||||||||||||||||||
| import dataclasses | ||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||
| import torch.utils.tensorboard | ||||||||||||||||||||||||||||||||
| import tyro | ||||||||||||||||||||||||||||||||
| from .common import CommonConfig | ||||||||||||||||||||||||||||||||
| from .subcommand_dict import subcommand_dict | ||||||||||||||||||||||||||||||||
| from .optimizer import initialize_optimizer | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| @dataclasses.dataclass | ||||||||||||||||||||||||||||||||
| class RlimConfig: | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| The reinforcement learning based imaginary time evolution algorithm. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # pylint: disable=too-many-instance-attributes | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # The sampling count | ||||||||||||||||||||||||||||||||
| sampling_count: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 4000 | ||||||||||||||||||||||||||||||||
| # The number of relative configurations to be used in energy calculation | ||||||||||||||||||||||||||||||||
| relative_count: typing.Annotated[int, tyro.conf.arg(aliases=["-c"])] = 40000 | ||||||||||||||||||||||||||||||||
| # The learning rate for the local optimizer | ||||||||||||||||||||||||||||||||
| learning_rate: typing.Annotated[float, tyro.conf.arg(aliases=["-r"])] = 1e-3 | ||||||||||||||||||||||||||||||||
| # The learning rate for the imaginary time evolution | ||||||||||||||||||||||||||||||||
| evolution_time: typing.Annotated[float, tyro.conf.arg(aliases=["-t"])] = 1e-3 | ||||||||||||||||||||||||||||||||
| # The number of steps for the local optimizer | ||||||||||||||||||||||||||||||||
| local_step: typing.Annotated[int, tyro.conf.arg(aliases=["-s"])] = 32 | ||||||||||||||||||||||||||||||||
| # The dropout of the loss function | ||||||||||||||||||||||||||||||||
| dropout: typing.Annotated[float, tyro.conf.arg(aliases=["-d"])] = 0.5 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def main(self) -> None: | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| The main function for the RLIM optimization. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| # pylint: disable=too-many-statements | ||||||||||||||||||||||||||||||||
| # pylint: disable=too-many-locals | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| model, network, data = self.common.main() | ||||||||||||||||||||||||||||||||
| ref_network = network | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
| ref_network = network | |
| ref_network = type(network)() # Create a new instance of the same model class | |
| ref_network.load_state_dict(network.state_dict()) # Copy the parameters from the original network |
Copilot
AI
Jul 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reference wavefunction is computed with the main network instead of ref_network. Update this (and similar calls) to use ref_network so the reference model remains unchanged.
| ref_psi_src = network(ref_configs_src) # psi r | |
| ref_psi_src = ref_network(ref_configs_src) # psi r |
Copilot
AI
Jul 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamically adding attributes to tensor objects is not a clean practice. Consider returning a tuple or using a dataclass to pass both loss and energy values instead of monkey-patching the tensor object.
| loss.energy = energy # type: ignore[attr-defined] | |
| return loss | |
| logging.info("Starting local optimization process") | |
| for i in range(self.local_step): | |
| loss: torch.Tensor = optimizer.step(closure) # type: ignore[assignment,arg-type] | |
| energy: float = loss.energy # type: ignore[attr-defined] | |
| return LossEnergy(loss=loss, energy=energy) | |
| logging.info("Starting local optimization process") | |
| for i in range(self.local_step): | |
| loss_energy: LossEnergy = optimizer.step(closure) # type: ignore[assignment,arg-type] | |
| loss, energy = loss_energy.loss, loss_energy.energy |
Copilot
AI
Jul 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing the dynamically added energy attribute requires type ignore comments and makes the code fragile. This is a consequence of the monkey-patching approach on line 117.
Copilot
AI
Jul 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing a Tensor directly to add_scalar may cause type issues—use loss.item() to log a Python float.
| writer.add_scalar("rlim/loss", loss, data["rlim"]["local"]) # type: ignore[no-untyped-call] | |
| writer.add_scalar("rlim/loss", loss.item(), data["rlim"]["local"]) # type: ignore[no-untyped-call] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assigning ref_network to
networkcreates an alias—updates tonetworkwill also affectref_network. Consider usingcopy.deepcopy(network)to freeze the reference model.