Skip to content

Commit

Permalink
changed file structure
Browse files Browse the repository at this point in the history
add custom permutation set
  • Loading branch information
abelcarreras committed Mar 31, 2024
1 parent b7da6d2 commit 4157b44
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 105 deletions.
37 changes: 19 additions & 18 deletions posym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from posym.basis import BasisFunction
from posym.config import Configuration
from posym.tools import uniform_euler_scan, collapse_limit
from posym.permutation import generate_permutation_set
from posym.permutation.hungarian import get_permutation_hungarian
from scipy.spatial.transform import Rotation as R
from scipy.optimize import minimize
import numpy as np
import pandas as pd
import warnings


cache_orientation = {}
Expand Down Expand Up @@ -126,7 +129,7 @@ class SymmetryMolecule(SymmetryObject):
"""
Symmetry of molecular geometry
"""
def __init__(self, group, coordinates, symbols, total_state=None, orientation_angles=None, center=None):
def __init__(self, group, coordinates, symbols, total_state=None, orientation_angles=None, center=None, permutation_set=None):
"""
:param group: symmetry point group
Expand All @@ -137,7 +140,7 @@ def __init__(self, group, coordinates, symbols, total_state=None, orientation_an
:param center: center of symmetry group [x, y, z]
"""

self._setup_structure(coordinates, symbols, group, center, orientation_angles)
self._setup_structure(coordinates, symbols, group, center, orientation_angles, permutation_set=permutation_set)

