-
Notifications
You must be signed in to change notification settings - Fork 0
Dev/add rbm network #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
5f08017 to
fbd4a73
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds a Restricted Boltzmann Machine (RBM) implementation for modeling wavefunction magnitudes and a WaveFunctionNormal wrapper that combines the RBM (for magnitude) with an MLP (for phase) and provides sampling utilities.
- Introduce
RBMclass withforwardprobability estimation andsample(Gibbs sampling). - Define
WaveFunctionNormalto compute complex amplitudes via RBM and MLP, plusgenerate_conffor sample generation. - Utilize
pack_int/unpack_intfor integer-based configuration encoding.
Comments suppressed due to low confidence (4)
qmb/rbm.py:88
- The parameter name
mpl_hidden_sizeappears to be a typo; it should match the attributemlp_hidden_sizefor consistency.
mpl_hidden_size: tuple[int, ...],
qmb/rbm.py:48
- There are currently no unit tests for
RBM.sample; consider adding tests to verify that Gibbs sampling converges to the expected distribution over several iterations.
def sample(self, v: torch.Tensor, k: int = 1) -> torch.Tensor:
qmb/rbm.py:42
- [nitpick] The variable name
mid2is not descriptive; consider renaming it to something likehidden_pre_activationorhidden_linear_outputfor clarity.
mid2 = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=E1102
qmb/rbm.py:65
- [nitpick] The intermediate name
midhis ambiguous; renaming it tohidden_pre_activationwould improve readability.
midh = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=E1102
qmb/rbm.py
Outdated
| self.visible_dim = visible_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.weights = torch.nn.Parameter(torch.Tensor(self.visible_dim, self.hidden_dim)) | ||
| init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim).cuda()) |
Copilot
AI
Jun 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid hard-coding .cuda() to keep the code device-agnostic; instead, derive the device from the model parameters (e.g., torch.tensor(self.visible_dim, device=self.weights.device)).
| init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim).cuda()) | |
| init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim, device=self.weights.device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你不应该在这里.cuda(),但这个copilot说的也不对,你应该直接
init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim))
因为模型移动到不同的device(如gpu),是在model创建之后进行的。
qmb/rbm.py
Outdated
| ) -> None: | ||
| super().__init__() | ||
| self.sites: int = sites | ||
| assert physical_dim == 2 # ?? |
Copilot
AI
Jun 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This bare assert lacks a message and the trailing # ?? is unclear; consider raising a ValueError with a descriptive message or adding an assertion message.
| assert physical_dim == 2 # ?? | |
| assert physical_dim == 2, "physical_dim must be 2 to ensure compatibility with the RBM implementation." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用管他
| super().__init__() | ||
| self.sites: int = sites | ||
| assert physical_dim == 2 # ?? | ||
| assert is_complex == True # pylint: disable=singleton-comparison |
Copilot
AI
Jun 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer assert is_complex over comparing to True to simplify the expression and satisfy linters without disabling them.
| assert is_complex == True # pylint: disable=singleton-comparison | |
| assert is_complex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用管他
qmb/rbm.py
Outdated
| self.dummy_param = torch.nn.Parameter(torch.empty(0)) | ||
|
|
||
| @property | ||
| # A grammar sugar to determine device |
Copilot
AI
Jun 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Minor typo: change "grammar sugar" to the more standard term "syntactic sugar" for clarity.
| # A grammar sugar to determine device | |
| # A syntactic sugar to determine device |
qmb/rbm.py
Outdated
| self.visible_dim = visible_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.weights = torch.nn.Parameter(torch.Tensor(self.visible_dim, self.hidden_dim)) | ||
| init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim).cuda()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你不应该在这里.cuda(),但这个copilot说的也不对,你应该直接
init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim))
因为模型移动到不同的device(如gpu),是在model创建之后进行的。
qmb/rbm.py
Outdated
| super().__init__() | ||
| self.visible_dim = visible_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.weights = torch.nn.Parameter(torch.Tensor(self.visible_dim, self.hidden_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要用torch.Tensor 这个东西, see: https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor
qmb/rbm.py
Outdated
| Probabilities. | ||
| """ | ||
| e1 = (v @ self.visible_bias).view(v.size()[:-1]) | ||
| mid2 = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=E1102 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
给disable规则的具体名称,不要代号。
qmb/rbm.py
Outdated
| ) -> None: | ||
| super().__init__() | ||
| self.sites: int = sites | ||
| assert physical_dim == 2 # ?? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用管他
| super().__init__() | ||
| self.sites: int = sites | ||
| assert physical_dim == 2 # ?? | ||
| assert is_complex == True # pylint: disable=singleton-comparison |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用管他
qmb/rbm.py
Outdated
| self.mlp_hidden_size: tuple[int, ...] = mpl_hidden_size | ||
|
|
||
| # Build Networks | ||
| self.probability: RBM = RBM(self.sites, self.rbm_hidden_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你在这里并不需要给typing,这个他可以推导出来,如果你真的想给typing,请颗粒度一致,不要一个RBM一个torch.nn.Module。
qmb/rbm.py
Outdated
|
|
||
| Returns | ||
| ------- | ||
| samples, amplitude |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你这段注释并不符合numpy style。
27651fc to
95ab358
Compare
修复了大部分小问题。
95ab358 to
b52bc91
Compare
Description
添加rbm网络
Checklist: