Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
IAlibay committed Sep 13, 2024
1 parent 7fe9e75 commit 3e40419
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 113 deletions.
9 changes: 3 additions & 6 deletions src/pontibus/components/extended_solvent_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def from_keyed_dict(cls, dct: dict):
dct,
lambda d: registry[GufeKey(d[":gufe-key:"])],
is_gufe_key_dict,
mode='decode',
top=True
mode="decode",
top=True,
)

return from_dict_depth_one(dct)
Expand Down Expand Up @@ -189,10 +189,7 @@ def _from_dict_depth_one(dct: dict) -> GufeTokenizable:
new_dct = {}

for entry in dct:
if (
isinstance(dct[entry], dict) and
'__qualname__' in dct[entry]
):
if isinstance(dct[entry], dict) and "__qualname__" in dct[entry]:
new_dct[entry] = _from_dict(dct[entry])
else:
new_dct[entry] = dct[entry]
Expand Down
4 changes: 1 addition & 3 deletions src/pontibus/protocols/solvation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
Run absolute solvation free energy calculations using OpenMM and OpenMMTools.
"""
from .settings import(
ASFESettings
)
from .settings import ASFESettings
from .asfe_protocol import (
ASFEProtocol,
ASFEProtocolResult,
Expand Down
22 changes: 6 additions & 16 deletions src/pontibus/protocols/solvation/asfe_protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe

from typing import (
Union,
Optional
)
from typing import Union, Optional
import uuid

import numpy as np
Expand Down Expand Up @@ -200,8 +197,7 @@ def _validate_solvent(state: ChemicalSystem, nonbonded_method: str):
* If there is a SolventComponent and the `nonbonded_method` is
`nocutoff`.
"""
solv = [comp for comp in state.values()
if isinstance(comp, SolventComponent)]
solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)]

