Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qmb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from . import free_fermion as _ # type: ignore[no-redef]
from . import ising as _ # type: ignore[no-redef]
from . import vmc as _ # type: ignore[no-redef]
from . import markov as _ # type: ignore[no-redef]
from . import haar as _ # type: ignore[no-redef]
from . import rldiag as _ # type: ignore[no-redef]
from . import precompile as _ # type: ignore[no-redef]
Expand Down
39 changes: 39 additions & 0 deletions qmb/ising.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import tyro
from .mlp import WaveFunctionNormal as MlpWaveFunction
from .attention import WaveFunctionNormal as AttentionWaveFunction
from .peps import PepsFunction
from .hamiltonian import Hamiltonian
from .model_dict import model_dict, ModelProto, NetworkProto, NetworkConfigProto

Expand Down Expand Up @@ -325,3 +326,41 @@ def create(self, model: Model) -> NetworkProto:


Model.network_dict["attention"] = AttentionConfig


@dataclasses.dataclass
class PepsConfig:
"""
The configuration of the PEPS network.
"""

# The bond dimension of the network
D: typing.Annotated[int, tyro.conf.arg(aliases=["-d"])] = 4 # pylint: disable=invalid-name
# The cut-off bond dimension of the network
Dc: typing.Annotated[int, tyro.conf.arg(aliases=["-c"])] = 16 # pylint: disable=invalid-name

def create(self, model: Model) -> NetworkProto:
"""
Create a PEPS network for the model.
"""
logging.info(
"PEPS network configuration: "
"bond dimension: %d, "
"cut-off bond dimension: %d",
self.D,
self.Dc,
)

network = PepsFunction(
L1=model.m,
L2=model.n,
d=2,
D=self.D,
Dc=self.Dc,
use_complex=True,
)

return network


