Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions qmb/haar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .subcommand_dict import subcommand_dict
from .model_dict import ModelProto
from .optimizer import initialize_optimizer, scale_learning_rate
from .bitspack import pack_int


@dataclasses.dataclass
Expand Down Expand Up @@ -256,6 +257,8 @@ class HaarConfig:

common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes]

# The initial configurations for the first step
initial_config: typing.Annotated[str, tyro.conf.arg(aliases=["-i"])] = ""
# The sampling count from neural network
sampling_count_from_neural_network: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 1024
# The sampling count from last iteration
Expand Down Expand Up @@ -364,6 +367,17 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
pool_configs, pool_psi = data["haar"]["pool"]
data["haar"]["pool"] = (pool_configs.to(device=self.common.device), pool_psi.to(device=self.common.device))

if self.initial_config != "":
if data["haar"]["pool"] is None:
config = pack_int(
torch.tensor([[int(i) for i in single_config] for single_config in self.initial_config.split(",")], dtype=torch.uint8, device=self.common.device),
size=1,
)
data["haar"]["pool"] = (config, network(config))
logging.info("The initial configuration is imported successfully.")
else:
logging.info("The initial configuration is provided, but the pool from the last iteration is not empty, so the initial configuration will be ignored.")

writer = torch.utils.tensorboard.SummaryWriter(log_dir=self.common.folder()) # type: ignore[no-untyped-call]

while True:
Expand Down