Skip to content

Commit

Permalink
Merge pull request #32 from wiederm/psf_io
Browse files Browse the repository at this point in the history
Psf_io
  • Loading branch information
wiederm authored Oct 5, 2021
2 parents fbf599a + cbd3427 commit a087b00
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 110 deletions.
63 changes: 40 additions & 23 deletions transformato/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, configuration: dict, structure_name: str):
self.N_k: dict = {}
self.thinning: int = 0
self.save_results_to_path: str = f"{self.configuration['system_dir']}/results/"
self.traj_files:list = []
self.traj_files = defaultdict(list)

def load_trajs(self, nr_of_max_snapshots: int = 300):
"""
Expand Down Expand Up @@ -141,15 +141,15 @@ def _generate_openMM_system(self, env: str, lambda_state: int) -> Simulation:
)
return simulation

def _thinning_traj(self, traj):
lenght = int(len(traj))
def _thinning(self, any_list):
lenght = int(len(any_list))
start = int(lenght / 4)
traj = traj[start:] # remove the first 25% confs
new_length = int(len(traj))
any_list = any_list[start:] # remove the first 25% confs
new_length = int(len(any_list))
further_thinning = max(
int(new_length / self.nr_of_max_snapshots), 1
) # thinning
return traj[::further_thinning][: self.nr_of_max_snapshots], start, further_thinning
return any_list[::further_thinning][: self.nr_of_max_snapshots], start, further_thinning

def _merge_trajs(self) -> Tuple[dict, dict, int, dict]:
"""
Expand All @@ -165,51 +165,62 @@ def _merge_trajs(self) -> Tuple[dict, dict, int, dict]:
nr_of_states = len(next(os.walk(f"{self.base_path}"))[1])

logger.info(f"Evaluating {nr_of_states} states.")
snapshots: dict = {}
unitcell:dict = {}
snapshots, unitcell = {}, {}

N_k: dict = defaultdict(list)
start = -1
stride = -1
start, stride = -1, -1

for env in self.envs:
confs = []
unitcell_ = []
conf_sub = self.configuration["system"][self.structure][env]
for lambda_state in tqdm(range(1, nr_of_states + 1)):
dcd_path = f"{self.base_path}/intst{lambda_state}/{conf_sub['intermediate-filename']}.dcd"
print(dcd_path)
psf_path = f"{self.base_path}/intst{lambda_state}/{conf_sub['intermediate-filename']}.psf"
if not os.path.isfile(dcd_path):
raise RuntimeError(f"{dcd_path} does not exist.")

traj = mdtraj.open(f"{dcd_path}")
# read trajs, determin offset, start ,stride and unitcell lengths
if start == -1:
xyz, unitcell_lengths, _ = traj.read()
xyz, start, stride = self._thinning_traj(xyz)
print(f'Len: {len(xyz)}, Start: {start}, Stride: {stride}')
unitcell_lengths, _, _ = self._thinning_traj(unitcell_lengths)
xyz, start, stride = self._thinning(xyz)

else:
traj.seek(start)
xyz, unitcell_lengths, _ = traj.read(stride=stride)
xyz, unitcell_lengths = xyz[:self.nr_of_max_snapshots], unitcell_lengths[:self.nr_of_max_snapshots]
print(f'Len: {len(xyz)}, Start: {start}, Stride: {stride}')
xyz = xyz[:self.nr_of_max_snapshots]

print(f'Len: {len(xyz)}, Start: {start}, Stride: {stride}')

# check that we have enough samples
if len(xyz) < 10:
raise RuntimeError(
f"Below 10 conformations per lambda ({len(traj)}) -- decrease the thinning factor (currently: {self.thinning})."
)

# thin unitcell_lengths
# make sure that we can work with vacuum environments
if env != 'vacuum':
unitcell_lengths = unitcell_lengths[:self.nr_of_max_snapshots]
else:
unitcell_lengths = np.zeros(len(xyz))


confs.extend(xyz/10)
unitcell_.extend(unitcell_lengths/10)
logger.info(f"{dcd_path}")
logger.info(f"Nr of snapshots: {len(xyz)}")
N_k[env].append(len(xyz))
self.traj_files.append((dcd_path, psf_path))
self.traj_files[env].append((dcd_path, psf_path))

logger.info(f"Combined nr of snapshots: {len(confs)}")
snapshots[env] = confs
unitcell[env] = unitcell_
assert len(confs) == len(unitcell_)

print(len(confs))
print(N_k)
return (snapshots, unitcell, nr_of_states, N_k)

