Skip to content

Commit e1be1f7

Browse files
authored
Merge pull request #438 from s-mayani/fix_interpolate
Fix interpolation on GPU
2 parents ac4b79b + bc8136e commit e1be1f7

File tree

6 files changed

+274
-45
lines changed

6 files changed

+274
-45
lines changed

cmake/FailingTests.cmake

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
if(BUILD_TESTING AND IPPL_SKIP_FAILING_TESTS)
1212
set(IPPL_DISABLED_TEST_LIST
13-
AssembleRHS
14-
InterpolateDiracs
1513
ParticleSendRecv
1614
ORB
1715
PIC

src/FEM/FEMInterpolate.hpp

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,17 @@ namespace ippl {
1111
* Assumes the input x is strictly inside the computational domain so that
1212
* for each dimension d: 0 ≤ (x[d]-origin[d])/h[d] < nr[d]-1.
1313
*/
14-
template <typename T, unsigned Dim, class Mesh>
14+
template <typename T, unsigned Dim>
1515
KOKKOS_INLINE_FUNCTION void
16-
locate_element_nd_and_xi(const Mesh& mesh,
17-
const ippl::Vector<T,Dim>& x,
18-
ippl::Vector<size_t,Dim>& e_nd,
19-
ippl::Vector<T,Dim>& xi) {
20-
21-
const auto nr = mesh.getGridsize(); // vertices per axis
22-
const auto h = mesh.getMeshSpacing();
23-
const auto org = mesh.getOrigin();
24-
16+
locate_element_nd_and_xi(const Vector<T, Dim>& hr,
17+
const Vector<T, Dim>& origin,
18+
const Vector<T, Dim>& x,
19+
Vector<size_t, Dim>& e_nd,
20+
Vector<T, Dim>& xi) {
2521

2622
for (unsigned d = 0; d < Dim; ++d) {
27-
const T s = (x[d] - org[d]) / h[d]; // To cell units
28-
const size_t e = static_cast<size_t>(std::floor(s));
23+
const T s = (x[d] - origin[d]) / hr[d]; // To cell units
24+
const size_t e = static_cast<size_t>(Kokkos::floor(s));
2925
e_nd[d] = e;
3026
xi[d] = s - static_cast<T>(e);
3127
}
@@ -79,35 +75,41 @@ namespace ippl {
7975
// Mesh / layout (for locating + indexing into the field view)
8076
const mesh_type& mesh = f.get_mesh();
8177

82-
const ippl::FieldLayout<Dim>& layout = f.getLayout();
83-
const ippl::NDIndex<Dim>& lDom = layout.getLocalNDIndex();
84-
const int nghost = f.getNghost();
78+
const auto hr = mesh.getMeshSpacing();
79+
const auto origin = mesh.getOrigin();
80+
81+
const FieldLayout<Dim>& layout = f.getLayout();
82+
const NDIndex<Dim>& lDom = layout.getLocalNDIndex();
83+
const int nghost = f.getNghost();
8584

8685
// Particle attribute/device views
8786
auto d_attr = attrib.getView(); // scalar weight per particle (e.g. charge)
8887
auto d_pos = pp.getView(); // positions (Vector<T,Dim>) per particle
8988

89+
// make device copy of space
90+
auto device_space = space.getDeviceMirror();
91+
9092
Kokkos::parallel_for("assemble_rhs_from_particles_P1", iteration_policy,
9193
KOKKOS_LAMBDA(const size_t p) {
92-
const Vector<T,Dim> x = d_pos(p);
94+
const Vector<T, Dim> x = d_pos(p);
9395
const T val = d_attr(p);
9496

95-
Vector<size_t,Dim> e_nd;
96-
Vector<T,Dim> xi;
97-
locate_element_nd_and_xi<T,Dim>(mesh, x, e_nd, xi);
98-
// Convert to the element's linear index
99-
const size_t e_lin = space.getElementIndex(e_nd);
97+
Vector<size_t, Dim> e_nd;
98+
Vector<T, Dim> xi;
99+
100+
locate_element_nd_and_xi<T, Dim>(hr, origin, x, e_nd, xi);
100101

101102
// DOFs for this element
102-
const auto dofs = space.getGlobalDOFIndices(e_lin);
103+
const auto dofs = device_space.getGlobalDOFIndices(e_nd);
103104

104105
// Deposit into each vertex/DOF
105106
for (size_t a = 0; a < dofs.dim; ++a) {
106-
const size_t local = space.getLocalDOFIndex(e_lin, dofs[a]);
107-
const T w = space.evaluateRefElementShapeFunction(local, xi);
107+
const size_t local = device_space.getLocalDOFIndex(e_nd, dofs[a]);
108+
const T w = device_space.evaluateRefElementShapeFunction(local, xi);
108109

109-
const auto v_nd = space.getMeshVertexNDIndex(dofs[a]); // ND coords (global, vertex-centered)
110-
ippl::Vector<size_t,Dim> I; // indices into view
110+
// ND coords (global, vertex-centered)
111+
const auto v_nd = device_space.getMeshVertexNDIndex(dofs[a]);
112+
Vector<size_t, Dim> I; // indices into view
111113

112114
for (unsigned d = 0; d < Dim; ++d) {
113115
I[d] = static_cast<size_t>(v_nd[d] - lDom.first()[d] + nghost);
@@ -167,38 +169,41 @@ namespace ippl {
167169
IpplTimings::getTimer("interpolate_field_to_particles(P1)");
168170
IpplTimings::startTimer(timer);
169171

170-
view_type view = coeffs.getView();
171-
const mesh_type& M = coeffs.get_mesh();
172+
view_type view = coeffs.getView();
173+
const mesh_type& mesh = coeffs.get_mesh();
172174

175+
const auto hr = mesh.getMeshSpacing();
176+
const auto origin = mesh.getOrigin();
173177

174-
const ippl::FieldLayout<Dim>& layout = coeffs.getLayout();
175-
const ippl::NDIndex<Dim>& lDom = layout.getLocalNDIndex();
176-
const int nghost = coeffs.getNghost();
178+
const FieldLayout<Dim>& layout = coeffs.getLayout();
179+
const NDIndex<Dim>& lDom = layout.getLocalNDIndex();
180+
const int nghost = coeffs.getNghost();
177181

178182
// Particle device views
179183
auto d_pos = pp.getView();
180184
auto d_out = attrib_out.getView();
181185

186+
// make device copy of space
187+
auto device_space = space.getDeviceMirror();
182188
Kokkos::parallel_for("interpolate_to_diracs_P1", iteration_policy,
183189
KOKKOS_LAMBDA(const size_t p) {
184190

185-
const Vector<T,Dim> x = d_pos(p);
191+
const Vector<T, Dim> x = d_pos(p);
186192

187-
ippl::Vector<size_t,Dim> e_nd;
188-
ippl::Vector<T,Dim> xi;
189-
locate_element_nd_and_xi<T,Dim>(M, x, e_nd, xi);
190-
const size_t e_lin = space.getElementIndex(e_nd);
193+
Vector<size_t, Dim> e_nd;
194+
Vector<T, Dim> xi;
195+
locate_element_nd_and_xi<T, Dim>(hr, origin, x, e_nd, xi);
191196

192-
const auto dofs = space.getGlobalDOFIndices(e_lin);
197+
const auto dofs = device_space.getGlobalDOFIndices(e_nd);
193198

194199
field_value_type up = field_value_type(0);
195200

196201
for (size_t a = 0; a < dofs.dim; ++a) {
197-
const size_t local = space.getLocalDOFIndex(e_lin, dofs[a]);
198-
const field_value_type w = space.evaluateRefElementShapeFunction(local, xi);
202+
const size_t local = device_space.getLocalDOFIndex(e_nd, dofs[a]);
203+
const field_value_type w = device_space.evaluateRefElementShapeFunction(local, xi);
199204

200-
const auto v_nd = space.getMeshVertexNDIndex(dofs[a]);
201-
ippl::Vector<size_t,Dim> I;
205+
const auto v_nd = device_space.getMeshVertexNDIndex(dofs[a]);
206+
Vector<size_t, Dim> I;
202207
for (unsigned d = 0; d < Dim; ++d) {
203208
I[d] = static_cast<size_t>(v_nd[d] - lDom.first()[d] + nghost);
204209
}

src/FEM/LagrangeSpace.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,30 @@ namespace ippl {
267267
*/
268268
T computeAvg(const FieldLHS& u_h) const;
269269

270+
///////////////////////////////////////////////////////////////////////
271+
/// Device struct for copies //////////////////////////////////////////
272+
///////////////////////////////////////////////////////////////////////
273+
struct DeviceStruct {
274+
// members we need to copy for the following functions:
275+
// works since numElementDOFs in LagrangeSpace is static constexpr
276+
static constexpr unsigned numElementDOFs = LagrangeSpace::numElementDOFs;
277+
Vector<size_t, Dim> nr_m;
278+
ElementType ref_element_m;
279+
280+
// these are the functions needed for interpolation to the space
281+
KOKKOS_FUNCTION indices_t getMeshVertexNDIndex(const size_t& vertex_index) const;
282+
283+
KOKKOS_FUNCTION size_t getLocalDOFIndex(const indices_t& elementNDIndex,
284+
const size_t& globalDOFIndex) const;
285+
KOKKOS_FUNCTION Vector<size_t, numElementDOFs> getGlobalDOFIndices(
286+
const indices_t& elementNDIndex) const;
287+
288+
KOKKOS_FUNCTION T evaluateRefElementShapeFunction(const size_t& localDOF,
289+
const point_t& localPoint) const;
290+
};
291+
292+
DeviceStruct getDeviceMirror() const;
293+
270294
private:
271295
/**
272296
* @brief Check if a DOF is on the boundary of the mesh
@@ -285,6 +309,10 @@ namespace ippl {
285309
return false;
286310
}
287311

312+
///////////////////////////////////////////////////////////////////////
313+
/// Private member containing the element indices owned by ////////////
314+
/// my MPI rank. //////////////////////////////////////////////////////
315+
///////////////////////////////////////////////////////////////////////
288316
Kokkos::View<size_t*> elementIndices;
289317
};
290318

0 commit comments

Comments
 (0)