Skip to content

Commit

Permalink
Add distributed_mdarray extent method (#568)
Browse files Browse the repository at this point in the history
* add distributed_mdarray.extent method

* wave equation: use mdarray.extent method
  • Loading branch information
tkarna authored Oct 2, 2023
1 parent e5c3a08 commit 0ecaef9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 36 deletions.
48 changes: 12 additions & 36 deletions benchmarks/gbench/mhp/wave_equation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ void rhs(Array &u, Array &v, Array &e, Array &dudt, Array &dvdt, Array &dedt,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0) - 1),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0) - 1, e.extent(1)};
auto e_view = dr::mhp::views::submdspan(e.view(), start, end);
auto dudt_view = dr::mhp::views::submdspan(dudt.view(), start, end);
dr::mhp::stencil_for_each(rhs_dedx, e_view, dudt_view);
Expand All @@ -95,9 +93,7 @@ void rhs(Array &u, Array &v, Array &e, Array &dudt, Array &dvdt, Array &dedt,
};
{
std::array<std::size_t, 2> start{0, 1};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0)),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0), e.extent(1)};
auto e_view = dr::mhp::views::submdspan(e.view(), start, end);
auto dvdt_view = dr::mhp::views::submdspan(dvdt.view(), start, end);
dr::mhp::stencil_for_each(rhs_dedy, e_view, dvdt_view);
Expand All @@ -113,9 +109,7 @@ void rhs(Array &u, Array &v, Array &e, Array &dudt, Array &dvdt, Array &dedt,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(u.mdspan().extent(0)),
static_cast<std::size_t>(u.mdspan().extent(1))};
std::array<std::size_t, 2> end{u.extent(0), u.extent(1)};
auto u_view = dr::mhp::views::submdspan(u.view(), start, end);
auto v_view = dr::mhp::views::submdspan(v.view(), start, end);
auto dedt_view = dr::mhp::views::submdspan(dedt.view(), start, end);
Expand All @@ -140,9 +134,7 @@ void stage1(Array &u, Array &v, Array &e, Array &u1, Array &v1, Array &e1,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0) - 1),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0) - 1, e.extent(1)};
auto e_view = dr::mhp::views::submdspan(e.view(), start, end);
auto u_view = dr::mhp::views::submdspan(u.view(), start, end);
auto u1_view = dr::mhp::views::submdspan(u1.view(), start, end);
Expand All @@ -158,9 +150,7 @@ void stage1(Array &u, Array &v, Array &e, Array &u1, Array &v1, Array &e1,
};
{
std::array<std::size_t, 2> start{0, 1};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0)),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0), e.extent(1)};
auto e_view = dr::mhp::views::submdspan(e.view(), start, end);
auto v_view = dr::mhp::views::submdspan(v.view(), start, end);
auto v1_view = dr::mhp::views::submdspan(v1.view(), start, end);
Expand All @@ -179,9 +169,7 @@ void stage1(Array &u, Array &v, Array &e, Array &u1, Array &v1, Array &e1,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(u.mdspan().extent(0)),
static_cast<std::size_t>(u.mdspan().extent(1))};
std::array<std::size_t, 2> end{u.extent(0), u.extent(1)};
auto e_view = dr::mhp::views::submdspan(e.view(), start, end);
auto u_view = dr::mhp::views::submdspan(u.view(), start, end);
auto v_view = dr::mhp::views::submdspan(v.view(), start, end);
Expand Down Expand Up @@ -209,9 +197,7 @@ void stage2(Array &u, Array &v, Array &e, Array &u1, Array &v1, Array &e1,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0) - 1),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0) - 1, e.extent(1)};
auto e1_view = dr::mhp::views::submdspan(e1.view(), start, end);
auto u1_view = dr::mhp::views::submdspan(u1.view(), start, end);
auto u_view = dr::mhp::views::submdspan(u.view(), start, end);
Expand All @@ -228,9 +214,7 @@ void stage2(Array &u, Array &v, Array &e, Array &u1, Array &v1, Array &e1,
};
{
std::array<std::size_t, 2> start{0, 1};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0)),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0), e.extent(1)};
auto e1_view = dr::mhp::views::submdspan(e1.view(), start, end);
auto v1_view = dr::mhp::views::submdspan(v1.view(), start, end);
auto v_view = dr::mhp::views::submdspan(v.view(), start, end);
Expand All @@ -250,9 +234,7 @@ void stage2(Array &u, Array &v, Array &e, Array &u1, Array &v1, Array &e1,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(u.mdspan().extent(0)),
static_cast<std::size_t>(u.mdspan().extent(1))};
std::array<std::size_t, 2> end{u.extent(0), u.extent(1)};
auto e1_view = dr::mhp::views::submdspan(e1.view(), start, end);
auto u1_view = dr::mhp::views::submdspan(u1.view(), start, end);
auto v1_view = dr::mhp::views::submdspan(v1.view(), start, end);
Expand Down Expand Up @@ -282,9 +264,7 @@ void stage3(Array &u, Array &v, Array &e, Array &u2, Array &v2, Array &e2,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0) - 1),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0) - 1, e.extent(1)};
auto e2_view = dr::mhp::views::submdspan(e2.view(), start, end);
auto u2_view = dr::mhp::views::submdspan(u2.view(), start, end);
auto u_view = dr::mhp::views::submdspan(u.view(), start, end);
Expand All @@ -301,9 +281,7 @@ void stage3(Array &u, Array &v, Array &e, Array &u2, Array &v2, Array &e2,
};
{
std::array<std::size_t, 2> start{0, 1};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(e.mdspan().extent(0)),
static_cast<std::size_t>(e.mdspan().extent(1))};
std::array<std::size_t, 2> end{e.extent(0), e.extent(1)};
auto e2_view = dr::mhp::views::submdspan(e2.view(), start, end);
auto v2_view = dr::mhp::views::submdspan(v2.view(), start, end);
auto v_view = dr::mhp::views::submdspan(v.view(), start, end);
Expand All @@ -323,9 +301,7 @@ void stage3(Array &u, Array &v, Array &e, Array &u2, Array &v2, Array &e2,
};
{
std::array<std::size_t, 2> start{1, 0};
std::array<std::size_t, 2> end{
static_cast<std::size_t>(u.mdspan().extent(0)),
static_cast<std::size_t>(u.mdspan().extent(1))};
std::array<std::size_t, 2> end{u.extent(0), u.extent(1)};
auto e2_view = dr::mhp::views::submdspan(e2.view(), start, end);
auto u2_view = dr::mhp::views::submdspan(u2.view(), start, end);
auto v2_view = dr::mhp::views::submdspan(v2.view(), start, end);
Expand Down
1 change: 1 addition & 0 deletions include/dr/mhp/containers/distributed_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ template <typename T, std::size_t Rank> class distributed_mdarray {
auto &halo() const { return dr::mhp::halo(dv_); }

auto mdspan() const { return md_view_.mdspan(); }
auto extent(std::size_t r) const { return mdspan().extent(r); }
auto grid() { return md_view_.grid(); }
auto view() const { return md_view_; }

Expand Down

0 comments on commit 0ecaef9

Please sign in to comment.