5
5
from .base import CalculatorFactory
6
6
from ..utils import FactoryPool
7
7
from itertools import starmap
8
+ from ..utils import BaseIO
8
9
9
10
10
- class SortedCoulombMatrix (object ):
11
+ class SortedCoulombMatrix (BaseIO ):
11
12
"""
12
13
Computes the Sorted Coulomb matrix representation [1].
13
14
@@ -38,8 +39,7 @@ class SortedCoulombMatrix(object):
38
39
Physical Review Letters, 108(5), 58301. https://doi.org/10.1103/PhysRevLett.108.058301
39
40
"""
40
41
41
- def __init__ (self , cutoff , sorting_algorithm = 'row_norm' , size = 10 , central_decay = - 1 , interaction_cutoff = 10 , interaction_decay = - 1 ,
42
- method = 'thread' , n_workers = 1 , disable_pbar = False ):
42
+ def __init__ (self , cutoff , sorting_algorithm = 'row_norm' , size = 10 , central_decay = - 1 , interaction_cutoff = 10 , interaction_decay = - 1 ):
43
43
self .name = 'sortedcoulomb'
44
44
self .size = size
45
45
self .hypers = dict ()
@@ -57,8 +57,6 @@ def __init__(self, cutoff, sorting_algorithm='row_norm', size=10, central_decay=
57
57
dict (name = 'neighbourlist' , args = dict (cutoff = cutoff )),
58
58
dict (name = 'strict' , args = dict (cutoff = cutoff ))
59
59
]
60
- self .misc = dict (method = method , n_workers = n_workers ,
61
- disable_pbar = disable_pbar )
62
60
63
61
def update_hyperparameters (self , ** hypers ):
64
62
"""Store the given dict of hyperparameters
@@ -91,8 +89,7 @@ def transform(self, frames):
91
89
92
90
self .size = self .get_size (frames .managers )
93
91
self .update_hyperparameters (size = self .size )
94
- hypers_str = json .dumps (self .hypers )
95
- self .rep_options = dict (name = self .name , args = [hypers_str ])
92
+ self .rep_options = dict (name = self .name , args = [self .hypers ])
96
93
self ._representation = CalculatorFactory (self .rep_options )
97
94
98
95
self ._representation .compute (frames .managers )
@@ -109,3 +106,20 @@ def get_size(self, managers):
109
106
Nneigh .append (center .nb_pairs + 1 )
110
107
size = int (np .max (Nneigh ))
111
108
return size
109
+
110
+ def get_init_params (self ):
111
+ init_params = dict (
112
+ cutoff = self .hypers ['central_cutoff' ],
113
+ sorting_algorithm = self .hypers ['sorting_algorithm' ],
114
+ size = self .hypers ['size' ],
115
+ central_decay = self .hypers ['central_decay' ],
116
+ interaction_cutoff = self .hypers ['interaction_cutoff' ],
117
+ interaction_decay = self .hypers ['interaction_decay' ]
118
+ )
119
+ return init_params
120
+
121
+ def _set_data (self , data ):
122
+ pass
123
+
124
+ def _get_data (self ):
125
+ return dict ()
0 commit comments