-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharithmetic.cpp
48 lines (41 loc) · 1.15 KB
/
arithmetic.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include "gradstudent/internal/utils.h"
#include "gradstudent/iter.h"
#include "gradstudent/ops.h"
namespace gs {
Tensor operator+(const Tensor &left, const Tensor &right) {
auto [bleft, bright] = broadcast(left, right);
Tensor result(bleft.shape());
for (const auto &[res, lt, rt] : TensorIter(result, bleft, bright)) {
res = lt + rt;
}
return result;
}
Tensor operator*(const Tensor &left, const Tensor &right) {
auto [bleft, bright] = broadcast(left, right);
Tensor result(bleft.shape());
for (const auto &[res, lt, rt] : TensorIter(result, bleft, bright)) {
res = lt * rt;
}
return result;
}
Tensor operator-(const Tensor &tensor) {
Tensor result(tensor.shape());
for (const auto &[res, val] : TensorIter(result, tensor)) {
res = -val;
}
return result;
}
Tensor operator-(const Tensor &left, const Tensor &right) {
return left + (-right);
}
bool operator==(const Tensor &left, const Tensor &right) {
checkCompatibleShape(left, right);
// NOLINTNEXTLINE(readability-use-anyofallof)
for (const auto &[lt, rt] : TensorIter(left, right)) {
if (lt != rt) {
return false;
}
}
return true;
}
} // namespace gs