if len(solv) > 0 and nonbonded_method.lower() == "nocutoff":
errmsg = "nocutoff cannot be used for solvent transformations"
Expand Down Expand Up @@ -238,12 +234,10 @@ def _create(

# Validate the lambda schedule
self._validate_lambda_schedule(
self.settings.lambda_settings,
self.settings.solvent_simulation_settings
self.settings.lambda_settings, self.settings.solvent_simulation_settings
)
self._validate_lambda_schedule(
self.settings.lambda_settings,
self.settings.vacuum_simulation_settings
self.settings.lambda_settings, self.settings.vacuum_simulation_settings
)

# Check nonbond & solvent compatibility
Expand Down Expand Up @@ -349,9 +343,7 @@ def _get_components(self):
# (of stateA since we enforce only one disappearing ligand)
return alchem_comps, None, prot_comp, off_comps

def _handle_settings(
self
) -> dict[str, gufe.settings.SettingsBaseModel]:
def _handle_settings(self) -> dict[str, gufe.settings.SettingsBaseModel]:
"""
Extract the relevant settings for a vacuum transformation.
Expand Down Expand Up @@ -433,9 +425,7 @@ def _get_components(self):
# disallowed on create
return alchem_comps, solv_comp, prot_comp, off_comps

def _handle_settings(
self
) -> dict[str, gufe.settings.SettingsBaseModel]:
def _handle_settings(self) -> dict[str, gufe.settings.SettingsBaseModel]:
"""
Extract the relevant settings for a vacuum transformation.
Expand Down
24 changes: 5 additions & 19 deletions src/pontibus/protocols/solvation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
# For details, see https://github.com/OpenFreeEnergy/openfe

import logging
from typing import (
Any,
Optional,
Union
)
from typing import Any, Optional, Union
import numpy.typing as npt
from gufe import (
SmallMoleculeComponent,
Expand All @@ -25,9 +21,7 @@
from openmm import app
import openmmtools
from openfe.utils import log_system_probe
from pontibus.components import (
ExtendedSolventComponent
)
from pontibus.components import ExtendedSolventComponent
from pontibus.protocols.solvation.settings import (
IntegratorSettings,
OpenFFPartialChargeSettings,
Expand Down Expand Up @@ -132,10 +126,7 @@ def _get_omm_objects(
solvent_component: Optional[SolventComponent],
smc_components: dict[SmallMoleculeComponent, OFFMolecule],
) -> tuple[
app.Topology,
openmm.System,
openmm.unit.Quantity,
dict[str, npt.NDArray]
app.Topology, openmm.System, openmm.unit.Quantity, dict[str, npt.NDArray]
]:
"""
Get the OpenMM Topology, Positions and System of the
Expand Down Expand Up @@ -311,9 +302,7 @@ def run(
)

# 15. Run simulation
unit_result_dict = self._run_simulation(
sampler, reporter, settings, dry
)
unit_result_dict = self._run_simulation(sampler, reporter, settings, dry)

finally:
# close reporter when you're done to prevent file handle clashes
Expand Down Expand Up @@ -355,10 +344,7 @@ def _execute(
) -> dict[str, Any]:
log_system_probe(logging.INFO, paths=[ctx.scratch])

outputs = self.run(
scratch_basepath=ctx.scratch,
shared_basepath=ctx.shared
)
outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared)

return {
"repeat_id": self._inputs["repeat_id"],
Expand Down
3 changes: 1 addition & 2 deletions src/pontibus/protocols/solvation/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def is_positive_distance(cls, v):
# these are time units, not simulation steps
if not v.is_compatible_with(unit.nanometer):
raise ValueError(
"nonbonded_cutoff must be in distance units "
"(i.e. nanometers)"
"nonbonded_cutoff must be in distance units " "(i.e. nanometers)"
)
if v < 0:
errmsg = "nonbonded_cutoff must be a positive value"
Expand Down
28 changes: 11 additions & 17 deletions src/pontibus/tests/components/test_extendedsolvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
def test_defaults():
s = ExtendedSolventComponent()

assert s.smiles == '[H][O][H]'
assert s.positive_ion == 'Na+'
assert s.negative_ion == 'Cl-'
assert s.smiles == "[H][O][H]"
assert s.positive_ion == "Na+"
assert s.negative_ion == "Cl-"
assert s.ion_concentration == 0.0 * unit.molar
assert s.neutralize == False
assert s.solvent_molecule == WATER
Expand All @@ -26,33 +26,31 @@ def test_defaults():
def test_neq_different_smc():
water_off = WATER.to_openff()
# Create a water with partial charges
water_off.assign_partial_charges(partial_charge_method='gasteiger')
water_off.assign_partial_charges(partial_charge_method="gasteiger")
WATER2 = SmallMoleculeComponent.from_openff(water_off)
s1 = ExtendedSolventComponent(solvent_molecule=WATER)
s2 = ExtendedSolventComponent(solvent_molecule=WATER2)

assert s1 != s2
assert s1.smiles == '[H][O][H]' == s2.smiles
assert s1.smiles == "[H][O][H]" == s2.smiles


def test_neq_different_solvent():
meth_off = Molecule.from_smiles('C')
meth_off = Molecule.from_smiles("C")
meth_off.generate_conformers()
METH = SmallMoleculeComponent.from_openff(meth_off)
s1 = ExtendedSolventComponent()
s2 = ExtendedSolventComponent(solvent_molecule=METH)

assert s1 != s2
assert s1.smiles == '[H][O][H]'
assert s2.smiles == '[H][C]([H])([H])[H]'
assert s1.smiles == "[H][O][H]"
assert s2.smiles == "[H][C]([H])([H])[H]"
assert s1.smiles != s2.smiles


def test_dict_roundtrip_eq():
s1 = ExtendedSolventComponent()
s2 = ExtendedSolventComponent.from_dict(
s1.to_dict()
)
s2 = ExtendedSolventComponent.from_dict(s1.to_dict())
assert s1 == s2
assert s1.solvent_molecule == s2.solvent_molecule
# Smiles isn't a dict entry, so make sure it got preserved
Expand All @@ -61,9 +59,7 @@ def test_dict_roundtrip_eq():

def test_keyed_dict_roundtrip_eq():
s1 = ExtendedSolventComponent()
s2 = ExtendedSolventComponent.from_keyed_dict(
s1.to_keyed_dict()
)
s2 = ExtendedSolventComponent.from_keyed_dict(s1.to_keyed_dict())
assert s1 == s2
# Smiles isn't a dict entry, so make sure it got preserved
assert s1.smiles == s2.smiles
Expand All @@ -74,9 +70,7 @@ def test_keyed_dict_roundtrip_eq():

def test_shallow_dict_roundtrip_eq():
s1 = ExtendedSolventComponent()
s2 = ExtendedSolventComponent.from_shallow_dict(
s1.to_shallow_dict()
)
s2 = ExtendedSolventComponent.from_shallow_dict(s1.to_shallow_dict())
assert s1 == s2
# Smiles isn't a dict entry, so make sure it got preserved
assert s1.smiles == s2.smiles
Expand Down
50 changes: 30 additions & 20 deletions src/pontibus/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,42 @@ class SlowTests:
To run the `slow` tests, either use the `--runslow` flag when invoking
pytest, or set the environment variable `PONTIBUS_SLOW_TESTS` to `true`
"""

def __init__(self, config):
self.config = config

@staticmethod
def _modify_slow(items, config):
msg = ("need --runslow pytest cli option or the environment variable "
"`PONTIBUS_SLOW_TESTS` set to `True` to run")
msg = (
"need --runslow pytest cli option or the environment variable "
"`PONTIBUS_SLOW_TESTS` set to `True` to run"
)
skip_slow = pytest.mark.skip(reason=msg)
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)

@staticmethod
def _modify_integration(items, config):
msg = ("need --gpu pytest cli option or the environment "
"variable `PONTIBUS_GPU_TESTS` set to `True` to run")
msg = (
"need --gpu pytest cli option or the environment "
"variable `PONTIBUS_GPU_TESTS` set to `True` to run"
)
skip_int = pytest.mark.skip(reason=msg)
for item in items:
if "gpu" in item.keywords:
item.add_marker(skip_int)

def pytest_collection_modifyitems(self, items, config):
if (config.getoption('--gpu') or
os.getenv("PONTIBUS_GPU_TESTS", default="false").lower() == 'true'):
if (
config.getoption("--gpu")
or os.getenv("PONTIBUS_GPU_TESTS", default="false").lower() == "true"
):
return
elif (config.getoption('--runslow') or
os.getenv("PONTIBUS_SLOW_TESTS", default="false").lower() == 'true'):
elif (
config.getoption("--runslow")
or os.getenv("PONTIBUS_SLOW_TESTS", default="false").lower() == "true"
):
self._modify_integration(items, config)
else:
self._modify_integration(items, config)
Expand All @@ -83,26 +92,27 @@ def pytest_addoption(parser):
"--runslow", action="store_true", default=False, help="run slow tests"
)
parser.addoption(
"--gpu", action="store_true", default=False,
"--gpu",
action="store_true",
default=False,
help="run gpu tests",
)


def pytest_configure(config):
config.pluginmanager.register(SlowTests(config), "slow")
config.addinivalue_line("markers", "slow: mark test as slow")
config.addinivalue_line(
"markers", "gpu: mark test as long integration test")
config.addinivalue_line("markers", "gpu: mark test as long integration test")


@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def benzene_modifications():
files = {}
with importlib.resources.files('openfe.tests.data') as d:
fn = str(d / 'benzene_modifications.sdf')
with importlib.resources.files("openfe.tests.data") as d:
fn = str(d / "benzene_modifications.sdf")
supp = Chem.SDMolSupplier(str(fn), removeHs=False)
for rdmol in supp:
files[rdmol.GetProp('_Name')] = SmallMoleculeComponent(rdmol)
files[rdmol.GetProp("_Name")] = SmallMoleculeComponent(rdmol)
return files


Expand All @@ -111,19 +121,19 @@ def CN_molecule():
"""
A basic CH3NH2 molecule for quick testing.
"""
with resources.files('openfe.tests.data') as d:
fn = str(d / 'CN.sdf')
with resources.files("openfe.tests.data") as d:
fn = str(d / "CN.sdf")
supp = Chem.SDMolSupplier(str(fn), removeHs=False)

smc = [SmallMoleculeComponent(i) for i in supp][0]

return smc


@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def T4_protein_component():
with resources.files('openfe.tests.data') as d:
fn = str(d / '181l_only.pdb')
with resources.files("openfe.tests.data") as d:
fn = str(d / "181l_only.pdb")
comp = gufe.ProteinComponent.from_pdb_file(fn, name="T4_protein")

return comp
Loading

0 comments on commit 3e40419

Please sign in to comment.