-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor to use thrust::reduce on any. #685
base: main
Are you sure you want to change the base?
Changes from 4 commits
d0f3789
6dc560a
15358ec
350b292
55d78c0
b8ea6c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
|
||
#pragma once | ||
|
||
#include <thrust/reduce.h> | ||
|
||
#include "matx/core/type_utils.h" | ||
#include "matx/operators/base_operator.h" | ||
|
@@ -40,8 +41,6 @@ | |
|
||
namespace matx { | ||
|
||
|
||
|
||
namespace detail { | ||
template<typename OpA, int ORank> | ||
class AnyOp : public BaseOp<AnyOp<OpA, ORank>> | ||
|
@@ -71,8 +70,17 @@ namespace detail { | |
}; | ||
|
||
template <typename Out, typename Executor> | ||
void Exec(Out &&out, Executor &&ex) const { | ||
any_impl(cuda::std::get<0>(out), a_, ex); | ||
void Exec(Out &&out, Executor) const { | ||
auto output_ = cuda::std::get<0>(out); | ||
using out_t = decltype(output_); | ||
using value_t = typename out_t::value_type; | ||
ZelboK marked this conversation as resolved.
Show resolved
Hide resolved
|
||
using output_t = typename detail::base_type_t<out_t>; | ||
|
||
output_t out_base = output_; | ||
auto op = detail::reduceOpAny<value_t>(); | ||
|
||
auto rv = ReduceInputThrust(std::forward<OpA>(a_), std::forward<out_t>(output_), std::forward<decltype(op)>(op)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We construct the |
||
MATX_ASSERT_STR_EXP(rv, cudaSuccess, matxCudaError, "Error in any"); | ||
} | ||
|
||
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -798,17 +798,24 @@ template <typename T> class reduceOpMax { | |
* Performs a reduction of two values of type T by returning 1 if either | ||
* of the values are non-zero. | ||
*/ | ||
template <typename T> class reduceOpAny { | ||
template <typename T> | ||
class reduceOpAny { | ||
public: | ||
using type = T; // This type is for Thrust | ||
using matx_reduce = bool; | ||
using matx_no_cub_reduce = bool; // Don't use CUB for this reduction type | ||
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T Reduce(const T &v1, const T &v2) | ||
{ | ||
|
||
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T operator()(const T &v1, const T &v2) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needs to be const for Thrust. |
||
return (v1 != 0) || (v2 != 0); | ||
} | ||
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T operator()(T &v1, T &v2) { v1 = ((v1 != 0) || (v2 != 0)); return v1; } | ||
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T Init() { return (T)(0); } | ||
__MATX_DEVICE__ __MATX_INLINE__ void atomicReduce(T *addr, T val) { atomicAny(addr, val); } | ||
|
||
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T Init() const { | ||
return static_cast<T>(0); | ||
} | ||
|
||
__MATX_DEVICE__ __MATX_INLINE__ void atomicReduce(T *addr, T val) const { | ||
atomicAny(addr, val); | ||
} | ||
}; | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cliffburdick
So this function is responsible for dealing with the facct that
In
may not necessarily be atensor_t
but rather an operator of some sort. It'll get the respective offsets and construct an iterator for thrust to use, which, thankfully, does seem to be perfectly compatible.Consequentially it would seem the code has also become potentially simpler than it's counterpart on
main
. Thrust is now responsible for deciding whom exactly to use in CUB rather than MatX. I might be missing some context though on whether or not MatX needing to be responsible for what function to use in CUB/Thrust.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In hindsight, I realize now that there might be issues with this approach. Does
matXBinaryOp
for example need to be utilized? I see that it has methods likePreRun
PostRun
etc that make use of the Executor(which I have used here)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ZelboK since a tensor is an operator, the iterator wrapper can turn any MatX operator type into an iterator. However, we do the contiguous check to allow CUB/thrust an optimization if it's a flat pointer with contiguous strides.
Regarding your second comment, the
ReduceInput
function shouldn't/doesn't need to know whether it's CUB or thrust. That's something you pass in your lambda that you give to the function. I'm saying this all without actually trying it, but if you assume the iterator type is compatible between both libraries then there may not be any changes to that code. This was how the example I pointed to previously worked (just search ReduceInput in cub.h).I'm not sure what your question about
matxBinaryOp
is, but that's a wrapper class for any binary type and should be completely separate from this.matxBinaryOp
, like most of our types is an operator, and you can pass it to thrust/cub and have the iterator pull from it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cliffburdick I see. I should have clarified - I actually did try to follow that example you've shown me like this(but I did it incorrectly). If you're curious, here's what I did.
which brought me to my question of
matXBinaryOp
. This led to errors likewhere
InputIterator
was of typeI should not be passing the base types though. Let me push the fix. Thanks for the help.