Model.network_dict["peps"] = PepsConfig
186 changes: 186 additions & 0 deletions qmb/markov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
This file implements a VMC method based on the Markov chain for solving quantum many-body problems.
"""

import logging
import typing
import dataclasses
import torch
import torch.utils.tensorboard
import tyro
from .common import CommonConfig
from .subcommand_dict import subcommand_dict
from .optimizer import initialize_optimizer
from .bitspack import pack_int, unpack_int


@dataclasses.dataclass
class MarkovConfig:
"""
The VMC optimization based on the Markov chain for solving quantum many-body problems.
"""

# pylint: disable=too-many-instance-attributes

common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes]

# The sampling count
sampling_count: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 4000
# The number of relative configurations to be used in energy calculation
relative_count: typing.Annotated[int, tyro.conf.arg(aliases=["-c"])] = 40000
# Whether to use the global optimizer
global_opt: typing.Annotated[bool, tyro.conf.arg(aliases=["-g"])] = False
# Whether to use LBFGS instead of Adam
use_lbfgs: typing.Annotated[bool, tyro.conf.arg(aliases=["-2"])] = False
# The learning rate for the local optimizer
learning_rate: typing.Annotated[float, tyro.conf.arg(aliases=["-r"], help_behavior_hint="(default: 1e-3 for Adam, 1 for LBFGS)")] = -1
# The number of steps for the local optimizer
local_step: typing.Annotated[int, tyro.conf.arg(aliases=["-s"])] = 1000
# The initial configurations for the first step
initial_config: typing.Annotated[str, tyro.conf.arg(aliases=["-i"])] = ""

def __post_init__(self) -> None:
if self.learning_rate == -1:
self.learning_rate = 1 if self.use_lbfgs else 1e-3

def main(self) -> None:
"""
The main function for the VMC optimization based on the Markov chain.
"""
# pylint: disable=too-many-statements
# pylint: disable=too-many-locals

model, network, data = self.common.main()

logging.info(
"Arguments Summary: "
"Sampling Count: %d, "
"Relative Count: %d, "
"Global Optimizer: %s, "
"Use LBFGS: %s, "
"Learning Rate: %.10f, "
"Local Steps: %d, ",
self.sampling_count,
self.relative_count,
"Yes" if self.global_opt else "No",
"Yes" if self.use_lbfgs else "No",
self.learning_rate,
self.local_step,
)

optimizer = initialize_optimizer(
network.parameters(),
use_lbfgs=self.use_lbfgs,
learning_rate=self.learning_rate,
state_dict=data.get("optimizer"),
)

if "markov" not in data:
data["markov"] = {"global": 0, "local": 0, "pool": None}

# TODO: 如何确认大小?
configs = pack_int(
torch.tensor([[int(i) for i in self.initial_config]], dtype=torch.uint8, device=self.common.device),
size=1,
)
if data["markov"]["pool"] is None:
data["markov"]["pool"] = configs
logging.info("The initial configuration is imported successfully.")
else:
logging.info("The initial configuration is provided, but the pool from the last iteration is not empty, so the initial configuration will be ignored.")

writer = torch.utils.tensorboard.SummaryWriter(log_dir=self.common.folder()) # type: ignore[no-untyped-call]

while True:
logging.info("Starting a new optimization cycle")

logging.info("Checking the configuration pool")
config = data["markov"]["pool"]
old_config = config.repeat(self.sampling_count // config.shape[0] + 1, 1)[:self.sampling_count]

logging.info("Hopping configurations")
def hop(config):
# TODO: use hamiltonian
x = unpack_int(config, size=1, last_dim=model.m * model.n)
x = x.view(-1, model.m, model.n)
batch, L1, L2 = x.shape
swap_dim = torch.randint(0, 2, (batch,))

out = x.clone()
for b in range(batch):
if swap_dim[b] == 0: # 在 L1 方向交换
i = torch.randint(0, L1 - 1, (1,)).item()
j = torch.randint(0, L2, (1,)).item()
out[b, i, j], out[b, i+1, j] = x[b, i+1, j], x[b, i, j]
else: # 在 L2 方向交换
i = torch.randint(0, L1, (1,)).item()
j = torch.randint(0, L2 - 1, (1,)).item()
out[b, i, j], out[b, i, j+1] = x[b, i, j+1], x[b, i, j]
x = out.view(-1, model.m * model.n)
return pack_int(x, size=1)
new_config = hop(old_config)
old_weight = network(old_config)
new_weight = network(new_config)
accept_prob = (new_weight / old_weight).abs().clamp(max=1)**2
accept = torch.rand_like(accept_prob) < accept_prob
configs = torch.where(accept.unsqueeze(-1), new_config, old_config)
configs_i = torch.unique(configs, dim=0)
psi_i = network(configs_i)
data["markov"]["pool"] = configs_i
logging.info("Sampling completed, configurations count: %d", len(configs_i))

logging.info("Calculating relative configurations")
if self.relative_count <= len(configs_i):
configs_src = configs_i
configs_dst = configs_i
else:
configs_src = configs_i
configs_dst = torch.cat([configs_i, model.find_relative(configs_i, psi_i, self.relative_count - len(configs_i))])
logging.info("Relative configurations calculated, count: %d", len(configs_dst))

optimizer = initialize_optimizer(
network.parameters(),
use_lbfgs=self.use_lbfgs,
learning_rate=self.learning_rate,
new_opt=not self.global_opt,
optimizer=optimizer,
)

def closure() -> torch.Tensor:
# Optimizing energy
optimizer.zero_grad()
psi_src = network(configs_src)
with torch.no_grad():
psi_dst = network(configs_dst)
hamiltonian_psi_dst = model.apply_within(configs_dst, psi_dst, configs_src)
num = psi_src.conj() @ hamiltonian_psi_dst
den = psi_src.conj() @ psi_src.detach()
energy = num / den
energy = energy.real
energy.backward() # type: ignore[no-untyped-call]
return energy

logging.info("Starting local optimization process")

for i in range(self.local_step):
energy: torch.Tensor = optimizer.step(closure) # type: ignore[assignment,arg-type]
logging.info("Local optimization in progress, step: %d, energy: %.10f, ref energy: %.10f", i, energy.item(), model.ref_energy)
writer.add_scalar("markov/energy", energy, data["markov"]["local"]) # type: ignore[no-untyped-call]
writer.add_scalar("markov/error", energy - model.ref_energy, data["markov"]["local"]) # type: ignore[no-untyped-call]
data["markov"]["local"] += 1

logging.info("Local optimization process completed")

writer.flush() # type: ignore[no-untyped-call]

logging.info("Saving model checkpoint")
data["markov"]["global"] += 1
data["network"] = network.state_dict()
data["optimizer"] = optimizer.state_dict()
self.common.save(data, data["markov"]["global"])
logging.info("Checkpoint successfully saved")

logging.info("Current optimization cycle completed")


subcommand_dict["markov"] = MarkovConfig
126 changes: 126 additions & 0 deletions qmb/peps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
This file implements the PEPS tensor network.
"""

import torch
from .bitspack import unpack_int


class PEPS(torch.nn.Module):
"""
The PEPS tensor network.
"""

# pylint: disable=invalid-name

def __init__(self, L1: int, L2: int, d: int, D: int, Dc: int, use_complex: bool = False) -> None: # pylint: disable=too-many-arguments, too-many-positional-arguments
super().__init__()
self.L1: int = L1
self.L2: int = L2
self.d: int = d
self.D: int = D
self.Dc: int = Dc
self.use_complex: bool = use_complex