if total_state is None:
# m = self.measure_pos
Expand All @@ -162,12 +165,11 @@ def __init__(self, group, coordinates, symbols, total_state=None, orientation_an
total_state = pd.Series(self._operator_measures, index=self._pg.op_labels)

if not self.check_permutation_coherence and not collapse_limit(self.symmetrized_coordinates):
import warnings
warnings.warn('Incoherence found in symmetrized structure. Symmetry measure may be incorrect')

super().__init__(group, total_state)

def _setup_structure(self, coordinates, symbols, group, center, orientation_angles):
def _setup_structure(self, coordinates, symbols, group, center, orientation_angles, permutation_set=None):

conf = Configuration()

Expand All @@ -188,6 +190,11 @@ def _setup_structure(self, coordinates, symbols, group, center, orientation_angl
else:
self._angles = orientation_angles

# manual permutation
if permutation_set is not None:
self._permutation_set[tuple(orientation_angles)] = {gen: perm for gen, perm in
zip(self._pg.generators,permutation_set)}

self._generate_permutation_set(self._angles)

def get_orientation(self, fast_optimization=True, scan_step=20, guess_angles=None):
Expand Down Expand Up @@ -275,7 +282,7 @@ def get_oriented_operations(self):
return operations_list

def print_operations_info(self):
from posym.operations.permutation import Permutation
from posym.permutation import Permutation

print('\nOperations list (molecule orientation)'
'\n--------------------------------------')
Expand Down Expand Up @@ -303,14 +310,6 @@ def print_operations_info(self):

def _generate_permutation_set(self, angles, force_reset=False):

from posym.operations.permutation import generate_permutation_set
from posym.operations import get_permutation_aprox

conf = Configuration()
use_approx = True
if conf.algorithm == 'exact':
use_approx = False

rotmol = R.from_euler('zyx', angles, degrees=True)
dict_key = tuple(angles)

Expand All @@ -320,16 +319,16 @@ def _generate_permutation_set(self, angles, force_reset=False):
self._permutation_set[dict_key] = next(generate_permutation_set(self._pg.generators, self._symbols))
return

# approximations
if use_approx:
# Hungarian algorithm (approximated)
if Configuration().algorithm == 'hungarian':
permutation_set = {}
for gen in self._pg.generators:
rot_coor = rotmol.inv().apply(self._coordinates)
permutation_set[gen] = get_permutation_aprox(gen.matrix_representation, rot_coor, self._symbols, gen._order)
permutation_set[gen] = get_permutation_hungarian(gen.matrix_representation, rot_coor, self._symbols)
self._permutation_set[dict_key] = permutation_set

else:
# exact
# Brute force algorithm (exact)
elif Configuration().algorithm == 'exact':
ir_rep_diff_max = -100

class NotValidPermutation(Exception): pass
Expand Down Expand Up @@ -357,6 +356,8 @@ class NotValidPermutation(Exception): pass

except NotValidPermutation:
continue
else:
raise Exception('Permutation algorithm not recognized ')

for operation in self._pg.operations:
operation.set_permutation_set(self._permutation_set[dict_key], self._symbols, ignore_compatibility=True)
Expand Down
2 changes: 1 addition & 1 deletion posym/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy
import math
import itertools
from posym.integrals import product_poly_coeff, gaussian_integral, gaussian_integral_2
from posym.integrals import product_poly_coeff, gaussian_integral, gaussian_integral_2 # noqa
from scipy.special import comb


Expand Down
84 changes: 1 addition & 83 deletions posym/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,5 @@
import numpy as np
from functools import lru_cache
# from posym.permutations import get_cross_distance_table
from posym.permutations import get_permutation_annealing, get_permutation_brute, fix_permutation, validate_permutation # noqa
from scipy.optimize import linear_sum_assignment
from posym.config import Configuration, CustomPerm
from posym.operations.permutation import Permutation


@lru_cache(maxsize=100)
def get_submatrix_indices(symbols):
# separate distance_table in submatrices corresponding to a single symbol
submatrices_indices = []
for s in np.unique(symbols):
submatrices_indices.append([j for j, s2 in enumerate(symbols) if s2 == s])

return submatrices_indices


def get_permutation_labels(distance_table, symbols, permutation_function):
"""
This function restricts permutations by the use of atom labels
returns the permutation vector that minimizes its trace using custom algorithms.
"""
submatrices_indices = get_submatrix_indices(symbols)

# determine the permutation for each submatrix
perm_submatrices = []
for index in submatrices_indices:
submatrix = np.array(distance_table)[index, :][:, index]
perm_sub = permutation_function(submatrix)
perm_submatrices.append(perm_sub)

# restore global permutation by joining permutations of submatrices
global_permutation = np.zeros(len(distance_table), dtype=int)
for index, perm in zip(submatrices_indices, perm_submatrices):
index = np.array(index)
global_permutation[index] = index[perm]

return np.array(global_permutation)


def cache_permutation(func):
cache_dict = {}

def wrapper_cache(operation, coordinates, symbols, order):
hash_key = (np.array2string(operation), np.array2string(coordinates), tuple(symbols))
if hash_key in cache_dict:
return cache_dict[hash_key]

cache_dict[hash_key] = func(operation, coordinates, symbols, order)
return cache_dict[hash_key]

return wrapper_cache


@cache_permutation
def get_permutation_aprox(operation, coordinates, symbols, order):

operated_coor = np.dot(operation, coordinates.T).T
symbols = tuple(int.from_bytes(num.encode(), 'big') for num in symbols)

dot_table = -np.dot(coordinates, operated_coor.T)
# dot_table = get_cross_distance_table(coordinates, operated_coor)

# permutation algorithms functions
def hungarian_algorithm(sub_matrix):
row_ind, col_ind = linear_sum_assignment(sub_matrix)
perm = np.zeros_like(row_ind)
perm[row_ind] = col_ind
return perm

def annealing_algorithm(dot_matrix):
return get_permutation_annealing(dot_matrix, order, 1)

def brute_force_algorithm(dot_matrix):
return get_permutation_brute(dot_matrix, order, 1)

# algorithms list
algorithm_dict = {'hungarian': hungarian_algorithm,
'annealing': annealing_algorithm,
'brute_force': brute_force_algorithm}

return get_permutation_labels(dot_table, symbols, algorithm_dict[Configuration().algorithm])
from posym.permutation.permutations import validate_permutation # noqa


class Operation:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ def get_orbits(self):
if not p in track_pos:
track_pos.append(p)
orbit = [p]
# print('perm', p)
while orbit[0] != self._permutation[p]:
p = self._permutation[p]
# print('p', p)
track_pos.append(p)
orbit.append(p)
self._orbits.append(orbit)
Expand Down
70 changes: 70 additions & 0 deletions posym/permutation/hungarian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from functools import lru_cache
# from posym.permutations import get_cross_distance_table
from scipy.optimize import linear_sum_assignment
import numpy as np


@lru_cache(maxsize=100)
def get_submatrix_indices(symbols):
# separate distance_table in submatrices corresponding to a single symbol
submatrices_indices = []
for s in np.unique(symbols):
submatrices_indices.append([j for j, s2 in enumerate(symbols) if s2 == s])

return submatrices_indices


def get_permutation_labels(distance_table, symbols, permutation_function):
"""
This function restricts permutations by the use of atom labels
returns the permutation vector that minimizes its trace using custom algorithms.
"""
submatrices_indices = get_submatrix_indices(symbols)

# determine the permutation for each submatrix
perm_submatrices = []
for index in submatrices_indices:
submatrix = np.array(distance_table)[index, :][:, index]
perm_sub = permutation_function(submatrix)
perm_submatrices.append(perm_sub)

# restore global permutation by joining permutations of submatrices
global_permutation = np.zeros(len(distance_table), dtype=int)
for index, perm in zip(submatrices_indices, perm_submatrices):
index = np.array(index)
global_permutation[index] = index[perm]

return np.array(global_permutation)


def cache_permutation(func):
cache_dict = {}

def wrapper_cache(operation, coordinates, symbols):
hash_key = (np.array2string(operation), np.array2string(coordinates), tuple(symbols))
if hash_key in cache_dict:
return cache_dict[hash_key]

cache_dict[hash_key] = func(operation, coordinates, symbols)
return cache_dict[hash_key]

return wrapper_cache


@cache_permutation
def get_permutation_hungarian(operation, coordinates, symbols):

operated_coor = np.dot(operation, coordinates.T).T
symbols = tuple(int.from_bytes(num.encode(), 'big') for num in symbols)

dot_table = -np.dot(coordinates, operated_coor.T)
# dot_table = get_cross_distance_table(coordinates, operated_coor)

# permutation algorithms functions
def hungarian_algorithm(sub_matrix):
row_ind, col_ind = linear_sum_assignment(sub_matrix)
perm = np.zeros_like(row_ind)
perm[row_ind] = col_ind
return perm

return get_permutation_labels(dot_table, symbols, hungarian_algorithm)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def check_compiler():
include_dirs=include_dirs_numpy,
sources=['c/integrals.c'])

permutations = Extension('posym.permutations',
permutations = Extension('posym.permutation.permutations',
extra_compile_args=['-std=c99'],
include_dirs=include_dirs_numpy,
sources=['c/permutations.c'])
Expand Down

0 comments on commit 4157b44

Please sign in to comment.