Skip to content

Commit

Permalink
Add protocol results test
Browse files Browse the repository at this point in the history
  • Loading branch information
IAlibay committed Sep 21, 2024
1 parent f5dccae commit 264ee59
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 15 deletions.
Binary file added devtools/ASFEProtocol_json_results.gz
Binary file not shown.
106 changes: 106 additions & 0 deletions devtools/gen_serialized_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Dev script to generate some result jsons that are used for testing
Generates
- ASFEProtocol_json_results.gz
"""
import gzip
import json
import logging
import pathlib
import tempfile
from openff.toolkit import (
Molecule, RDKitToolkitWrapper, AmberToolsToolkitWrapper
)
from openff.toolkit.utils.toolkit_registry import (
toolkit_registry_manager, ToolkitRegistry
)
from openff.units import unit
from kartograf.atom_aligner import align_mol_shape
from kartograf import KartografAtomMapper
import gufe
from gufe.tokenization import JSON_HANDLER
import openfe
from pontibus.protocols.solvation import ASFEProtocol
from pontibus.components import ExtendedSolventComponent


logger = logging.getLogger(__name__)

LIGA = "[H]C([H])([H])C([H])([H])C(=O)C([H])([H])C([H])([H])[H]"

amber_rdkit = ToolkitRegistry(
[RDKitToolkitWrapper(), AmberToolsToolkitWrapper()]
)


def get_molecule(smi, name):
with toolkit_registry_manager(amber_rdkit):
m = Molecule.from_smiles(smi)
m.generate_conformers()
m.assign_partial_charges(partial_charge_method="am1bcc")
return openfe.SmallMoleculeComponent.from_openff(m, name=name)


def execute_and_serialize(dag, protocol, simname):
logger.info(f"running {simname}")
with tempfile.TemporaryDirectory() as tmpdir:
workdir = pathlib.Path(tmpdir)
dagres = gufe.protocols.execute_DAG(
dag,
shared_basedir=workdir,
scratch_basedir=workdir,
keep_shared=False,
n_retries=3
)
protres = protocol.gather([dagres])

outdict = {
"estimate": protres.get_estimate(),
"uncertainty": protres.get_uncertainty(),
"protocol_result": protres.to_dict(),
"unit_results": {
unit.key: unit.to_keyed_dict()
for unit in dagres.protocol_unit_results
}
}

with gzip.open(f"{simname}_json_results.gz", 'wt') as zipfile:
json.dump(outdict, zipfile, cls=JSON_HANDLER.encoder)


def generate_ahfe_settings():
settings = ASFEProtocol.default_settings()
settings.solvent_equil_simulation_settings.equilibration_length_nvt = 10 * unit.picosecond
settings.solvent_equil_simulation_settings.equilibration_length = 10 * unit.picosecond
settings.solvent_equil_simulation_settings.production_length = 10 * unit.picosecond
settings.solvent_simulation_settings.equilibration_length = 10 * unit.picosecond
settings.solvent_simulation_settings.production_length = 500 * unit.picosecond
settings.vacuum_equil_simulation_settings.equilibration_length = 10 * unit.picosecond
settings.vacuum_equil_simulation_settings.production_length = 10 * unit.picosecond
settings.vacuum_simulation_settings.equilibration_length = 10 * unit.picosecond
settings.vacuum_simulation_settings.production_length = 500 * unit.picosecond
settings.protocol_repeats = 3
settings.vacuum_engine_settings.compute_platform = 'CPU'
settings.solvent_engine_settings.compute_platform = 'CUDA'

return settings


def generate_asfe_json(smc):
protocol = ASFEProtocol(settings=generate_ahfe_settings())
sysA = openfe.ChemicalSystem(
{"ligand": smc, "solvent": ExtendedSolventComponent()}
)
sysB = openfe.ChemicalSystem(
{"solvent": ExtendedSolventComponent()}
)

dag = protocol.create(stateA=sysA, stateB=sysB, mapping=None)

execute_and_serialize(dag, protocol, "ASFEProtocol")


if __name__ == "__main__":
molA = get_molecule(LIGA, "ligandA")
generate_asfe_json(molA)
Binary file not shown.
149 changes: 149 additions & 0 deletions src/pontibus/tests/protocols/solvation/test_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import gzip
import itertools
import pytest
import json
from importlib import resources

import numpy as np
from openff.units import unit as offunit
import gufe
import openfe
from pontibus.protocols.solvation import ASFEProtocolResult


@pytest.fixture
def afe_solv_transformation_json() -> str:
"""
ASFE results object as created by quickrun.
generated with devtools/gent-serialized-results.py
"""
d = resources.files("pontibus.tests.data.solvation_protocol")
fname = "ASFEProtocol_json_results.gz"

with gzip.open((d / fname).as_posix(), 'r') as f:
return f.read().decode()


class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = openfe.ProtocolResult.from_dict(d['protocol_result'])

return pr

def test_reload_protocol_result(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = ASFEProtocolResult.from_dict(d['protocol_result'])

assert pr

def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(-2.47, abs=0.5)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est
assert est.m == pytest.approx(0.2, abs=0.2)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()

assert isinstance(inds, dict)
assert isinstance(inds['solvent'], list)
assert isinstance(inds['vacuum'], list)
assert len(inds['solvent']) == len(inds['vacuum']) == 3
for e, u in itertools.chain(inds['solvent'], inds['vacuum']):
assert e.is_compatible_with(offunit.kilojoule_per_mole)
assert u.is_compatible_with(offunit.kilojoule_per_mole)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_get_forwards_etc(self, key, protocolresult):
far = protocolresult.get_forward_and_reverse_energy_analysis()

assert isinstance(far, dict)
assert isinstance(far[key], list)
far1 = far[key][0]
assert isinstance(far1, dict)

for k in ['fractions', 'forward_DGs', 'forward_dDGs',
'reverse_DGs', 'reverse_dDGs']:
assert k in far1

if k == 'fractions':
assert isinstance(far1[k], np.ndarray)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_get_frwd_reverse_none_return(self, key, protocolresult):
# fetch the first result of type key
data = [i for i in protocolresult.data[key].values()][0][0]
# set the output to None
data.outputs['forward_and_reverse_energies'] = None

# now fetch the analysis results and expect a warning
wmsg = ("were found in the forward and reverse dictionaries "
f"of the repeats of the {key}")
with pytest.warns(UserWarning, match=wmsg):
protocolresult.get_forward_and_reverse_energy_analysis()

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_get_overlap_matrices(self, key, protocolresult):
ovp = protocolresult.get_overlap_matrices()

assert isinstance(ovp, dict)
assert isinstance(ovp[key], list)
assert len(ovp[key]) == 3

ovp1 = ovp[key][0]
assert isinstance(ovp1['matrix'], np.ndarray)
assert ovp1['matrix'].shape == (14, 14)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_get_replica_transition_statistics(self, key, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()

assert isinstance(rpx, dict)
assert isinstance(rpx[key], list)
assert len(rpx[key]) == 3
rpx1 = rpx[key][0]
assert 'eigenvalues' in rpx1
assert 'matrix' in rpx1
assert rpx1['eigenvalues'].shape == (14,)
assert rpx1['matrix'].shape == (14, 14)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()

assert isinstance(eq, dict)
assert isinstance(eq[key], list)
assert len(eq[key]) == 3
assert all(isinstance(v, float) for v in eq[key])

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_production_iterations(self, key, protocolresult):
prod = protocolresult.production_iterations()

assert isinstance(prod, dict)
assert isinstance(prod[key], list)
assert len(prod[key]) == 3
assert all(isinstance(v, float) for v in prod[key])

def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"

with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()
71 changes: 64 additions & 7 deletions src/pontibus/tests/utils/test_interchange_packmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_set_offmol_resname,
_get_offmol_resname,
_check_library_charges,
_check_charged_mols,
)
from pontibus.utils.molecules import WATER
from numpy.testing import assert_allclose, assert_equal
Expand Down Expand Up @@ -67,6 +68,13 @@ def water_off_named_charged():
return water


@pytest.fixture(scope="module")
def water_off_am1bcc():
water = WATER.to_openff()
water.assign_partial_charges(partial_charge_method="am1bcc")
return water


def test_get_and_set_offmol_resname(CN_molecule, caplog):
CN_off = CN_molecule.to_openff()

Expand Down Expand Up @@ -99,6 +107,20 @@ def test_check_library_charges_fail(methanol):
_check_library_charges(ff, methanol)


def test_check_charged_mols_pass(methanol):
_check_charged_mols([methanol])


def test_check_charged_mols_nocharge(water_off, methanol):
with pytest.raises(ValueError, match="One or more"):
_check_charged_mols([water_off, methanol])


def test_check_charged_mols(water_off_am1bcc, water_off_named_charged):
with pytest.raises(ValueError, match="different charges"):
_check_charged_mols([water_off_am1bcc, water_off_named_charged])


def test_protein_component_fail(smc_components_benzene_named, T4_protein_component):
errmsg = "ProteinComponents is not currently supported"
with pytest.raises(ValueError, match=errmsg):
Expand Down Expand Up @@ -168,13 +190,11 @@ def test_solv_mismatch(
@pytest.mark.parametrize(
"assign_charges, errmsg",
[
(True, "PackmolSolvationSettings.assign_solvent_charges"),
(True, "do not have partial charges"),
(False, "No library charges"),
],
)
def test_noncharge_nolibrarycharges(
smc_components_benzene_named, assign_charges, errmsg
):
def test_charge_assignment_errors(smc_components_benzene_named, assign_charges, errmsg):
"""
True case: passing a Molecule without partial charges to Interchange
and asking to get charges from it will fail.
Expand Down Expand Up @@ -204,6 +224,44 @@ def test_noncharge_nolibrarycharges(
)


