forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX]: Support jax.lax.select_n operation for JAX (openvinotoolkit#28025
) **Overview:** This pull request fixes openvinotoolkit#26570. **Testing:** - Tested the updated code. - Verified that other functionalities remain unaffected. ![Screenshot from 2024-12-12 00-09-17](https://github.com/user-attachments/assets/ae118efa-2047-4bac-a90d-8318396e5f63) **Dependencies:** - No dependencies on other pull requests. **CC:** - @rkazants --------- Signed-off-by: 11happy <[email protected]> Co-authored-by: Roman Kazantsev <[email protected]>
- Loading branch information
Showing
3 changed files
with
93 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/jax/node_context.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/gather_elements.hpp" | ||
#include "openvino/op/unsqueeze.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace jax { | ||
namespace op { | ||
|
||
OutputVector translate_select_n(const NodeContext& context) { | ||
num_inputs_check(context, 2); | ||
auto num_inputs = static_cast<int>(context.get_input_size()); | ||
Output<Node> which = context.get_input(0); | ||
if (which.get_element_type() == element::boolean) { | ||
which = std::make_shared<v0::Convert>(which, element::i32); | ||
} | ||
auto const_axis = ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0}); | ||
OutputVector unsqueezed_cases(num_inputs - 1); | ||
unsqueezed_cases.reserve(num_inputs - 1); | ||
for (int ind = 1; ind < num_inputs; ++ind) { | ||
auto case_input = context.get_input(ind); | ||
auto unsqueeze = std::make_shared<v0::Unsqueeze>(case_input, const_axis); | ||
unsqueezed_cases[ind - 1] = unsqueeze; | ||
} | ||
Output<Node> cases = std::make_shared<v0::Concat>(unsqueezed_cases, 0); | ||
which = | ||
std::make_shared<v0::Unsqueeze>(which, | ||
ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0})); | ||
Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0); | ||
return {result}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace jax | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
from jax import numpy as jnp | ||
|
||
from jax_layer_test_class import JaxLayerTest | ||
|
||
rng = np.random.default_rng(5402) | ||
|
||
|
||
class TestSelectN(JaxLayerTest): | ||
def _prepare_input(self): | ||
cases = [] | ||
if (self.case_num == 2): | ||
which = rng.choice([True, False], self.input_shape) | ||
else: | ||
which = rng.uniform(0, self.case_num, self.input_shape).astype(self.input_type) | ||
which = np.array(which) | ||
for i in range(self.case_num): | ||
cases.append(jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type))) | ||
cases = np.array(cases) | ||
return (which, cases) | ||
|
||
def create_model(self, input_shape, input_type, case_num): | ||
self.input_shape = input_shape | ||
self.input_type = input_type | ||
self.case_num = case_num | ||
|
||
def jax_select_n(which, cases): | ||
return jax.lax.select_n(which, *cases) | ||
|
||
return jax_select_n, None, 'select_n' | ||
|
||
@pytest.mark.parametrize("input_shape", [[], [1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]) | ||
@pytest.mark.parametrize("input_type", [np.int32, np.int64]) | ||
@pytest.mark.parametrize("case_num", [2, 3, 4]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit_jax_fe | ||
def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type, case_num): | ||
self._test(*self.create_model(input_shape, input_type, case_num), | ||
ie_device, precision, | ||
ir_version) |