-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add index select backward #359
base: master
Are you sure you want to change the base?
Add index select backward #359
Conversation
src/flag_gems/ops/index_select.py
Outdated
dim = dim % len(self_sizes) | ||
grad_shape = list(grad.shape) | ||
assert grad_shape[dim] == index_shape[0], "Index out of range" | ||
grad = dim_compress(grad, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function dim_compress
move the specified dim to the inner-most dimension and make a contiguous tensor, which is designed to be used in reduction/scan operator when other dimensions are considered batch dimensions and the reduction operation is performed on each 1d sub tensors, which will be loaded and iterated over. Thus being contiguous is good for this case.
For the backward of index_select
, it is actually the opposite. Dimension dim
is used as an indexing dimension. Then several n-1
d subtensors are inserted into a zeros tensor. So, being contiguous on dimension dim
is not what we want. We rather want it to be the outer-most dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok,we will change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, there is a precision issue when calling atomic_add, and I'm not sure how to resolve it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for CI coverage
batch_dim = [i for i in range(dim) if i not in dims] | ||
sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True) | ||
order = sorted_reduction_dim + batch_dim | ||
return inp.permute(order).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good.
src/flag_gems/ops/index_select.py
Outdated
for i in index: | ||
assert i >= 0 and i < self_sizes[dim], "Index out of range" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest removing this out-of-bound checking, since it involves slicing & compare for each index, which means a large overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
tests/test_reduction_ops.py
Outdated
|
||
index_size = inp.size(dim) | ||
index = torch.randint(0, index_size, [floor(index_size * 0.8)], device="cuda") | ||
index = torch.unique(index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if there is duplicated index? I don't think current implementation can handle this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for duplicted index , torch.autograd.grad can handle this case, our design code also can handle. But for accuracy test, out_grad and index random generate ,not meet same index and value keep correspond's condition. So, we add the unique statement.;
normally ,the index is duplicated,the correwponding value overwritten, and the result is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I don't see how it can handle duplicated indices. With duplicated indices, a sub-tensor is selected multiple times in the forward op, thus the corresponding gradient should accumulate into the corresponding input gradient.
There is no such accumulation in the kernel. I add a test, and it always fails.
@pytest.mark.index_select_backward
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dim", DIM_LIST)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_index_select_backward(shape, dim, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
ref_inp = to_reference(inp)
from math import floor
index_size = inp.size(dim)
index = torch.tensor([0, 0, 0, 0], device="cuda")
# index = torch.unique(index)
if len(index) == 0:
pass
else:
ref_index = to_reference(index)
ref_out = torch.index_select(ref_inp, dim, ref_index)
out_grad = torch.randn_like(ref_out)
ref_grad = to_reference(out_grad)
(ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad)
with flag_gems.use_gems():
res_out = torch.index_select(inp, dim, index)
(res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
res_out = to_reference(res_out)
res_in_grad = to_reference(res_in_grad)
gems_assert_equal(res_out, ref_out)
gems_assert_equal(res_in_grad, ref_in_grad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
src/flag_gems/ops/index_select.py
Outdated
tem_shape = list(grad.shape[1:]) | ||
tem_shape[-1] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this shape tmp_shape
intended to be ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe choose a more semantically meaningful name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
grad_off = (pid_x * num_blocks_per_CTA + i) * N + cols_offsets | ||
out_off = (indices * num_blocks_per_CTA + i) * N + cols_offsets | ||
selected = tl.load(grad + grad_off, mask=grad_mask, other=0.0) | ||
tl.atomic_add(out + out_off, selected, mask=grad_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this kernel uses atomic_add
and also autotune
, you should add out
to reset_to_zero
to avoid it being added to many times.
) | ||
yield inp, dim, index | ||
|
||
bench = TensorSelectBenchmark( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to override
def get_gbps(self, args, latency=None):
# """Return the dynamic input iterator for each Operator."""
raise NotImplementedError(
"Each Benchmark must implement its own input iterator."
)
for it to compute gbps metric.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
PR Category
Operator
Type of Change
New Feature
Description
Implement index_select_backward operator
Issue
#316
Progress
Performance
Accuracy