def test_assign_duplicate_resnames(caplog):
"""
Pass two smcs named the same and expect one to be renamed
"""
a = Molecule.from_smiles('C')
b = Molecule.from_smiles('CCC')
a.generate_conformers()
b.generate_conformers()
a.assign_partial_charges(partial_charge_method='gasteiger')
b.assign_partial_charges(partial_charge_method='gasteiger')
_set_offmol_resname(a, 'FOO')
_set_offmol_resname(b, 'FOO')
smc_a = SmallMoleculeComponent.from_openff(a)
smc_b = SmallMoleculeComponent.from_openff(b)

smcs = {smc_a: a, smc_b: b}

with caplog.at_level(logging.WARNING):
_, smc_comps = interchange_packmol_creation(
ffsettings=InterchangeFFSettings(
forcefields=[
"openff-2.0.0.offxml",
]
),
solvation_settings=None,
smc_components=smcs,
protein_component=None,
solvent_component=None,
solvent_offmol=None,
)
for match in ["Duplicate", "residue name to AAA"]:
assert match in caplog.text

assert len(smc_comps) == 2
assert smc_comps[smc_a][0] == 0
assert smc_comps[smc_b][0] == 1


@pytest.mark.parametrize(
"smiles",
[
Expand Down Expand Up @@ -250,7 +308,7 @@ def test_noncharge_nolibrarycharges(
"c1ccncc1",
],
)
def test_solvent_packing(smc_components_benzene_named, smiles):
def test_nonwater_solvent(smc_components_benzene_named, smiles):
solvent_offmol = Molecule.from_smiles(smiles)
solvent_offmol.assign_partial_charges(partial_charge_method="gasteiger")

Expand Down Expand Up @@ -640,9 +698,7 @@ def test_virtual_sites(self, omm_system, num_waters, num_particles, nonbonds):


"""
4. Named solvent
5. Unamed solvent
- Check we get the new residue names
- Check we get warned about renaming
6. Named solvent with inconsistent name
7. Duplicate named smcs
Expand All @@ -653,4 +709,5 @@ def test_virtual_sites(self, omm_system, num_waters, num_particles, nonbonds):
- with a solvent w/ virtual sites
- check omm topology indices match virtual sites (it doesn't!)
14. Check nonbonded cutoffs set via ffsettings
15. Check charged mols tests.
"""
Loading

0 comments on commit 264ee59

Please sign in to comment.