@staticmethod
Expand Down Expand Up @@ -327,10 +338,16 @@ def _evaluate_traj_with_CHARMM(
def _evaluate_e_on_all_snapshots_CHARMM(
self, snapshots: mdtraj.Trajectory, lambda_state: int, env: str
):

if env == "waterbox":
unitcell_lengths = [
(snapshots.unitcell_lengths[ts][0],
snapshots.unitcell_lengths[ts][1],
snapshots.unitcell_lengths[ts][2])
for ts in range(len(snapshots))
]

volumn_list = [
self._get_V_for_ts(snapshots, env, ts)
self._get_V_for_ts(unitcell_lengths, env, ts)
for ts in range(snapshots.n_frames)
]

Expand Down Expand Up @@ -410,23 +427,23 @@ def _analyse_results_using_mbar(
u_kn = np.stack([r_i for r_i in r])

elif engine == "CHARMM":
del snapshots
confs = []
# write out traj in self.base_path
for (dcd, psf) in self.traj_files:
for (dcd, psf) in self.traj_files[env]:
traj = mdtraj.load(
f"{dcd}",
top=f"{psf}",
)
self._thinning_traj(traj)
# return and append thinned trajs
traj, _, _ = self._thinning(traj)
confs.append(traj)

joined_trajs = mdtraj.join(confs, check_topology=True)
joined_trajs.save_dcd(f"{self.base_path}/traj.dcd")
u_kn = np.stack(
[
self._evaluate_e_on_all_snapshots_CHARMM(
snapshots, lambda_state, env
joined_trajs, lambda_state, env
)
for lambda_state in range(1, self.nr_of_states + 1)
]
Expand Down
6 changes: 4 additions & 2 deletions transformato/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

logger = logging.getLogger(__name__)

platform = "CPU" # CPU or GPU
default_platform = "GPU" # CPU or GPU
temperature = 303.15 * unit.kelvin
kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA
kT = kB * temperature
Expand All @@ -23,7 +23,9 @@ def initialize_NUM_PROC(n_proc):
print(msg)


def change_platform(configuration: dict, change_to="CPU"):
def change_platform(configuration: dict):

change_to = default_platform

if change_to.upper() == "GPU":
configuration["simulation"]["GPU"] = True
Expand Down
11 changes: 8 additions & 3 deletions transformato/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,12 +943,17 @@ def _write_psf(psf, output_file_base: str, env: str):
"""
Writes the new psf and pdb file.
"""

with open(f"{output_file_base}/lig_in_{env}.psf", "w+") as f:
psf.write_psf(f)

string_object = StringIO()
psf.write_psf(string_object)
# read in psf and correct some aspects of the file not suitable for CHARMM
corrected_psf = psf_correction(string_object)
f = open(f"{output_file_base}/lig_in_{env}.psf", "w+")
f.write(corrected_psf)
f.close()
with open(f"{output_file_base}/lig_in_{env}_corr.psf", "w+") as f:
f.write(corrected_psf)
# write pdb
psf.write_pdb(f"{output_file_base}/lig_in_{env}.pdb")

def _init_intermediate_state_dir(self, nr: int):
Expand Down
6 changes: 0 additions & 6 deletions transformato/tests/test_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,8 +1820,6 @@ def setup_2OJ9_tautomer_pair_rsfe(
from ..mutate import mutate_pure_tautomers
from ..constants import change_platform

change_platform(configuration)

s1 = SystemStructure(configuration, "structure1")
s2 = SystemStructure(configuration, "structure2")
s1_to_s2 = ProposeMutationRoute(s1, s2)
Expand All @@ -1846,8 +1844,6 @@ def setup_2OJ9_tautomer_pair_rbfe(
from ..mutate import mutate_pure_tautomers
from ..constants import change_platform

change_platform(configuration)

s1 = SystemStructure(configuration, "structure1")
s2 = SystemStructure(configuration, "structure2")
s1_to_s2 = ProposeMutationRoute(s1, s2)
Expand Down Expand Up @@ -1875,8 +1871,6 @@ def setup_acetylacetone_tautomer_pair(
conf = "transformato/tests/config/test-acetylacetone-tautomer-rsfe.yaml"
configuration = load_config_yaml(config=conf, input_dir="data/", output_dir=".")

change_platform(configuration)

s1 = SystemStructure(configuration, "structure1")
s2 = SystemStructure(configuration, "structure2")
s1_to_s2 = ProposeMutationRoute(s1, s2)
Expand Down
Loading

0 comments on commit a087b00

Please sign in to comment.