Skip to content

Commit de6bd89

Browse files
committed
Merge branch 'feat/split_pr_test_1' into feat/sparsified_soap
2 parents d8e15eb + b8dc829 commit de6bd89

31 files changed

+3681
-674
lines changed

bindings/bind_include.hh

+57-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#define BINDINGS_BIND_INCLUDE_HH_
3030

3131
#include "rascal/structure_managers/atomic_structure.hh"
32+
#include "rascal/utils/json_io.hh"
3233
#include "rascal/utils/utils.hh"
3334

3435
#include <pybind11/eigen.h>
@@ -54,9 +55,64 @@ PYBIND11_MAKE_OPAQUE(std::vector<rascal::AtomicStructure<3>>);
5455

5556
namespace py = pybind11;
5657

57-
namespace rascal {
58+
/**
59+
* Simplistic but robust implicit conversion of py::dict to/from nlohmann::json,
60+
* e.g. py::dict to nlohmann::json json j = py::dict("one"_a=1, "b"_a="fgssdf");
61+
* py::dict to/from nlohmann::json
62+
* py::dict d = j.get<py::dict>();
63+
*/
64+
namespace nlohmann {
65+
template <>
66+
struct adl_serializer<py::dict> {
67+
static void to_json(json & j, const py::dict & dic) {
68+
py::module py_json = py::module::import("json");
69+
j = json::parse(
70+
static_cast<std::string>(py::str(py_json.attr("dumps")(dic))));
71+
}
72+
static void from_json(const json & j, py::dict & dic) {
73+
py::module py_json = py::module::import("json");
74+
dic = py_json.attr("loads")(j.dump());
75+
}
76+
};
77+
} // namespace nlohmann
5878

79+
namespace rascal {
5980
namespace internal {
81+
/**
82+
* Expose to python the serialization of rascal objects as a python
83+
* dictionary.
84+
*
85+
* @tparam Object is expected to be nlohmann::json (de)serializable
86+
*
87+
* A copy and a json (de)serialization are necessary to make sure that if
88+
* the resulting dictionary is written in json, then it will be directly
89+
* convertible to the original object in C++ and vice-versa.
90+
*/
91+
template <class Object, class... Bases>
92+
void bind_dict_representation(py::class_<Object, Bases...> & obj) {
93+
// serialization to a python dictionary
94+
obj.def("to_dict", [](const Object & self) {
95+
json j;
96+
j = self; // implicit conversion to nlohmann::json
97+
return j.template get<py::dict>();
98+
});
99+
// construction from a python dictionary
100+
obj.def_static("from_dict", [](const py::dict & d) {
101+
json j;
102+
j = d; // implicit conversion to nlohmann::json
103+
return std::make_unique<Object>(j.template get<Object>());
104+
});
105+
// string representation
106+
obj.def("__str__", [](const Object & self) {
107+
json j = self; // implicit conversion to nlohmann::json
108+
std::string str = j.dump(2);
109+
std::string representation_name{internal::type_name<Object>()};
110+
std::string sep{" | Parameters: "};
111+
std::string prefix{"Class: "};
112+
return prefix + representation_name + sep + str;
113+
});
114+
}
115+
60116
/**
61117
* Transforms the template type to a string for the python bindings.
62118
* There are submodules in the python bindings with the class

bindings/bind_py_models.cc

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ namespace rascal {
3535
py::module & /*m_internal*/) {
3636
std::string kernel_name = internal::GetBindingTypeName<Kernel>();
3737
py::class_<Kernel> kernel(mod, kernel_name.c_str());
38-
kernel.def(py::init([](std::string & hyper_str) {
38+
kernel.def(py::init([](const py::dict & hyper) {
3939
// convert to json
40-
json hypers = json::parse(hyper_str);
40+
json hypers = hyper;
4141
return std::make_unique<Kernel>(hypers);
4242
}));
4343

@@ -158,17 +158,20 @@ namespace rascal {
158158

159159
// Bind the interface of this representation manager
160160
auto kernel = add_kernel<Kernel>(mod, m_internal);
161+
internal::bind_dict_representation(kernel);
161162
bind_kernel_compute_function<internal::KernelType::Cosine, Calc1_t,
162163
ManagerCollection_1_t>(kernel);
163164
bind_kernel_compute_function<internal::KernelType::Cosine, Calc1_t,
164165
ManagerCollection_2_t>(kernel);
165166

166167
// bind the sparse kernel and pseudo points class
167168
auto sparse_kernel = add_kernel<SparseKernel>(mod, m_internal);
169+
internal::bind_dict_representation(sparse_kernel);
168170
bind_sparse_kernel_compute_function<internal::SparseKernelType::GAP,
169171
Calc1_t, ManagerCollection_2_t,
170172
SparsePoints_1_t>(sparse_kernel);
171173
auto sparse_points = add_sparse_points<SparsePoints_1_t>(mod, m_internal);
172174
bind_sparse_points_push_back<ManagerCollection_2_t, Calc1_t>(sparse_points);
175+
internal::bind_dict_representation(sparse_points);
173176
}
174177
} // namespace rascal

bindings/bind_py_representation_calculator.cc

+3-6
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,12 @@ namespace rascal {
4040

4141
py::class_<Calculator, CalculatorBase> representation(
4242
mod, representation_name.c_str());
43-
// use custom constructor to pass json formated string as initializer
44-
// an alternative would be to convert python dict to json internally
45-
// but needs some work on in the pybind machinery
46-
representation.def(py::init([](std::string & hyper_str) {
43+
representation.def(py::init([](const py::dict & hyper) {
4744
// convert to json
48-
json hypers = json::parse(hyper_str);
45+
json hypers = hyper;
4946
return std::make_unique<Calculator>(hypers);
5047
}));
51-
48+
internal::bind_dict_representation(representation);
5249
return representation;
5350
}
5451

bindings/rascal/models/IP_ase_interface.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from ..utils import BaseIO
12
from ase.calculators.calculator import Calculator, all_changes
23
from copy import deepcopy
34

45

5-
class ASEMLCalculator(Calculator):
6+
class ASEMLCalculator(Calculator, BaseIO):
67
"""Wrapper class to use a rascal model as an interatomic potential in ASE
78
89
Parameters
@@ -23,10 +24,10 @@ class ASEMLCalculator(Calculator):
2324
nolabel = True
2425

2526
def __init__(self, model, representation, **kwargs):
26-
Calculator.__init__(self, **kwargs)
27-
27+
super(ASEMLCalculator, self).__init__(**kwargs)
2828
self.model = model
2929
self.representation = representation
30+
self.kwargs = kwargs
3031

3132
def calculate(self, atoms=None, properties=['energy'],
3233
system_changes=all_changes):
@@ -41,3 +42,14 @@ def calculate(self, atoms=None, properties=['energy'],
4142
self.results['energy'] = energy
4243
self.results['free_energy'] = energy
4344
self.results['forces'] = forces
45+
46+
def get_init_params(self):
47+
init_params = dict(model=self.model, representation=self.representation)
48+
init_params.update(**self.kwargs)
49+
return init_params
50+
51+
def _set_data(self, data):
52+
pass
53+
54+
def _get_data(self):
55+
return dict()

bindings/rascal/models/kernels.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from ..lib._rascal.models.kernels import SparseKernel as SparseKernelcpp
33
from ..neighbourlist import AtomsList
44
from .sparse_points import SparsePoints
5+
from ..utils import BaseIO
56
import json
67

78

8-
class Kernel(object):
9+
class Kernel(BaseIO):
910

1011
"""
1112
Computes the kernel for a given representation. In the following
@@ -74,7 +75,7 @@ class Kernel(object):
7475

7576
def __init__(self, representation, name='Cosine', kernel_type='Full', target_type='Structure',
7677
**kwargs):
77-
78+
super(Kernel, self).__init__()
7879
# This case cannot be handled by the c++ side because c++ cannot deduce the
7980
# type from arguments inside a json, so it has to be casted in the c++
8081
# side. Therefore zeta has to be checked here.
@@ -94,17 +95,30 @@ def __init__(self, representation, name='Cosine', kernel_type='Full', target_typ
9495
raise RuntimeError("Kernel name must be one of: Cosine, GAP.")
9596
hypers = dict(name=name, target_type=target_type)
9697
hypers.update(**kwargs)
97-
hypers_str = json.dumps(hypers)
9898
self._rep = representation
9999
self._representation = representation._representation
100100
self.name = name
101101
self._kwargs = kwargs
102102
self.kernel_type = kernel_type
103103
self.target_type = target_type
104104
if 'Sparse' in kernel_type:
105-
self._kernel = SparseKernelcpp(hypers_str)
105+
self._kernel = SparseKernelcpp(hypers)
106106
else:
107-
self._kernel = Kernelcpp(hypers_str)
107+
self._kernel = Kernelcpp(hypers)
108+
109+
def get_init_params(self):
110+
init_params = dict(representation=self._rep,
111+
name=self.name,
112+
kernel_type=self.kernel_type,
113+
target_type=self.target_type)
114+
init_params.update(**self._kwargs)
115+
return init_params
116+
117+
def _set_data(self, data):
118+
pass
119+
120+
def _get_data(self):
121+
return dict()
108122

109123
def __call__(self, X, Y=None, grad=(False, False)):
110124
if isinstance(X, AtomsList):

bindings/rascal/models/krr.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from ..utils import BaseIO
2+
13
import numpy as np
24

35

4-
class KRR(object):
6+
class KRR(BaseIO):
57
"""Kernel Ridge Regression model. Only supports sparse GPR
68
training for the moment.
79
@@ -82,6 +84,19 @@ def predict(self, managers, compute_gradients=False):
8284
def get_weights(self):
8385
return self.weights
8486

87+
def get_init_params(self):
88+
init_params = dict(weights=self.weights, kernel=self.kernel,
89+
X_train=self.X_train, self_contributions=self.self_contributions)
90+
return init_params
91+
92+
def _set_data(self, data):
93+
pass
94+
95+
def _get_data(self):
96+
return dict()
97+
98+
def get_representation_calculator(self):
99+
return self.kernel._rep
85100

86101
def train_gap_model(kernel, managers, KNM_, X_pseudo, y_train, self_contributions, grad_train=None, lambdas=None, jitter=1e-8):
87102
"""

bindings/rascal/models/sparse_points.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from ..lib._rascal.models import kernels
22
from ..neighbourlist import AtomsList
3+
from ..utils import BaseIO
34

45
# names of existing pseudo points implementation on the pybinding side.
5-
_pseudo_points = {}
6+
_sparse_points = {}
67
for k, v in kernels.__dict__.items():
78
if "SparsePoints" in k:
89
name = k
9-
_pseudo_points[name] = v
10+
_sparse_points[name] = v
1011

1112

12-
class SparsePoints(object):
13+
class SparsePoints(BaseIO):
1314
"""
1415
Holds features to be used as references / sparse points / pseudo points
1516
in sparse GPR methods.
@@ -44,24 +45,36 @@ class SparsePoints(object):
4445
"""
4546

4647
def __init__(self, representation):
48+
super(SparsePoints, self).__init__()
4749
self.representation = representation
4850
if 'SphericalInvariants' in str(representation):
49-
self._pseudo_points = _pseudo_points['SparsePointsBlockSparse_SphericalInvariants'](
51+
self._sparse_points = _sparse_points['SparsePointsBlockSparse_SphericalInvariants'](
5052
)
5153
else:
5254
raise ValueError(
5355
'No pseudo point is appropiate for ' + str(representation))
5456

57+
def get_init_params(self):
58+
init_params = dict(representation=self.representation)
59+
return init_params
60+
61+
def _set_data(self, data):
62+
self._sparse_points = self._sparse_points.from_dict(
63+
data['sparse_points'])
64+
65+
def _get_data(self):
66+
return dict(sparse_points=self._sparse_points.to_dict())
67+
5568
def extend(self, atoms_list, selected_indices):
5669
if isinstance(atoms_list, AtomsList):
57-
self._pseudo_points.extend(
70+
self._sparse_points.extend(
5871
self.representation._representation, atoms_list.managers, selected_indices)
5972
else:
60-
self._pseudo_points.extend(
73+
self._sparse_points.extend(
6174
self.representation._representation, atoms_list, selected_indices)
6275

6376
def size(self):
64-
return self._pseudo_points.size()
77+
return self._sparse_points.size()
6578

6679
def get_features(self):
67-
return self._pseudo_points.get_features()
80+
return self._sparse_points.get_features()

bindings/rascal/representations/coulomb_matrix.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from .base import CalculatorFactory
66
from ..utils import FactoryPool
77
from itertools import starmap
8+
from ..utils import BaseIO
89

910

10-
class SortedCoulombMatrix(object):
11+
class SortedCoulombMatrix(BaseIO):
1112
"""
1213
Computes the Sorted Coulomb matrix representation [1].
1314
@@ -38,8 +39,7 @@ class SortedCoulombMatrix(object):
3839
Physical Review Letters, 108(5), 58301. https://doi.org/10.1103/PhysRevLett.108.058301
3940
"""
4041

41-
def __init__(self, cutoff, sorting_algorithm='row_norm', size=10, central_decay=-1, interaction_cutoff=10, interaction_decay=-1,
42-
method='thread', n_workers=1, disable_pbar=False):
42+
def __init__(self, cutoff, sorting_algorithm='row_norm', size=10, central_decay=-1, interaction_cutoff=10, interaction_decay=-1):
4343
self.name = 'sortedcoulomb'
4444
self.size = size
4545
self.hypers = dict()
@@ -57,8 +57,6 @@ def __init__(self, cutoff, sorting_algorithm='row_norm', size=10, central_decay=
5757
dict(name='neighbourlist', args=dict(cutoff=cutoff)),
5858
dict(name='strict', args=dict(cutoff=cutoff))
5959
]
60-
self.misc = dict(method=method, n_workers=n_workers,
61-
disable_pbar=disable_pbar)
6260

6361
def update_hyperparameters(self, **hypers):
6462
"""Store the given dict of hyperparameters
@@ -91,8 +89,7 @@ def transform(self, frames):
9189

9290
self.size = self.get_size(frames.managers)
9391
self.update_hyperparameters(size=self.size)
94-
hypers_str = json.dumps(self.hypers)
95-
self.rep_options = dict(name=self.name, args=[hypers_str])
92+
self.rep_options = dict(name=self.name, args=[self.hypers])
9693
self._representation = CalculatorFactory(self.rep_options)
9794

9895
self._representation.compute(frames.managers)
@@ -109,3 +106,20 @@ def get_size(self, managers):
109106
Nneigh.append(center.nb_pairs + 1)
110107
size = int(np.max(Nneigh))
111108
return size
109+
110+
def get_init_params(self):
111+
init_params = dict(
112+
cutoff=self.hypers['central_cutoff'],
113+
sorting_algorithm=self.hypers['sorting_algorithm'],
114+
size=self.hypers['size'],
115+
central_decay=self.hypers['central_decay'],
116+
interaction_cutoff=self.hypers['interaction_cutoff'],
117+
interaction_decay=self.hypers['interaction_decay']
118+
)
119+
return init_params
120+
121+
def _set_data(self, data):
122+
pass
123+
124+
def _get_data(self):
125+
return dict()

0 commit comments

Comments
 (0)