diff --git a/qmb/haar.py b/qmb/haar.py index a7211da..81170b7 100644 --- a/qmb/haar.py +++ b/qmb/haar.py @@ -16,6 +16,7 @@ from .subcommand_dict import subcommand_dict from .model_dict import ModelProto from .optimizer import initialize_optimizer, scale_learning_rate +from .bitspack import pack_int @dataclasses.dataclass @@ -256,6 +257,8 @@ class HaarConfig: common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes] + # The initial configurations for the first step + initial_config: typing.Annotated[str, tyro.conf.arg(aliases=["-i"])] = "" # The sampling count from neural network sampling_count_from_neural_network: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 1024 # The sampling count from last iteration @@ -364,6 +367,17 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No pool_configs, pool_psi = data["haar"]["pool"] data["haar"]["pool"] = (pool_configs.to(device=self.common.device), pool_psi.to(device=self.common.device)) + if self.initial_config != "": + if data["haar"]["pool"] is None: + config = pack_int( + torch.tensor([[int(i) for i in single_config] for single_config in self.initial_config.split(",")], dtype=torch.uint8, device=self.common.device), + size=1, + ) + data["haar"]["pool"] = (config, network(config)) + logging.info("The initial configuration is imported successfully.") + else: + logging.info("The initial configuration is provided, but the pool from the last iteration is not empty, so the initial configuration will be ignored.") + writer = torch.utils.tensorboard.SummaryWriter(log_dir=self.common.folder()) # type: ignore[no-untyped-call] while True: