Skip to content

Commit

Permalink
Add serialization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IAlibay committed Sep 21, 2024
1 parent 264ee59 commit add6464
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 44 deletions.
17 changes: 17 additions & 0 deletions src/pontibus/tests/protocols/solvation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from importlib import resources
import gzip


@pytest.fixture
def afe_solv_transformation_json() -> str:
"""
ASFE results object as created by quickrun.
generated with devtools/gent-serialized-results.py
"""
d = resources.files("pontibus.tests.data.solvation_protocol")
fname = "ASFEProtocol_json_results.gz"

with gzip.open((d / fname).as_posix(), "r") as f:
return f.read().decode()
83 changes: 39 additions & 44 deletions src/pontibus/tests/protocols/solvation/test_results.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import gzip
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import itertools
import pytest
import json
from importlib import resources

import numpy as np
from openff.units import unit as offunit
Expand All @@ -11,35 +11,23 @@
from pontibus.protocols.solvation import ASFEProtocolResult


@pytest.fixture
def afe_solv_transformation_json() -> str:
"""
ASFE results object as created by quickrun.
generated with devtools/gent-serialized-results.py
"""
d = resources.files("pontibus.tests.data.solvation_protocol")
fname = "ASFEProtocol_json_results.gz"

with gzip.open((d / fname).as_posix(), 'r') as f:
return f.read().decode()


class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)
d = json.loads(
afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder
)

pr = openfe.ProtocolResult.from_dict(d['protocol_result'])
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])

return pr

def test_reload_protocol_result(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)
d = json.loads(
afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder
)

pr = ASFEProtocolResult.from_dict(d['protocol_result'])
pr = ASFEProtocolResult.from_dict(d["protocol_result"])

assert pr

Expand All @@ -63,14 +51,14 @@ def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()

assert isinstance(inds, dict)
assert isinstance(inds['solvent'], list)
assert isinstance(inds['vacuum'], list)
assert len(inds['solvent']) == len(inds['vacuum']) == 3
for e, u in itertools.chain(inds['solvent'], inds['vacuum']):
assert isinstance(inds["solvent"], list)
assert isinstance(inds["vacuum"], list)
assert len(inds["solvent"]) == len(inds["vacuum"]) == 3
for e, u in itertools.chain(inds["solvent"], inds["vacuum"]):
assert e.is_compatible_with(offunit.kilojoule_per_mole)
assert u.is_compatible_with(offunit.kilojoule_per_mole)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_forwards_etc(self, key, protocolresult):
far = protocolresult.get_forward_and_reverse_energy_analysis()

Expand All @@ -79,27 +67,34 @@ def test_get_forwards_etc(self, key, protocolresult):
far1 = far[key][0]
assert isinstance(far1, dict)

for k in ['fractions', 'forward_DGs', 'forward_dDGs',
'reverse_DGs', 'reverse_dDGs']:
for k in [
"fractions",
"forward_DGs",
"forward_dDGs",
"reverse_DGs",
"reverse_dDGs",
]:
assert k in far1

if k == 'fractions':
if k == "fractions":
assert isinstance(far1[k], np.ndarray)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_frwd_reverse_none_return(self, key, protocolresult):
# fetch the first result of type key
data = [i for i in protocolresult.data[key].values()][0][0]
# set the output to None
data.outputs['forward_and_reverse_energies'] = None
data.outputs["forward_and_reverse_energies"] = None

# now fetch the analysis results and expect a warning
wmsg = ("were found in the forward and reverse dictionaries "
f"of the repeats of the {key}")
wmsg = (
"were found in the forward and reverse dictionaries "
f"of the repeats of the {key}"
)
with pytest.warns(UserWarning, match=wmsg):
protocolresult.get_forward_and_reverse_energy_analysis()

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_overlap_matrices(self, key, protocolresult):
ovp = protocolresult.get_overlap_matrices()

Expand All @@ -108,23 +103,23 @@ def test_get_overlap_matrices(self, key, protocolresult):
assert len(ovp[key]) == 3

ovp1 = ovp[key][0]
assert isinstance(ovp1['matrix'], np.ndarray)
assert ovp1['matrix'].shape == (14, 14)
assert isinstance(ovp1["matrix"], np.ndarray)
assert ovp1["matrix"].shape == (14, 14)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_replica_transition_statistics(self, key, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()

assert isinstance(rpx, dict)
assert isinstance(rpx[key], list)
assert len(rpx[key]) == 3
rpx1 = rpx[key][0]
assert 'eigenvalues' in rpx1
assert 'matrix' in rpx1
assert rpx1['eigenvalues'].shape == (14,)
assert rpx1['matrix'].shape == (14, 14)
assert "eigenvalues" in rpx1
assert "matrix" in rpx1
assert rpx1["eigenvalues"].shape == (14,)
assert rpx1["matrix"].shape == (14, 14)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()

Expand All @@ -133,7 +128,7 @@ def test_equilibration_iterations(self, key, protocolresult):
assert len(eq[key]) == 3
assert all(isinstance(v, float) for v in eq[key])

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_production_iterations(self, key, protocolresult):
prod = protocolresult.production_iterations()

Expand Down
106 changes: 106 additions & 0 deletions src/pontibus/tests/protocols/solvation/test_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import json
import pytest

import gufe
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
import openfe
from pontibus.protocols.solvation import (
ASFEProtocol,
ASFEProtocolResult,
ASFESolventUnit,
ASFEVacuumUnit,
)
from pontibus.components import ExtendedSolventComponent


@pytest.fixture
def protocol():
return ASFEProtocol(ASFEProtocol.default_settings())


@pytest.fixture
def protocol_units(protocol, benzene_modifications):
pus = protocol.create(
stateA=openfe.ChemicalSystem(
{
"solute": benzene_modifications["benzene"],
"solvent": ExtendedSolventComponent(),
}
),
stateB=openfe.ChemicalSystem({"solvent": ExtendedSolventComponent()}),
mapping=None,
)
return list(pus.protocol_units)


@pytest.fixture
def solvent_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, ASFESolventUnit):
return pu


@pytest.fixture
def vacuum_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, ASFEVacuumUnit):
return pu


@pytest.fixture
def protocol_result(afe_solv_transformation_json):
d = json.loads(
afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder
)
pr = ASFEProtocolResult.from_dict(d["protocol_result"])
return pr


class TestProtocol(GufeTokenizableTestsMixin):
cls = ASFEProtocol
key = "ASFEProtocol-798d96f939ae6898c385e31e48caae6d"
repr = f"<{key}>"

@pytest.fixture()
def instance(self, protocol):
return protocol


class TestSolventUnit(GufeTokenizableTestsMixin):
cls = ASFESolventUnit
repr = "ASFESolventUnit(Absolute Solvation, benzene solvent leg: repeat 2 generation 0)"
key = None

@pytest.fixture()
def instance(self, solvent_protocol_unit):
return solvent_protocol_unit

def test_key_stable(self):
pytest.skip()


class TestVacuumUnit(GufeTokenizableTestsMixin):
cls = ASFEVacuumUnit
repr = (
"ASFEVacuumUnit(Absolute Solvation, benzene vacuum leg: repeat 2 generation 0)"
)
key = None

@pytest.fixture()
def instance(self, vacuum_protocol_unit):
return vacuum_protocol_unit

def test_key_stable(self):
pytest.skip()


class TestProtocolResult(GufeTokenizableTestsMixin):
cls = ASFEProtocolResult
key = "ASFEProtocolResult-f1172ed96a55d778bdfcc8d9ce0299f2"
repr = f"<{key}>"

@pytest.fixture()
def instance(self, protocol_result):
return protocol_result

0 comments on commit add6464

Please sign in to comment.