diff --git a/paratemp/coordinate_analysis.py b/paratemp/coordinate_analysis.py index 5fcb954..6f361b3 100644 --- a/paratemp/coordinate_analysis.py +++ b/paratemp/coordinate_analysis.py @@ -207,6 +207,8 @@ def calculate_distances(self, *args, recalculate=False, first_group = self.select_atoms('protein and not protein') second_group = self.select_atoms('protein and not protein') column_names = [] + groups_CoM = [] + column_names_CoM = [] if len(args) == 0 and len(kwargs) == 0: args = ['all'] if len(args) != 0: @@ -249,6 +251,14 @@ def calculate_distances(self, *args, recalculate=False, ' currently supported.\nAt your ' 'own risk you can try assigning ' 'to self._data[{}].'.format(key)) + except TypeError: + selections = [] + for g in atoms: + sel_string = 'bynum ' + ' '.join([str(i) for i in g]) + selections.append(self.select_atoms(sel_string)) + groups_CoM.append(selections) + column_names_CoM.append(key) + continue first_group += self.select_atoms('bynum '+str(atoms[0])) second_group += self.select_atoms('bynum '+str(atoms[1])) column_names += [key] @@ -267,6 +277,14 @@ def calculate_distances(self, *args, recalculate=False, nc) + '\nThis should not happen.\nPossibly invalid ' 'atom selection.') + n_groups = len(groups_CoM) + n_group_names = len(column_names_CoM) + if not n_groups == n_group_names: + raise SyntaxError('Different numbers of atom groups or number' + ' of column labels for CoM calculations' + '({} and {}, respectively).\n' + 'This should not happen.'.format(n_groups, + n_group_names)) if self._num_frames != self.trajectory.n_frames: if self._verbosity: print('Current trajectory has {} frames, '.format( @@ -275,14 +293,21 @@ def calculate_distances(self, *args, recalculate=False, '{} frames.'.format(self._num_frames)) if not ignore_file_change: raise FileChangedError() - dists = np.zeros((self._num_frames, n1)) + dists = np.zeros((self._num_frames, n1 + n_groups)) + positions_1 = np.zeros((n1+n_groups, 3), dtype=np.float32) + positions_2 = np.zeros((n1+n_groups, 3), dtype=np.float32) for i in range(self._num_frames): self.trajectory[i] - MDa.lib.distances.calc_bonds(first_group.positions, - second_group.positions, + positions_1[:n1] = first_group.positions + positions_2[:n1] = second_group.positions + for j, group in enumerate(groups_CoM): + positions_1[n1+j] = group[0].center_of_mass() + positions_2[n1+j] = group[1].center_of_mass() + MDa.lib.distances.calc_bonds(positions_1, + positions_2, box=self.dimensions, result=dists[i]) - for i, column in enumerate(column_names): + for i, column in enumerate(column_names + column_names_CoM): self._data[column] = dists[:, i] if save_data: self.save_data() diff --git a/tests/conftest.py b/tests/conftest.py index 92d3d53..db29c4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,12 @@ def ref_a_dists(): names=['a'], index_col=0) +@pytest.fixture +def ref_g_dists(): + import numpy + return numpy.load('tests/ref-data/spc2-g-dists.npy') + + @pytest.fixture def ref_delta_g(): return np.load('tests/ref-data/spc2-fes1d-delta-gs.npy') diff --git a/tests/ref-data/spc2-g-dists.npy b/tests/ref-data/spc2-g-dists.npy new file mode 100644 index 0000000..5306311 Binary files /dev/null and b/tests/ref-data/spc2-g-dists.npy differ diff --git a/tests/test_coordinate_analysis.py b/tests/test_coordinate_analysis.py index cd1faeb..21390e2 100644 --- a/tests/test_coordinate_analysis.py +++ b/tests/test_coordinate_analysis.py @@ -114,6 +114,12 @@ def test_distance_pbc(self, univ_pbc, ref_a_pbc_dists): read_data=False, save_data=False) assert np.isclose(ref_a_pbc_dists['a'], univ_pbc.data['a']).all() + def test_distances_com(self, univ, ref_g_dists): + univ.calculate_distances( + read_data=False, save_data=False, + g=((1, 2), (3, 4))) + assert np.isclose(ref_g_dists, univ.data).all() + def test_calculate_distance_raises(self, univ): with pytest.raises(SyntaxError): univ.calculate_distances(1, read_data=False, save_data=False) @@ -132,7 +138,6 @@ def test_calculate_distance_warns(self, univ): match='following positional arguments were given'): univ.calculate_distances('fail', read_data=False, save_data=False) - def test_fes_1d_data_str(self, univ_w_a, ref_delta_g, ref_bins): """ :type univ_w_a: paratemp.coordinate_analysis.Universe