Skip to content

Commit

Permalink
Update test so that it compares the results of running the modelforge…
Browse files Browse the repository at this point in the history
… implementation on a batched methane with the reference implementation on a single methane
  • Loading branch information
ArnNag committed Aug 21, 2024
1 parent 5f9666d commit 57e7b49
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions modelforge/tests/test_spookynet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,29 @@ def test_forward():
).double()

single_methane = setup_single_methane_input()["modelforge_methane_input"]
model_input = NNPInput(
double_methane = NNPInput(
atomic_numbers=torch.cat([single_methane.atomic_numbers] * 2, dim=0),
positions=torch.cat([single_methane.positions] * 2, dim=0),
atomic_subsystem_indices=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]),
total_charge=torch.cat([single_methane.total_charge] * 2, dim=0),
)
print(f"{torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=torch.int64).dtype=}")
print(f"test: {model_input.atomic_subsystem_indices.dtype=}")
ic(model_input)
model_input.positions = model_input.positions.double()
model_input.total_charge = model_input.total_charge.double()

print(f"test: {model_input.atomic_subsystem_indices.dtype=}")
spookynet.input_preparation._input_checks(model_input)

print(f"test: {model_input.atomic_subsystem_indices.dtype=}")
pairlist_output = spookynet.input_preparation.prepare_inputs(model_input)
print(f"test: {model_input.atomic_subsystem_indices.dtype=}")
calculated_results = spookynet.core_module.forward(model_input, pairlist_output)
print(f"test: {double_methane.atomic_subsystem_indices.dtype=}")
ic(double_methane)
single_methane.positions = single_methane.positions.double()
single_methane.total_charge = single_methane.total_charge.double()
double_methane.positions = double_methane.positions.double()
double_methane.total_charge = double_methane.total_charge.double()

print(f"test: {double_methane.atomic_subsystem_indices.dtype=}")
spookynet.input_preparation._input_checks(single_methane)
single_pairlist_output = spookynet.input_preparation.prepare_inputs(single_methane)

spookynet.input_preparation._input_checks(double_methane)
print(f"test: {double_methane.atomic_subsystem_indices.dtype=}")
double_pairlist_output = spookynet.input_preparation.prepare_inputs(double_methane)
print(f"test: {double_methane.atomic_subsystem_indices.dtype=}")
calculated_results = spookynet.core_module.forward(double_methane, double_pairlist_output)

ref_spookynet = RefSpookyNet(
num_features=config["potential"]["core_parameter"]["number_of_atom_features"],
Expand Down Expand Up @@ -395,19 +399,19 @@ def test_forward():
ref_spookynet.train()

# TODO: how are multiple systems passed into the reference SpookyNet
print(f"test: {model_input.atomic_subsystem_indices.dtype=}")
reference_calculated_results = ref_spookynet(
model_input.atomic_numbers,
model_input.total_charge,
(model_input.positions * unit.nanometer).m_as(unit.angstrom),
pairlist_output.pair_indices[0],
pairlist_output.pair_indices[1],
batch_seg=model_input.atomic_subsystem_indices.long(),
num_batch=2,
print(f"test: {double_methane.atomic_subsystem_indices.dtype=}")
energy, forces, dipole, f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6 = ref_spookynet(
single_methane.atomic_numbers,
single_methane.total_charge,
(single_methane.positions * unit.nanometer).m_as(unit.angstrom),
single_pairlist_output.pair_indices[0],
single_pairlist_output.pair_indices[1],
batch_seg=None,
num_batch=1,
)

ic(calculated_results.keys())
ic(reference_calculated_results)
ic(calculated_results["per_atom_energy"])
ic(ea)


def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter):
Expand Down

0 comments on commit 57e7b49

Please sign in to comment.