Skip to content

Commit

Permalink
[JAX]: Support jax.lax.select_n operation for JAX (openvinotoolkit#28025
Browse files Browse the repository at this point in the history
)

**Overview:**
This pull request fixes openvinotoolkit#26570.

**Testing:**
- Tested the updated code.
- Verified that other functionalities remain unaffected.
![Screenshot from 2024-12-12
00-09-17](https://github.com/user-attachments/assets/ae118efa-2047-4bac-a90d-8318396e5f63)

**Dependencies:**
- No dependencies on other pull requests.

**CC:**
- @rkazants

---------

Signed-off-by: 11happy <[email protected]>
Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
11happy and rkazants authored Dec 31, 2024
1 parent ca501ca commit 362f073
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/frontends/jax/src/op/select_n.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather_elements.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace jax {
namespace op {

OutputVector translate_select_n(const NodeContext& context) {
num_inputs_check(context, 2);
auto num_inputs = static_cast<int>(context.get_input_size());
Output<Node> which = context.get_input(0);
if (which.get_element_type() == element::boolean) {
which = std::make_shared<v0::Convert>(which, element::i32);
}
auto const_axis = ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0});
OutputVector unsqueezed_cases(num_inputs - 1);
unsqueezed_cases.reserve(num_inputs - 1);
for (int ind = 1; ind < num_inputs; ++ind) {
auto case_input = context.get_input(ind);
auto unsqueeze = std::make_shared<v0::Unsqueeze>(case_input, const_axis);
unsqueezed_cases[ind - 1] = unsqueeze;
}
Output<Node> cases = std::make_shared<v0::Concat>(unsqueezed_cases, 0);
which =
std::make_shared<v0::Unsqueeze>(which,
ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0}));
Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0);
return {result};
};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ OP_CONVERTER(translate_reduce_window_max);
OP_CONVERTER(translate_reduce_window_sum);
OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_select_n);
OP_CONVERTER(translate_slice);
OP_CONVERTER(translate_square);
OP_CONVERTER(translate_squeeze);
Expand Down Expand Up @@ -92,6 +93,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"transpose", op::translate_transpose},
{"rsqrt", op::translate_rsqrt},
{"reshape", op::translate_reshape},
{"select_n", op::translate_select_n},
{"slice", op::translate_slice},
{"square", op::translate_square},
{"sqrt", op::translate_1to1_match_1_input<v0::Sqrt>},
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/jax_tests/test_select_n.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(5402)


class TestSelectN(JaxLayerTest):
def _prepare_input(self):
cases = []
if (self.case_num == 2):
which = rng.choice([True, False], self.input_shape)
else:
which = rng.uniform(0, self.case_num, self.input_shape).astype(self.input_type)
which = np.array(which)
for i in range(self.case_num):
cases.append(jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type)))
cases = np.array(cases)
return (which, cases)

def create_model(self, input_shape, input_type, case_num):
self.input_shape = input_shape
self.input_type = input_type
self.case_num = case_num

def jax_select_n(which, cases):
return jax.lax.select_n(which, *cases)

return jax_select_n, None, 'select_n'

@pytest.mark.parametrize("input_shape", [[], [1], [2, 3], [4, 5, 6], [7, 8, 9, 10]])
@pytest.mark.parametrize("input_type", [np.int32, np.int64])
@pytest.mark.parametrize("case_num", [2, 3, 4])
@pytest.mark.nightly
@pytest.mark.precommit_jax_fe
def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type, case_num):
self._test(*self.create_model(input_shape, input_type, case_num),
ie_device, precision,
ir_version)

0 comments on commit 362f073

Please sign in to comment.