Skip to content

Commit 1ba6970

Browse files
authored
[compute/cker] Fix RMSNorm shape assert error (#14247)
* [compute/cker] Fix RMSNorm shape assert error This commit fixes shape assert error when running model including RMSNorm operation. ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected] * [compute/cker] Add RMSNorm unittests Unit test added for RMSNorm to test rank 3 input. ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
1 parent 9c49d99 commit 1ba6970

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

compute/cker/include/cker/operation/RmsNorm.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ inline void RmsNorm(const RmsNormParams &params, const Shape &input_shape, const
6868
}
6969
else if (input_shape.DimensionsCount() == 3)
7070
{
71-
const int32_t heights = MatchingDim(input_shape, 1, output_shape, 0);
72-
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 1);
73-
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 2);
71+
const int32_t heights = MatchingDim(input_shape, 0, output_shape, 0);
72+
const int32_t widths = MatchingDim(input_shape, 1, output_shape, 1);
73+
const int32_t channels = MatchingDim(input_shape, 2, output_shape, 2);
7474

7575
for (int32_t height = 0; height < heights; height++)
7676
{

compute/cker/src/RmsNorm.test.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ TEST(CKer_Operation, RmsNorm)
4343
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
4444
}
4545

46-
// Default gamma
46+
// rank 4
4747
{
4848
std::vector<float> input = {0, 1, 2, 3, 4, 5, 6, 7};
4949
nnfw::cker::Shape input_shape{1, 2, 2, 2};
@@ -65,6 +65,29 @@ TEST(CKer_Operation, RmsNorm)
6565
for (size_t i = 0; i < expected_output.size(); ++i)
6666
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
6767
}
68+
69+
// rank 3
70+
{
71+
std::vector<float> input = {0, 1, 2, 3, 4, 5, 6, 7};
72+
nnfw::cker::Shape input_shape{2, 2, 2};
73+
74+
std::vector<float> expected_output = {0, 1.412802, 0.784404, 1.176606,
75+
0.883431, 1.104288, 0.920347, 1.073738};
76+
std::vector<float> output(expected_output.size());
77+
nnfw::cker::Shape output_shape{2, 2, 2};
78+
79+
std::vector<float> gamma = {1, 1};
80+
nnfw::cker::Shape gamma_shape{2};
81+
82+
nnfw::cker::RmsNormParams param;
83+
param.epsilon = 0.001f;
84+
85+
nnfw::cker::RmsNorm(param, input_shape, input.data(), gamma_shape, gamma.data(), output_shape,
86+
output.data());
87+
88+
for (size_t i = 0; i < expected_output.size(); ++i)
89+
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
90+
}
6891
}
6992

7093
TEST(CKer_Operation, neg_RmsNormWrongInputDims)

0 commit comments

Comments
 (0)