Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL Spec][Joint Matrix] Add a new overload for joint_matrix_apply to be able to return result into a different matrix #13153

Merged
merged 6 commits into from
Sep 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,14 @@ of the link:sycl_ext_intel_matrix.asciidoc[sycl_ext_intel_matrix]

Besides the `Group` and the `joint_matrix` arguments,
`joint_matrix_apply` takes a C++ Callable object which is invoked once
for each element of the matrix. This callable object must be invocable
with a single parameter of type `T&`. Commonly, applications pass a
lambda expression.
for each element of the matrix. There are two cases: (1) one matrix is
passed, (2) two matrices are passed.

===== Unary Operation
In this case, `joint_matrix_apply` takes one `joint_matrix`
argument. The callable object must be invocable with a single
parameter of type `T&`. Commonly, applications pass a lambda
expression.

```c++
namespace sycl::ext::oneapi::experimental::matrix {
Expand All @@ -427,6 +432,39 @@ joint_matrix_apply(sg, C, [=](T &x) {
});
```

===== Binary Operation
In this case, `joint_matrix_apply` takes two `joint_matrix` arguments:
`jm0` and `jm1` that have the same `use`, number of rows, number of
columns, and `layout`. `jm0` and `jm1` can be read-only, write-only,
or read and write arguments. The callable object must be invocable
with two parameters `x` and `y` of types `T0&` amd `T1&`, where `x` is
an element from `jm0` and `y` is an element from `jm1`. Moreover, `x`
and `y` are guaranteed to have identical coordinates in their
respective matrices. Commonly, applications pass a lambda expression.

```c++
namespace sycl::ext::oneapi::experimental::matrix {

template<typename Group, typename T0, typename T1, use Use,
size_t Rows, size_t Cols, layout Layout, typename F>
void joint_matrix_apply(Group g,
joint_matrix<Group, T0, Use, Rows, Cols, Layout>& jm0,
joint_matrix<Group, T1, Use, Rows, Cols, Layout>& jm1,
F&& func);

} // namespace sycl::ext::oneapi::experimental::matrix
```

In the following example, every element `x` of the matrix `C` is
multiplied by `alpha`. The result is returned into the element `y` of
the matrix `D`.

```c++
joint_matrix_apply(sg, C, D, [=](const T &x, T &y) {
y = x * alpha;
});
```
gmlueck marked this conversation as resolved.
Show resolved Hide resolved

==== Prefetch

```c++
Expand Down