2
2
import gc
3
3
import logging
4
4
import multiprocessing as mp
5
- import os
5
+ import os , sys
6
6
import pickle
7
7
import subprocess
8
8
from collections import defaultdict
21
21
from tqdm import tqdm
22
22
23
23
from transformato .constants import temperature
24
- from transformato .utils import get_structure_name
24
+ from transformato .utils import get_structure_name , isnotebook
25
25
26
26
logger = logging .getLogger (__name__ )
27
27
@@ -76,7 +76,10 @@ def __init__(self, configuration: dict, structure_name: str):
76
76
# decide if the name of the system corresponds to structure1 or structure2
77
77
structure = get_structure_name (configuration , structure_name )
78
78
79
- if configuration ["simulation" ]["free-energy-type" ] == "rsfe" :
79
+ if (
80
+ configuration ["simulation" ]["free-energy-type" ] == "rsfe"
81
+ or configuration ["simulation" ]["free-energy-type" ] == "asfe"
82
+ ):
80
83
self .envs = ("vacuum" , "waterbox" )
81
84
self .mbar_results = {"waterbox" : None , "vacuum" : None }
82
85
@@ -96,15 +99,15 @@ def __init__(self, configuration: dict, structure_name: str):
96
99
self .save_results_to_path : str = f"{ self .configuration ['system_dir' ]} /results/"
97
100
self .traj_files = defaultdict (list )
98
101
99
- def load_trajs (self , nr_of_max_snapshots : int = 300 ):
102
+ def load_trajs (self , nr_of_max_snapshots : int = 300 , multiple_runs : int = 0 ):
100
103
"""
101
104
load trajectories, thin trajs and merge themn.
102
105
Also calculate N_k for mbar.
103
106
"""
104
107
105
108
assert type (nr_of_max_snapshots ) == int
106
109
self .nr_of_max_snapshots = nr_of_max_snapshots
107
- self .snapshots , self .unitcell , self .nr_of_states , self .N_k = self ._merge_trajs ()
110
+ self .snapshots , self .unitcell , self .nr_of_states , self .N_k = self ._merge_trajs (multiple_runs )
108
111
109
112
def _generate_openMM_system (self , env : str , lambda_state : int ) -> Simulation :
110
113
# read in necessary files
@@ -156,7 +159,7 @@ def _thinning(self, any_list):
156
159
further_thinning ,
157
160
)
158
161
159
- def _merge_trajs (self ) -> Tuple [dict , dict , int , dict ]:
162
+ def _merge_trajs (self , multiple_runs : int ) -> Tuple [dict , dict , int , dict ]:
160
163
"""
161
164
load trajectories, thin trajs and merge themn.
162
165
Also calculate N_k for mbar.
@@ -179,7 +182,11 @@ def _merge_trajs(self) -> Tuple[dict, dict, int, dict]:
179
182
unitcell_ = []
180
183
conf_sub = self .configuration ["system" ][self .structure ][env ]
181
184
for lambda_state in tqdm (range (1 , nr_of_states + 1 )):
182
- dcd_path = f"{ self .base_path } /intst{ lambda_state } /{ conf_sub ['intermediate-filename' ]} .dcd"
185
+ if multiple_runs :
186
+ dcd_path = f"{ self .base_path } /intst{ lambda_state } /run_{ multiple_runs } /{ conf_sub ['intermediate-filename' ]} .dcd"
187
+ else :
188
+ dcd_path = f"{ self .base_path } /intst{ lambda_state } /{ conf_sub ['intermediate-filename' ]} .dcd"
189
+
183
190
psf_path = f"{ self .base_path } /intst{ lambda_state } /{ conf_sub ['intermediate-filename' ]} .psf"
184
191
if not os .path .isfile (dcd_path ):
185
192
raise RuntimeError (f"{ dcd_path } does not exist." )
@@ -425,9 +432,8 @@ def _evaluate_e_with_openMM(
425
432
return energies
426
433
427
434
def energy_at_lambda (
428
- self , lambda_state : int , env : str , nr_of_max_snapshots : int , in_memory : bool
435
+ self , lambda_state : int , env : str , nr_of_max_snapshots : int , in_memory : bool , multiple_runs : int ,
429
436
) -> Tuple :
430
-
431
437
gc .enable ()
432
438
logger .info (f"Analysing lambda state { lambda_state } of { self .nr_of_states } " )
433
439
conf_sub = self .configuration ["system" ][self .structure ][env ]
@@ -440,7 +446,11 @@ def energy_at_lambda(
440
446
441
447
for lambda_state in range (1 , self .nr_of_states + 1 ):
442
448
443
- dcd_path = f"{ self .base_path } /intst{ lambda_state } /{ conf_sub ['intermediate-filename' ]} .dcd"
449
+ if not multiple_runs :
450
+ dcd_path = f"{ self .base_path } /intst{ lambda_state } /{ conf_sub ['intermediate-filename' ]} .dcd"
451
+ else :
452
+ dcd_path = f"{ self .base_path } /intst{ lambda_state } /run_{ multiple_runs } /{ conf_sub ['intermediate-filename' ]} .dcd"
453
+
444
454
if not os .path .isfile (dcd_path ):
445
455
raise RuntimeError (f"{ dcd_path } does not exist." )
446
456
@@ -497,10 +507,13 @@ def _analyse_results_using_mda(
497
507
engine : str ,
498
508
num_proc : int ,
499
509
nr_of_max_snapshots : int ,
510
+ multiple_runs : int ,
500
511
in_memory : bool = False ,
501
512
):
502
513
503
514
logger .info (f"Evaluating with { engine } , using { num_proc } CPUs" )
515
+ if not os .path .isdir (self .base_path ):
516
+ sys .exit (f"{ self .base_path } does not exist" )
504
517
self .nr_of_states = len (next (os .walk (f"{ self .base_path } " ))[1 ])
505
518
506
519
if engine == "openMM" :
@@ -520,6 +533,7 @@ def _analyse_results_using_mda(
520
533
repeat (env ),
521
534
repeat (nr_of_max_snapshots ),
522
535
repeat (in_memory ),
536
+ repeat (multiple_runs ),
523
537
),
524
538
)
525
539
)
@@ -535,12 +549,19 @@ def _analyse_results_using_mda(
535
549
)
536
550
537
551
if save_results :
538
- file = f"{ self .save_results_to_path } /mbar_data_for_{ self .structure_name } _in_{ env } .pickle"
552
+ if multiple_runs :
553
+ file = f"{ self .save_results_to_path } /mbar_data_for_{ self .structure_name } _in_{ env } _run_{ multiple_runs } .pickle"
554
+ else :
555
+ file = f"{ self .save_results_to_path } /mbar_data_for_{ self .structure_name } _in_{ env } .pickle"
556
+
539
557
logger .info (f"Saving results: { file } " )
540
558
results = {"u_kn" : u_kn , "N_k" : N_k }
541
559
pickle .dump (results , open (file , "wb+" ))
542
-
543
- return self .calculate_dG_using_mbar (u_kn , N_k , env )
560
+
561
+ return self .calculate_dG_using_mbar (u_kn , N_k , env )
562
+
563
+ else :
564
+ return self .calculate_dG_using_mbar (u_kn , N_k , env )
544
565
545
566
def _analyse_results_using_mdtraj (
546
567
self ,
@@ -611,6 +632,7 @@ def calculate_dG_to_common_core(
611
632
num_proc : int = 1 ,
612
633
in_memory : bool = False ,
613
634
nr_of_max_snapshots : int = - 1 ,
635
+ multiple_runs : int = 0 ,
614
636
):
615
637
"""
616
638
Calculate mbar results using either the python package mdtraj
@@ -636,13 +658,14 @@ def calculate_dG_to_common_core(
636
658
engine ,
637
659
num_proc ,
638
660
nr_of_max_snapshots ,
661
+ multiple_runs ,
639
662
)
640
663
elif analyze_traj_with == "mdtraj" :
641
664
self .mbar_results [env ] = self ._analyse_results_using_mdtraj (
642
665
env , self .snapshots [env ], self .unitcell [env ], save_results , engine
643
666
)
644
667
else :
645
- raise RuntimeError ("Either mda or mdtray " )
668
+ raise RuntimeError ("Either mda or mdtraj " )
646
669
647
670
def load_waterbox_results (self , file : str ):
648
671
self .mbar_results ["waterbox" ] = self ._load_mbar_results (file )
@@ -765,8 +788,8 @@ def plot_free_energy_overlap(self, env: str):
765
788
plt .savefig (
766
789
f"{ self .save_results_to_path } /ddG_to_common_core_overlap_{ env } _for_{ self .structure_name } .png"
767
790
)
768
-
769
- plt .show ()
791
+ if isnotebook ():
792
+ plt .show ()
770
793
plt .close ()
771
794
772
795
def plot_free_energy (self , env : str ):
@@ -805,31 +828,17 @@ def plot_free_energy(self, env: str):
805
828
plt .savefig (
806
829
f"{ self .save_results_to_path } /ddG_to_common_core_line_plot_{ env } _for_{ self .structure_name } .png"
807
830
)
808
- plt .show ()
831
+ if isnotebook ():
832
+ plt .show ()
809
833
plt .close ()
810
834
811
- def plot_vacuum_free_energy_overlap (self ):
812
- self .plot_free_energy_overlap ("vacuum" )
813
-
814
- def plot_complex_free_energy_overlap (self ):
815
- self .plot_free_energy_overlap ("complex" )
816
-
817
- def plot_waterbox_free_energy_overlap (self ):
818
- self .plot_free_energy_overlap ("waterbox" )
819
-
820
- def plot_vacuum_free_energy (self ):
821
- self .plot_free_energy ("vacuum" )
822
-
823
- def plot_complex_free_energy (self ):
824
- self .plot_free_energy ("complex" )
825
-
826
- def plot_waterbox_free_energy (self ):
827
- self .plot_free_energy ("waterbox" )
828
-
829
835
@property
830
836
def end_state_free_energy_difference (self ):
831
837
"""DeltaF[lambda=1 --> lambda=0]"""
832
- if self .configuration ["simulation" ]["free-energy-type" ] == "rsfe" :
838
+ if (
839
+ self .configuration ["simulation" ]["free-energy-type" ] == "rsfe"
840
+ or self .configuration ["simulation" ]["free-energy-type" ] == "asfe"
841
+ ):
833
842
return (
834
843
self .waterbox_free_energy_differences [0 , - 1 ]
835
844
- self .vacuum_free_energy_differences [0 , - 1 ],
@@ -848,26 +857,20 @@ def end_state_free_energy_difference(self):
848
857
raise RuntimeError ()
849
858
850
859
def show_summary (self ):
851
- from transformato .utils import isnotebook
852
-
853
- if self .configuration ["simulation" ]["free-energy-type" ] == "rsfe" :
854
- if isnotebook :
855
- # only show this if we are in a notebook
856
- self .plot_vacuum_free_energy_overlap ()
857
- self .plot_waterbox_free_energy_overlap ()
858
- self .plot_vacuum_free_energy ()
859
- self .plot_waterbox_free_energy ()
860
- self .detailed_overlap ("waterbox" )
861
- self .detailed_overlap ("vacuum" )
860
+
861
+ if (
862
+ self .configuration ["simulation" ]["free-energy-type" ] == "rsfe"
863
+ or self .configuration ["simulation" ]["free-energy-type" ] == "asfe"
864
+ ):
865
+ self .plot_free_energy_overlap ("vacuum" )
866
+ self .plot_free_energy_overlap ("waterbox" )
867
+ self .plot_free_energy ("vacuum" )
868
+ self .plot_free_energy ("waterbox" )
862
869
else :
863
- if isnotebook :
864
- # only show this if we are in a notebook
865
- self .plot_complex_free_energy_overlap ()
866
- self .plot_waterbox_free_energy_overlap ()
867
- self .plot_complex_free_energy ()
868
- self .plot_waterbox_free_energy ()
869
- self .detailed_overlap ("complex" )
870
- self .detailed_overlap ("waterbox" )
870
+ self .plot_free_energy_overlap ("waterbox" )
871
+ self .plot_free_energy_overlap ("complex" )
872
+ self .plot_free_energy ("waterbox" )
873
+ self .plot_free_energy ("complex" )
871
874
872
875
energy_estimate , uncertainty = self .end_state_free_energy_difference
873
876
print (
0 commit comments