-
Notifications
You must be signed in to change notification settings - Fork 19.6k
My fix branch #21585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
My fix branch #21585
Changes from all commits
0f77687
ea0a40f
df028d0
1635157
f450869
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import openvino.opset14 as ov_opset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from openvino import Type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from openvino.runtime import opset13 as ov | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from keras.src.backend import config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from keras.src.backend.common import dtypes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -17,6 +18,109 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from keras.src.backend.openvino.core import ov_to_keras_type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def diagonal(x, offset=0, axis1=0, axis2=1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_node = ov.constant(x) # -> ov.Node | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
offset_const = ov_opset.constant(int(offset), dtype="i64") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# rank & normalize axes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
shape = ov_opset.shape_of(x_node) # i64 vector | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rank = ov_opset.shape_of(shape) # scalar i64 (len of shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rank_val = ov_opset.squeeze(rank) # [] -> scalar | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
axis1_node = ov_opset.floor_mod( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rank_val, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
axis2_node = ov_opset.floor_mod( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rank_val, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+22
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation of To fix this and align with the style of other functions in this backend, please use the
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
arange = ov_opset.range( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.constant(0, dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rank_val, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.constant(1, dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mask1 = ov_opset.equal(arange, axis1_node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mask2 = ov_opset.equal(arange, axis2_node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
not12 = ov_opset.logical_not(ov_opset.logical_or(mask1, mask2)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
others = ov_opset.squeeze( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.non_zero(not12), [1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) # gather positions != axis1, axis2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
perm = ov_opset.concat( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
others, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.reshape(axis1_node, [1]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.reshape(axis2_node, [1]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_perm = ov_opset.transpose(x_node, perm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
permuted_shape = ov_opset.shape_of(x_perm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
d1 = ov_opset.gather( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
permuted_shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.constant([-2], dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.constant(0, dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
d2 = ov_opset.gather( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
permuted_shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.constant([-1], dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.constant(0, dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
d1 = ov_opset.squeeze(d1) # scalar | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
d2 = ov_opset.squeeze(d2) # scalar | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# start1 = max(0, offset), start2 = max(0, -offset) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
zero = ov_opset.constant(0, dtype="i64") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
start1 = ov_opset.maximum(zero, offset_const) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
start2 = ov_opset.maximum(zero, ov_opset.negative(offset_const)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# L = min(d1 - start1, d2 - start2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
l1 = ov_opset.subtract(d1, start1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
l2 = ov_opset.subtract(d2, start2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
L = ov_opset.minimum(l1, l2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# r = range(0, L, 1) -> shape [L] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
r = ov_opset.range(zero, L, ov_opset.constant(1, dtype="i64")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
idx_row = ov_opset.add(r, start1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
idx_col = ov_opset.add(r, start2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
idx_row = ov_opset.unsqueeze( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
idx_row, ov_opset.constant(1, dtype="i64") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) # [L,1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
idx_col = ov_opset.unsqueeze( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
idx_col, ov_opset.constant(1, dtype="i64") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) # [L,1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
diag_idx = ov_opset.concat([idx_row, idx_col], 1) # [L,2] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Broadcast indices to batch dims: target shape = (*batch, L, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# batch_rank = rank(x) - 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
two = ov_opset.constant(2, dtype="i64") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_rank = ov_opset.subtract(rank_val, two) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# build target shape: concat(permuted_shape[:batch_rank], [L, 2]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_shape = ov_opset.strided_slice( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
permuted_shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
begin=ov_opset.constant([0], dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end=ov_opset.reshape(batch_rank, [1]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
strides=ov_opset.constant([1], dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
begin_mask=[0], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end_mask=[0], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
target_shape = ov_opset.concat( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.reshape(L, [1]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ov_opset.constant([2], dtype="i64"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bcast_idx = ov_opset.broadcast(diag_idx, target_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# GatherND with batch_dims = batch_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
gathered = ov_opset.gather_nd(x_perm, bcast_idx, batch_rank) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return OpenVINOKerasTensor(gathered) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def add(x1, x2): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
element_type = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(x1, OpenVINOKerasTensor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -677,12 +781,6 @@ def diag(x, k=0): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError("`diag` is not supported with openvino backend") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def diagonal(x, offset=0, axis1=0, axis2=1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"`diagonal` is not supported with openvino backend" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def diff(a, n=1, axis=-1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if n == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return OpenVINOKerasTensor(get_ov_output(a)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import from
opset13
is inconsistent with the rest of the file, which usesopset14
. By applying the suggestion to useget_ov_output
on line 22, this import will become unnecessary and can be removed.