diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index c905dfc6..2de93833 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -59,25 +59,29 @@ def test_forward(): ).double() single_methane = setup_single_methane_input()["modelforge_methane_input"] - model_input = NNPInput( + double_methane = NNPInput( atomic_numbers=torch.cat([single_methane.atomic_numbers] * 2, dim=0), positions=torch.cat([single_methane.positions] * 2, dim=0), atomic_subsystem_indices=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), total_charge=torch.cat([single_methane.total_charge] * 2, dim=0), ) print(f"{torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=torch.int64).dtype=}") - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - ic(model_input) - model_input.positions = model_input.positions.double() - model_input.total_charge = model_input.total_charge.double() - - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - spookynet.input_preparation._input_checks(model_input) - - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - pairlist_output = spookynet.input_preparation.prepare_inputs(model_input) - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - calculated_results = spookynet.core_module.forward(model_input, pairlist_output) + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + ic(double_methane) + single_methane.positions = single_methane.positions.double() + single_methane.total_charge = single_methane.total_charge.double() + double_methane.positions = double_methane.positions.double() + double_methane.total_charge = double_methane.total_charge.double() + + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + spookynet.input_preparation._input_checks(single_methane) + single_pairlist_output = spookynet.input_preparation.prepare_inputs(single_methane) + + spookynet.input_preparation._input_checks(double_methane) + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + double_pairlist_output = spookynet.input_preparation.prepare_inputs(double_methane) + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + calculated_results = spookynet.core_module.forward(double_methane, double_pairlist_output) ref_spookynet = RefSpookyNet( num_features=config["potential"]["core_parameter"]["number_of_atom_features"], @@ -395,19 +399,19 @@ def test_forward(): ref_spookynet.train() # TODO: how are multiple systems passed into the reference SpookyNet - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - reference_calculated_results = ref_spookynet( - model_input.atomic_numbers, - model_input.total_charge, - (model_input.positions * unit.nanometer).m_as(unit.angstrom), - pairlist_output.pair_indices[0], - pairlist_output.pair_indices[1], - batch_seg=model_input.atomic_subsystem_indices.long(), - num_batch=2, + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + energy, forces, dipole, f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6 = ref_spookynet( + single_methane.atomic_numbers, + single_methane.total_charge, + (single_methane.positions * unit.nanometer).m_as(unit.angstrom), + single_pairlist_output.pair_indices[0], + single_pairlist_output.pair_indices[1], + batch_seg=None, + num_batch=1, ) - ic(calculated_results.keys()) - ic(reference_calculated_results) + ic(calculated_results["per_atom_energy"]) + ic(ea) def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter):