Skip to content

Commit

Permalink
Merge pull request #88 from wiederm/endstate_correction
Browse files Browse the repository at this point in the history
integrating endstate corrections into TF
  • Loading branch information
JohannesKarwou authored Sep 16, 2022
2 parents 130d732 + 9d47aa6 commit 02b481d
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 0 deletions.
84 changes: 84 additions & 0 deletions transformato/bin/perform_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# general imports
import os
import glob
from endstate_correction.system import create_charmm_system, gen_box
from openmm.app import (
CharmmParameterSet,
CharmmPsfFile,
PDBFile,
CharmmCrdFile,
)
from endstate_correction.analysis import plot_endstate_correction_results
from endstate_correction.protocol import perform_endstate_correction, Protocol
import mdtraj
from openmm import unit


# Variables will be generated by transformato
system_name = NAMEofSYSTEM
tlc = TLC

for env in ["waterbox", "vacuum"]:


# load the charmm specific files (psf, rtf, crd files)
psf_file = f"../{system_name}/intst1/lig_in_{env}.psf"
psf = CharmmPsfFile(psf_file)
pdb = PDBFile(f"../{system_name}/intst1/lig_in_{env}.pdb")
crd = CharmmCrdFile(f"../{system_name}/intst1/lig_in_{env}.crd")

# load forcefiled files (ligand.str and toppar files)
parms = ()
file = f"../{system_name}/intst1/{tlc.lower()}"
if os.path.isfile(f"{file}.str"):
parms += (f"{file}.str",)
else:
parms += (f"{file}_g.rtf",)
parms += (f"{file}.prm",)

parms += (f"../toppar/top_all36_cgenff.rtf",)
parms += (f"../toppar/par_all36_cgenff.prm",)
parms += (f"../toppar/toppar_water_ions.str",)

params = CharmmParameterSet(*parms)

# set up the treatment of the system for the specific environment
if env == "waterbox":
psf = gen_box(psf, crd)

# define region that should be treated with the qml
chains = list(psf.topology.chains())
ml_atoms = [atom.index for atom in chains[0].atoms()]
# define system
sim = create_charmm_system(psf=psf, parameters=params, env=env, ml_atoms=ml_atoms)

########################################################
########################################################
# ------------------- load samples ---------------------#

files = glob.glob(f"../{system_name}/intst1/**/lig_in_{env}.dcd", recursive=True)
traj = mdtraj.load(files, top=psf_file)

if env == "waterbox":
traj.image_molecules()

mm_samples = []
mm_samples.extend(traj.xyz * unit.nanometer) # NOTE: this is in nanometer!

####################################################
# ----------------------- FEP ----------------------
####################################################

fep_protocoll = Protocol(
method="NEQ",
direction="unidirectional",
sim=sim,
trajectories=[mm_samples],
nr_of_switches=500, # 500
neq_switching_length=5000, # 5000
)

r = perform_endstate_correction(fep_protocoll)
plot_endstate_correction_results(
system_name, r, f"results_neq_unidirectional_{env}.png"
)
16 changes: 16 additions & 0 deletions transformato/bin/slurm_switching.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@










export OPENMM_PRECISION='mixed'

pwd; hostname; date


python perform_correction.py
6 changes: 6 additions & 0 deletions transformato/mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from transformato.system import SystemStructure


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,6 +71,7 @@ def perform_mutations(
nr_of_mutation_steps_lj_of_hydrogens: int = 1,
nr_of_mutation_steps_lj_of_heavy_atoms: int = 1,
nr_of_mutation_steps_cc: int = 5,
endstate_correction: bool = False,
):
"""Performs the mutations necessary to mutate the physical endstate to the defined common core.
Expand Down Expand Up @@ -264,6 +266,10 @@ def perform_mutations(
mutation=mutation_list["transform"],
)

if endstate_correction:
i.endstate_correction()



@dataclass
class DummyRegion:
Expand Down
30 changes: 30 additions & 0 deletions transformato/state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import shutil
import glob
from io import StringIO
from typing import List

Expand Down Expand Up @@ -43,6 +44,35 @@ def __init__(
self.current_step = 1
self.multiple_runs = multiple_runs

def endstate_correction(self):

logger.info(f"Will create script for endstate correction")
try:
os.makedirs(f"{self.path}/../endstate_correction")
except:
logger.info(f"Folder for endstate correction exist already")

# copy submit script to the newly created folder
submit_switching_script_source = f"{self.configuration['bin_dir']}/slurm_switching.sh"
submit_switchting_script_target = f"{self.path}/../endstate_correction/slurm_switching.sh"
shutil.copyfile(submit_switching_script_source, submit_switchting_script_target)

# modify the perform_correction file from transformato bin and save it in the endstate_correcition folder
endstate_correction_script_source = f"{self.configuration['bin_dir']}/perform_correction.py"
endstate_correction_script_target = f"{self.path}/../endstate_correction/perform_correction.py"
fin = open(endstate_correction_script_source,"rt")
fout = open(endstate_correction_script_target,"wt")

for line in fin:
if "NAMEofSYSTEM" in line:
fout.write(line.replace("NAMEofSYSTEM",f'"{self.configuration["system"]["structure1"]["name"]}"'))
elif "TLC" in line:
fout.write(line.replace("TLC",f'"{self.configuration["system"]["structure1"]["tlc"]}"'))
else:
fout.write(line)
fin.close()
fout.close()

def write_state(
self,
mutation_conf: List,
Expand Down
28 changes: 28 additions & 0 deletions transformato/tests/test_absolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,31 @@ def test_compare_mda_and_mdtraj():
mda_results = analyse_asfe_with_module(module="mda")
mdtraj_results = analyse_asfe_with_module(module="mdtraj")
assert np.isclose(np.average(mda_results), np.average(mdtraj_results))




def test_perform_enstate_correction_asfe_system():

configuration = load_config_yaml(
config=f"{get_testsystems_dir()}/config/methanol-asfe.yaml",
input_dir=get_testsystems_dir(),
output_dir=get_test_output_dir(),
)

s1, mutation_list = create_asfe_system(configuration)

i = IntermediateStateFactory(system=s1, configuration=configuration)

perform_mutations(
configuration=configuration,
nr_of_mutation_steps_charge=2,
i=i,
mutation_list=mutation_list,
endstate_correction=True,
)


assert len(i.output_files) == 7
assert len((mutation_list)["charge"][0].atoms_to_be_mutated) == 6

0 comments on commit 02b481d

Please sign in to comment.