From 3e404193c10f0e7e319ed10573edc6407f48d12e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 13 Sep 2024 07:39:55 -0400 Subject: [PATCH] linter --- .../components/extended_solvent_component.py | 9 ++-- src/pontibus/protocols/solvation/__init__.py | 4 +- .../protocols/solvation/asfe_protocol.py | 22 +++----- src/pontibus/protocols/solvation/base.py | 24 ++------- src/pontibus/protocols/solvation/settings.py | 3 +- .../tests/components/test_extendedsolvent.py | 28 ++++------- src/pontibus/tests/conftest.py | 50 +++++++++++-------- .../tests/utils/test_interchange_packmol.py | 50 +++++++++---------- src/pontibus/utils/system_creation.py | 10 ++-- 9 files changed, 87 insertions(+), 113 deletions(-) diff --git a/src/pontibus/components/extended_solvent_component.py b/src/pontibus/components/extended_solvent_component.py index 9245e04..ef1f6b1 100644 --- a/src/pontibus/components/extended_solvent_component.py +++ b/src/pontibus/components/extended_solvent_component.py @@ -131,8 +131,8 @@ def from_keyed_dict(cls, dct: dict): dct, lambda d: registry[GufeKey(d[":gufe-key:"])], is_gufe_key_dict, - mode='decode', - top=True + mode="decode", + top=True, ) return from_dict_depth_one(dct) @@ -189,10 +189,7 @@ def _from_dict_depth_one(dct: dict) -> GufeTokenizable: new_dct = {} for entry in dct: - if ( - isinstance(dct[entry], dict) and - '__qualname__' in dct[entry] - ): + if isinstance(dct[entry], dict) and "__qualname__" in dct[entry]: new_dct[entry] = _from_dict(dct[entry]) else: new_dct[entry] = dct[entry] diff --git a/src/pontibus/protocols/solvation/__init__.py b/src/pontibus/protocols/solvation/__init__.py index 69e3c91..b1c5c5b 100644 --- a/src/pontibus/protocols/solvation/__init__.py +++ b/src/pontibus/protocols/solvation/__init__.py @@ -4,9 +4,7 @@ Run absolute solvation free energy calculations using OpenMM and OpenMMTools. """ -from .settings import( - ASFESettings -) +from .settings import ASFESettings from .asfe_protocol import ( ASFEProtocol, ASFEProtocolResult, diff --git a/src/pontibus/protocols/solvation/asfe_protocol.py b/src/pontibus/protocols/solvation/asfe_protocol.py index e54d6a1..9a6af6f 100644 --- a/src/pontibus/protocols/solvation/asfe_protocol.py +++ b/src/pontibus/protocols/solvation/asfe_protocol.py @@ -1,10 +1,7 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from typing import ( - Union, - Optional -) +from typing import Union, Optional import uuid import numpy as np @@ -200,8 +197,7 @@ def _validate_solvent(state: ChemicalSystem, nonbonded_method: str): * If there is a SolventComponent and the `nonbonded_method` is `nocutoff`. """ - solv = [comp for comp in state.values() - if isinstance(comp, SolventComponent)] + solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)] if len(solv) > 0 and nonbonded_method.lower() == "nocutoff": errmsg = "nocutoff cannot be used for solvent transformations" @@ -238,12 +234,10 @@ def _create( # Validate the lambda schedule self._validate_lambda_schedule( - self.settings.lambda_settings, - self.settings.solvent_simulation_settings + self.settings.lambda_settings, self.settings.solvent_simulation_settings ) self._validate_lambda_schedule( - self.settings.lambda_settings, - self.settings.vacuum_simulation_settings + self.settings.lambda_settings, self.settings.vacuum_simulation_settings ) # Check nonbond & solvent compatibility @@ -349,9 +343,7 @@ def _get_components(self): # (of stateA since we enforce only one disappearing ligand) return alchem_comps, None, prot_comp, off_comps - def _handle_settings( - self - ) -> dict[str, gufe.settings.SettingsBaseModel]: + def _handle_settings(self) -> dict[str, gufe.settings.SettingsBaseModel]: """ Extract the relevant settings for a vacuum transformation. @@ -433,9 +425,7 @@ def _get_components(self): # disallowed on create return alchem_comps, solv_comp, prot_comp, off_comps - def _handle_settings( - self - ) -> dict[str, gufe.settings.SettingsBaseModel]: + def _handle_settings(self) -> dict[str, gufe.settings.SettingsBaseModel]: """ Extract the relevant settings for a vacuum transformation. diff --git a/src/pontibus/protocols/solvation/base.py b/src/pontibus/protocols/solvation/base.py index 5e460c7..d83a415 100644 --- a/src/pontibus/protocols/solvation/base.py +++ b/src/pontibus/protocols/solvation/base.py @@ -2,11 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import logging -from typing import ( - Any, - Optional, - Union -) +from typing import Any, Optional, Union import numpy.typing as npt from gufe import ( SmallMoleculeComponent, @@ -25,9 +21,7 @@ from openmm import app import openmmtools from openfe.utils import log_system_probe -from pontibus.components import ( - ExtendedSolventComponent -) +from pontibus.components import ExtendedSolventComponent from pontibus.protocols.solvation.settings import ( IntegratorSettings, OpenFFPartialChargeSettings, @@ -132,10 +126,7 @@ def _get_omm_objects( solvent_component: Optional[SolventComponent], smc_components: dict[SmallMoleculeComponent, OFFMolecule], ) -> tuple[ - app.Topology, - openmm.System, - openmm.unit.Quantity, - dict[str, npt.NDArray] + app.Topology, openmm.System, openmm.unit.Quantity, dict[str, npt.NDArray] ]: """ Get the OpenMM Topology, Positions and System of the @@ -311,9 +302,7 @@ def run( ) # 15. Run simulation - unit_result_dict = self._run_simulation( - sampler, reporter, settings, dry - ) + unit_result_dict = self._run_simulation(sampler, reporter, settings, dry) finally: # close reporter when you're done to prevent file handle clashes @@ -355,10 +344,7 @@ def _execute( ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - outputs = self.run( - scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared - ) + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) return { "repeat_id": self._inputs["repeat_id"], diff --git a/src/pontibus/protocols/solvation/settings.py b/src/pontibus/protocols/solvation/settings.py index 1a11035..66c866a 100644 --- a/src/pontibus/protocols/solvation/settings.py +++ b/src/pontibus/protocols/solvation/settings.py @@ -81,8 +81,7 @@ def is_positive_distance(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.nanometer): raise ValueError( - "nonbonded_cutoff must be in distance units " - "(i.e. nanometers)" + "nonbonded_cutoff must be in distance units " "(i.e. nanometers)" ) if v < 0: errmsg = "nonbonded_cutoff must be a positive value" diff --git a/src/pontibus/tests/components/test_extendedsolvent.py b/src/pontibus/tests/components/test_extendedsolvent.py index e3bed97..47712ef 100644 --- a/src/pontibus/tests/components/test_extendedsolvent.py +++ b/src/pontibus/tests/components/test_extendedsolvent.py @@ -15,9 +15,9 @@ def test_defaults(): s = ExtendedSolventComponent() - assert s.smiles == '[H][O][H]' - assert s.positive_ion == 'Na+' - assert s.negative_ion == 'Cl-' + assert s.smiles == "[H][O][H]" + assert s.positive_ion == "Na+" + assert s.negative_ion == "Cl-" assert s.ion_concentration == 0.0 * unit.molar assert s.neutralize == False assert s.solvent_molecule == WATER @@ -26,33 +26,31 @@ def test_defaults(): def test_neq_different_smc(): water_off = WATER.to_openff() # Create a water with partial charges - water_off.assign_partial_charges(partial_charge_method='gasteiger') + water_off.assign_partial_charges(partial_charge_method="gasteiger") WATER2 = SmallMoleculeComponent.from_openff(water_off) s1 = ExtendedSolventComponent(solvent_molecule=WATER) s2 = ExtendedSolventComponent(solvent_molecule=WATER2) assert s1 != s2 - assert s1.smiles == '[H][O][H]' == s2.smiles + assert s1.smiles == "[H][O][H]" == s2.smiles def test_neq_different_solvent(): - meth_off = Molecule.from_smiles('C') + meth_off = Molecule.from_smiles("C") meth_off.generate_conformers() METH = SmallMoleculeComponent.from_openff(meth_off) s1 = ExtendedSolventComponent() s2 = ExtendedSolventComponent(solvent_molecule=METH) assert s1 != s2 - assert s1.smiles == '[H][O][H]' - assert s2.smiles == '[H][C]([H])([H])[H]' + assert s1.smiles == "[H][O][H]" + assert s2.smiles == "[H][C]([H])([H])[H]" assert s1.smiles != s2.smiles def test_dict_roundtrip_eq(): s1 = ExtendedSolventComponent() - s2 = ExtendedSolventComponent.from_dict( - s1.to_dict() - ) + s2 = ExtendedSolventComponent.from_dict(s1.to_dict()) assert s1 == s2 assert s1.solvent_molecule == s2.solvent_molecule # Smiles isn't a dict entry, so make sure it got preserved @@ -61,9 +59,7 @@ def test_dict_roundtrip_eq(): def test_keyed_dict_roundtrip_eq(): s1 = ExtendedSolventComponent() - s2 = ExtendedSolventComponent.from_keyed_dict( - s1.to_keyed_dict() - ) + s2 = ExtendedSolventComponent.from_keyed_dict(s1.to_keyed_dict()) assert s1 == s2 # Smiles isn't a dict entry, so make sure it got preserved assert s1.smiles == s2.smiles @@ -74,9 +70,7 @@ def test_keyed_dict_roundtrip_eq(): def test_shallow_dict_roundtrip_eq(): s1 = ExtendedSolventComponent() - s2 = ExtendedSolventComponent.from_shallow_dict( - s1.to_shallow_dict() - ) + s2 = ExtendedSolventComponent.from_shallow_dict(s1.to_shallow_dict()) assert s1 == s2 # Smiles isn't a dict entry, so make sure it got preserved assert s1.smiles == s2.smiles diff --git a/src/pontibus/tests/conftest.py b/src/pontibus/tests/conftest.py index 0a91488..6a16975 100644 --- a/src/pontibus/tests/conftest.py +++ b/src/pontibus/tests/conftest.py @@ -43,13 +43,16 @@ class SlowTests: To run the `slow` tests, either use the `--runslow` flag when invoking pytest, or set the environment variable `PONTIBUS_SLOW_TESTS` to `true` """ + def __init__(self, config): self.config = config @staticmethod def _modify_slow(items, config): - msg = ("need --runslow pytest cli option or the environment variable " - "`PONTIBUS_SLOW_TESTS` set to `True` to run") + msg = ( + "need --runslow pytest cli option or the environment variable " + "`PONTIBUS_SLOW_TESTS` set to `True` to run" + ) skip_slow = pytest.mark.skip(reason=msg) for item in items: if "slow" in item.keywords: @@ -57,19 +60,25 @@ def _modify_slow(items, config): @staticmethod def _modify_integration(items, config): - msg = ("need --gpu pytest cli option or the environment " - "variable `PONTIBUS_GPU_TESTS` set to `True` to run") + msg = ( + "need --gpu pytest cli option or the environment " + "variable `PONTIBUS_GPU_TESTS` set to `True` to run" + ) skip_int = pytest.mark.skip(reason=msg) for item in items: if "gpu" in item.keywords: item.add_marker(skip_int) def pytest_collection_modifyitems(self, items, config): - if (config.getoption('--gpu') or - os.getenv("PONTIBUS_GPU_TESTS", default="false").lower() == 'true'): + if ( + config.getoption("--gpu") + or os.getenv("PONTIBUS_GPU_TESTS", default="false").lower() == "true" + ): return - elif (config.getoption('--runslow') or - os.getenv("PONTIBUS_SLOW_TESTS", default="false").lower() == 'true'): + elif ( + config.getoption("--runslow") + or os.getenv("PONTIBUS_SLOW_TESTS", default="false").lower() == "true" + ): self._modify_integration(items, config) else: self._modify_integration(items, config) @@ -83,7 +92,9 @@ def pytest_addoption(parser): "--runslow", action="store_true", default=False, help="run slow tests" ) parser.addoption( - "--gpu", action="store_true", default=False, + "--gpu", + action="store_true", + default=False, help="run gpu tests", ) @@ -91,18 +102,17 @@ def pytest_addoption(parser): def pytest_configure(config): config.pluginmanager.register(SlowTests(config), "slow") config.addinivalue_line("markers", "slow: mark test as slow") - config.addinivalue_line( - "markers", "gpu: mark test as long integration test") + config.addinivalue_line("markers", "gpu: mark test as long integration test") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_modifications(): files = {} - with importlib.resources.files('openfe.tests.data') as d: - fn = str(d / 'benzene_modifications.sdf') + with importlib.resources.files("openfe.tests.data") as d: + fn = str(d / "benzene_modifications.sdf") supp = Chem.SDMolSupplier(str(fn), removeHs=False) for rdmol in supp: - files[rdmol.GetProp('_Name')] = SmallMoleculeComponent(rdmol) + files[rdmol.GetProp("_Name")] = SmallMoleculeComponent(rdmol) return files @@ -111,8 +121,8 @@ def CN_molecule(): """ A basic CH3NH2 molecule for quick testing. """ - with resources.files('openfe.tests.data') as d: - fn = str(d / 'CN.sdf') + with resources.files("openfe.tests.data") as d: + fn = str(d / "CN.sdf") supp = Chem.SDMolSupplier(str(fn), removeHs=False) smc = [SmallMoleculeComponent(i) for i in supp][0] @@ -120,10 +130,10 @@ def CN_molecule(): return smc -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def T4_protein_component(): - with resources.files('openfe.tests.data') as d: - fn = str(d / '181l_only.pdb') + with resources.files("openfe.tests.data") as d: + fn = str(d / "181l_only.pdb") comp = gufe.ProteinComponent.from_pdb_file(fn, name="T4_protein") return comp diff --git a/src/pontibus/tests/utils/test_interchange_packmol.py b/src/pontibus/tests/utils/test_interchange_packmol.py index 276c6ec..35f700d 100644 --- a/src/pontibus/tests/utils/test_interchange_packmol.py +++ b/src/pontibus/tests/utils/test_interchange_packmol.py @@ -24,16 +24,16 @@ @pytest.fixture() def smc_components_benzene(benzene_modifications): - benzene_off = benzene_modifications['benzene'].to_openff() - benzene_off.assign_partial_charges(partial_charge_method='gasteiger') - return {benzene_modifications['benzene']: benzene_off} + benzene_off = benzene_modifications["benzene"].to_openff() + benzene_off.assign_partial_charges(partial_charge_method="gasteiger") + return {benzene_modifications["benzene"]: benzene_off} @pytest.fixture() def methane(): - m = Molecule.from_smiles('C') + m = Molecule.from_smiles("C") m.generate_conformers() - m.assign_partial_charges(partial_charge_method='gasteiger') + m.assign_partial_charges(partial_charge_method="gasteiger") return m @@ -46,7 +46,7 @@ def test_protein_component_fail(smc_components_benzene, T4_protein_component): smc_components=smc_components_benzene, protein_component=T4_protein_component, solvent_component=None, - solvent_offmol=None + solvent_offmol=None, ) @@ -57,28 +57,29 @@ def test_get_and_set_offmol_resname(CN_molecule, caplog): assert _get_offmol_resname(CN_off) is None # Boop the floof - _set_offmol_resname(CN_off, 'BOOP') + _set_offmol_resname(CN_off, "BOOP") # Does the floof accept the boop? - assert 'BOOP' == _get_offmol_resname(CN_off) + assert "BOOP" == _get_offmol_resname(CN_off) # Oh no, one of the atoms didn't like the boop! atom3 = list(CN_off.atoms)[2] - atom3.metadata['residue_name'] = 'NOBOOP' + atom3.metadata["residue_name"] = "NOBOOP" with caplog.at_level(logging.WARNING): assert _get_offmol_resname(CN_off) is None - assert 'Inconsistent residue name' in caplog.text + assert "Inconsistent residue name" in caplog.text -@pytest.mark.parametrize('neutralize, ion_conc', [ - [True, 0.0 * unit.molar], - [False, 0.1 * unit.molar], - [True, 0.1 * unit.molar], -]) -def test_wrong_solventcomp_settings( - neutralize, ion_conc, smc_components_benzene -): +@pytest.mark.parametrize( + "neutralize, ion_conc", + [ + [True, 0.0 * unit.molar], + [False, 0.1 * unit.molar], + [True, 0.1 * unit.molar], + ], +) +def test_wrong_solventcomp_settings(neutralize, ion_conc, smc_components_benzene): with pytest.raises(ValueError, match="Adding counterions"): interchange_packmol_creation( ffsettings=InterchangeFFSettings(), @@ -111,7 +112,7 @@ def test_solv_mismatch( smc_components_benzene, methane, ): - assert ExtendedSolventComponent().smiles == '[H][O][H]' + assert ExtendedSolventComponent().smiles == "[H][O][H]" with pytest.raises(ValueError, match="does not match"): interchange_packmol_creation( ffsettings=InterchangeFFSettings(), @@ -135,19 +136,18 @@ def test_vacuum(smc_components_benzene): assert len(comp_resids) == 1 assert list(smc_components_benzene)[0] in comp_resids - + # Get the topology out omm_topology = interchange.to_openmm_topology() residues = list(omm_topology.residues()) assert len(residues) == 1 assert len(list(omm_topology.atoms())) == 12 - assert residues[0].name == 'AAA' + assert residues[0].name == "AAA" # Get the openmm system out.. omm_system = interchange.to_openmm_system() - nonbond = [f for f in omm_system.getForces() - if isinstance(f, NonbondedForce)] + nonbond = [f for f in omm_system.getForces() if isinstance(f, NonbondedForce)] # One nonbonded force assert len(nonbond) == 1 @@ -155,12 +155,12 @@ def test_vacuum(smc_components_benzene): # Gas phase should be nonbonded assert nonbond[0].getNonbondedMethod() == 0 - bond = [f for f in omm_system.getForces() - if not isinstance(f, NonbondedForce)] + bond = [f for f in omm_system.getForces() if not isinstance(f, NonbondedForce)] # 3 bonded forces assert len(bond) == 3 + """ 4. Named solvent 5. Unamed solvent diff --git a/src/pontibus/utils/system_creation.py b/src/pontibus/utils/system_creation.py index ab50de2..4b05683 100644 --- a/src/pontibus/utils/system_creation.py +++ b/src/pontibus/utils/system_creation.py @@ -21,9 +21,7 @@ ForceField, Topology, ) -from openff.interchange import ( - Interchange -) +from openff.interchange import Interchange from openff.interchange.components._packmol import ( solvate_topology_nonwater, RHOMBIC_DODECAHEDRON, @@ -158,8 +156,10 @@ def interchange_packmol_creation( if not solvent_offmol.is_isomorphic_with( OFFMolecule.from_smiles(solvent_component.smiles) ): - errmsg = (f"Passed molecule: {solvent_offmol} does not match the " - f"the solvent component: {solvent_component.smiles}") + errmsg = ( + f"Passed molecule: {solvent_offmol} does not match the " + f"the solvent component: {solvent_component.smiles}" + ) raise ValueError(errmsg) # 2. Get the force field object