Skip to content

Commit

Permalink
Fix an int conversion error (#1325)
Browse files Browse the repository at this point in the history
fix an int conversion error

Signed-off-by: Jennifer Zhou <[email protected]>
  • Loading branch information
jennifgcrl authored Nov 13, 2024
1 parent 237b493 commit 943f1e0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/jax/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act
auto *output = output_buf->untyped_data();

auto act_input_dims = act_input_buf.dimensions();
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = act_input_dims.back();
auto m = static_cast<size_t>(product(act_input_dims, 0, act_input_dims.size() - 2));
auto n = static_cast<size_t>(act_input_dims.back());
auto act_len = act_input_dims.end()[-2];

auto input_shape = std::vector<size_t>{m, n};
Expand Down

0 comments on commit 943f1e0

Please sign in to comment.