self.tensors = torch.nn.Parameter(torch.randn([L1, L2, d, D, D, D, D], dtype=torch.complex128 if use_complex else torch.float64))

def _tensor(self, l1: int, l2: int, config: torch.Tensor) -> torch.Tensor:
"""
Get the tensor for a specific lattice site (l1, l2) and configuration.
"""
# pylint: disable=unsubscriptable-object
# Order: L, U, D, R
tensor: torch.Tensor = self.tensors[l1, l2, config.to(torch.int64)]
if l2 == 0:
tensor = tensor[:, :1, :, :, :]
if l1 == 0:
tensor = tensor[:, :, :1, :, :]
if l1 == self.L1 - 1:
tensor = tensor[:, :, :, :1, :]
if l2 == self.L2 - 1:
tensor = tensor[:, :, :, :, :1]
return tensor

def _bmps(self, line1: list[torch.Tensor], line2: list[torch.Tensor]) -> list[torch.Tensor]:
# pylint: disable=too-many-locals
# tensor in double: blLudrR
double = [torch.einsum("blumr,bLmdR->blLudrR", tensor1, tensor2) for tensor1, tensor2 in zip(line1, line2)]
# Merge two left index for the first tensor
# tensor shape should be: bludrR
double[0] = double[0].flatten(1, 2)
for l2 in range(self.L2 - 1):
# tensor shape: bludrR
# b for batch
# lud for q tensor
# rR for r tensor
tensor = double[l2]
b, l, u, d, r, R = tensor.shape
tensor = tensor.reshape([b, l * u * d, r * R])
q_tensor, r_tensor = torch.linalg.qr(tensor, mode="reduced") # pylint: disable=not-callable
double[l2] = q_tensor.reshape([b, l, u, d, -1])
remain = r_tensor.reshape([b, -1, r, R])
double[l2 + 1] = torch.einsum("blmM,bmMudrR->bludrR", remain, double[l2 + 1])
# Merge two right index for the last tensor
double[-1] = double[-1].flatten(4, 5)
# tensor shape is: bludr
for l2 in range(self.L2 - 1, 0, -1):
# tensor shape: bludr
# b for batch
# l for u tensor
# udr for v tensor
tensor = double[l2]
b, l, u, d, r = tensor.shape
tensor = tensor.reshape([b, l, u * d * r])
u_tensor, s_tensor, v_tensor = torch.linalg.svd(tensor, full_matrices=False) # pylint: disable=not-callable
middle_size = s_tensor.shape[-1]
if middle_size > self.Dc:
u_tensor = u_tensor[:, :, :self.Dc]
s_tensor = s_tensor[:, :self.Dc]
v_tensor = v_tensor[:, :self.Dc, :]
double[l2] = v_tensor.reshape([b, -1, u, d, r])
double[l2 - 1] = torch.einsum("bludm,bmr,br->bludr", double[l2 - 1], u_tensor, s_tensor)
# tensor shape is still: b l u d r
return double

def _contract(self, tensors: list[list[torch.Tensor]]) -> torch.Tensor:
candidates = tensors[0]
for l1 in range(1, self.L1):
candidates = self._bmps(candidates, tensors[l1])
result = candidates[0]
for l2 in range(1, self.L2):
result = result * candidates[l2]
return result[:, 0, 0, 0, 0]

def forward(self, configs: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the PEPS tensor network.
"""
tensors: list[list[torch.Tensor]] = [[self._tensor(l1, l2, configs[:, (l1 * self.L2) + l2]) for l2 in range(self.L2)] for l1 in range(self.L1)]
return self._contract(tensors)


class PepsFunction(torch.nn.Module):
"""
The PEPS tensor network used by qmb interface.
"""

def __init__(self, L1: int, L2: int, d: int, D: int, Dc: int, use_complex: bool = False) -> None: # pylint: disable=too-many-arguments, too-many-positional-arguments
super().__init__()
assert d == 2
self.sites = L1 * L2
self.model = PEPS(L1, L2, d, D, Dc, use_complex)

@torch.jit.export
def generate_unique(self, batch_size: int, block_num: int = 1) -> tuple[torch.Tensor, torch.Tensor, None, None]:
"""
Generate a batch of unique configurations.
"""
raise NotImplementedError("The generate_unique method is not implemented for this class.")

@torch.jit.export
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the PEPS tensor network.
"""
x = unpack_int(x, size=1, last_dim=self.sites)
return self.model(x)
Loading