Skip to content

Conversation

@windy-pig
Copy link
Collaborator

Description

添加rbm网络

Checklist:

@windy-pig windy-pig force-pushed the dev/add-rbm-network branch from 5f08017 to fbd4a73 Compare June 12, 2025 15:26
@windy-pig windy-pig self-assigned this Jun 13, 2025
@hzhangxyz hzhangxyz requested a review from Copilot June 16, 2025 11:26
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds a Restricted Boltzmann Machine (RBM) implementation for modeling wavefunction magnitudes and a WaveFunctionNormal wrapper that combines the RBM (for magnitude) with an MLP (for phase) and provides sampling utilities.

  • Introduce RBM class with forward probability estimation and sample (Gibbs sampling).
  • Define WaveFunctionNormal to compute complex amplitudes via RBM and MLP, plus generate_conf for sample generation.
  • Utilize pack_int/unpack_int for integer-based configuration encoding.
Comments suppressed due to low confidence (4)

qmb/rbm.py:88

  • The parameter name mpl_hidden_size appears to be a typo; it should match the attribute mlp_hidden_size for consistency.
            mpl_hidden_size: tuple[int, ...],

qmb/rbm.py:48

  • There are currently no unit tests for RBM.sample; consider adding tests to verify that Gibbs sampling converges to the expected distribution over several iterations.
    def sample(self, v: torch.Tensor, k: int = 1) -> torch.Tensor:

qmb/rbm.py:42

  • [nitpick] The variable name mid2 is not descriptive; consider renaming it to something like hidden_pre_activation or hidden_linear_output for clarity.
        mid2 = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias)  # pylint: disable=E1102

qmb/rbm.py:65

  • [nitpick] The intermediate name midh is ambiguous; renaming it to hidden_pre_activation would improve readability.
            midh = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias)  # pylint: disable=E1102

qmb/rbm.py Outdated
self.visible_dim = visible_dim
self.hidden_dim = hidden_dim
self.weights = torch.nn.Parameter(torch.Tensor(self.visible_dim, self.hidden_dim))
init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim).cuda())
Copy link

Copilot AI Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid hard-coding .cuda() to keep the code device-agnostic; instead, derive the device from the model parameters (e.g., torch.tensor(self.visible_dim, device=self.weights.device)).

Suggested change
init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim).cuda())
init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim, device=self.weights.device))

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你不应该在这里.cuda(),但这个copilot说的也不对,你应该直接

init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim))

因为模型移动到不同的device(如gpu),是在model创建之后进行的。

qmb/rbm.py Outdated
) -> None:
super().__init__()
self.sites: int = sites
assert physical_dim == 2 # ??
Copy link

Copilot AI Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This bare assert lacks a message and the trailing # ?? is unclear; consider raising a ValueError with a descriptive message or adding an assertion message.

Suggested change
assert physical_dim == 2 # ??
assert physical_dim == 2, "physical_dim must be 2 to ensure compatibility with the RBM implementation."

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用管他

super().__init__()
self.sites: int = sites
assert physical_dim == 2 # ??
assert is_complex == True # pylint: disable=singleton-comparison
Copy link

Copilot AI Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer assert is_complex over comparing to True to simplify the expression and satisfy linters without disabling them.

Suggested change
assert is_complex == True # pylint: disable=singleton-comparison
assert is_complex

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用管他

qmb/rbm.py Outdated
self.dummy_param = torch.nn.Parameter(torch.empty(0))

@property
# A grammar sugar to determine device
Copy link

Copilot AI Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Minor typo: change "grammar sugar" to the more standard term "syntactic sugar" for clarity.

Suggested change
# A grammar sugar to determine device
# A syntactic sugar to determine device

Copilot uses AI. Check for mistakes.
qmb/rbm.py Outdated
self.visible_dim = visible_dim
self.hidden_dim = hidden_dim
self.weights = torch.nn.Parameter(torch.Tensor(self.visible_dim, self.hidden_dim))
init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim).cuda())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你不应该在这里.cuda(),但这个copilot说的也不对,你应该直接

init_range = 1.0 / torch.sqrt(torch.tensor(self.visible_dim))

因为模型移动到不同的device(如gpu),是在model创建之后进行的。

qmb/rbm.py Outdated
super().__init__()
self.visible_dim = visible_dim
self.hidden_dim = hidden_dim
self.weights = torch.nn.Parameter(torch.Tensor(self.visible_dim, self.hidden_dim))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要用torch.Tensor 这个东西, see: https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor

qmb/rbm.py Outdated
Probabilities.
"""
e1 = (v @ self.visible_bias).view(v.size()[:-1])
mid2 = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=E1102
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

给disable规则的具体名称,不要代号。

qmb/rbm.py Outdated
) -> None:
super().__init__()
self.sites: int = sites
assert physical_dim == 2 # ??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用管他

super().__init__()
self.sites: int = sites
assert physical_dim == 2 # ??
assert is_complex == True # pylint: disable=singleton-comparison
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用管他

qmb/rbm.py Outdated
self.mlp_hidden_size: tuple[int, ...] = mpl_hidden_size

# Build Networks
self.probability: RBM = RBM(self.sites, self.rbm_hidden_dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你在这里并不需要给typing,这个他可以推导出来,如果你真的想给typing,请颗粒度一致,不要一个RBM一个torch.nn.Module。

qmb/rbm.py Outdated

Returns
-------
samples, amplitude
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这段注释并不符合numpy style。

@windy-pig windy-pig force-pushed the dev/add-rbm-network branch from 27651fc to 95ab358 Compare June 16, 2025 14:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants