Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add index select backward #359

Open
wants to merge 32 commits into
base: master
Choose a base branch
from

Conversation

AdvancedCompiler
Copy link
Contributor

PR Category

Operator

Type of Change

New Feature

Description

Implement index_select_backward operator

Issue

#316

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

index_select_backward_pref

Accuracy

index_select_backward_accuracy

@iclementine iclementine self-assigned this Dec 13, 2024
dim = dim % len(self_sizes)
grad_shape = list(grad.shape)
assert grad_shape[dim] == index_shape[0], "Index out of range"
grad = dim_compress(grad, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function dim_compress move the specified dim to the inner-most dimension and make a contiguous tensor, which is designed to be used in reduction/scan operator when other dimensions are considered batch dimensions and the reduction operation is performed on each 1d sub tensors, which will be loaded and iterated over. Thus being contiguous is good for this case.

For the backward of index_select, it is actually the opposite. Dimension dim is used as an indexing dimension. Then several n-1d subtensors are inserted into a zeros tensor. So, being contiguous on dimension dim is not what we want. We rather want it to be the outer-most dimension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok,we will change it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, there is a precision issue when calling atomic_add, and I'm not sure how to resolve it.

Bowen12992
Bowen12992 previously approved these changes Dec 19, 2024
Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CI coverage

batch_dim = [i for i in range(dim) if i not in dims]
sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True)
order = sorted_reduction_dim + batch_dim
return inp.permute(order).contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good.

Comment on lines 108 to 109
for i in index:
assert i >= 0 and i < self_sizes[dim], "Index out of range"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest removing this out-of-bound checking, since it involves slicing & compare for each index, which means a large overhead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


index_size = inp.size(dim)
index = torch.randint(0, index_size, [floor(index_size * 0.8)], device="cuda")
index = torch.unique(index)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there is duplicated index? I don't think current implementation can handle this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for duplicted index , torch.autograd.grad can handle this case, our design code also can handle. But for accuracy test, out_grad and index random generate ,not meet same index and value keep correspond's condition. So, we add the unique statement.;

index
1

out_grad
2

normally ,the index is duplicated,the correwponding value overwritten, and the result is correct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't see how it can handle duplicated indices. With duplicated indices, a sub-tensor is selected multiple times in the forward op, thus the corresponding gradient should accumulate into the corresponding input gradient.

There is no such accumulation in the kernel. I add a test, and it always fails.

@pytest.mark.index_select_backward
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dim", DIM_LIST)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_index_select_backward(shape, dim, dtype):
    inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
    ref_inp = to_reference(inp)
    from math import floor

    index_size = inp.size(dim)
    index = torch.tensor([0, 0, 0, 0], device="cuda")
    # index = torch.unique(index)
    if len(index) == 0:
        pass
    else:
        ref_index = to_reference(index)
        ref_out = torch.index_select(ref_inp, dim, ref_index)
        out_grad = torch.randn_like(ref_out)
        ref_grad = to_reference(out_grad)
        (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad)

        with flag_gems.use_gems():
            res_out = torch.index_select(inp, dim, index)
            (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
        res_out = to_reference(res_out)
        res_in_grad = to_reference(res_in_grad)
        gems_assert_equal(res_out, ref_out)
        gems_assert_equal(res_in_grad, ref_in_grad)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Comment on lines 133 to 134
tem_shape = list(grad.shape[1:])
tem_shape[-1] = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this shape tmp_shape intended to be ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe choose a more semantically meaningful name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

grad_off = (pid_x * num_blocks_per_CTA + i) * N + cols_offsets
out_off = (indices * num_blocks_per_CTA + i) * N + cols_offsets
selected = tl.load(grad + grad_off, mask=grad_mask, other=0.0)
tl.atomic_add(out + out_off, selected, mask=grad_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this kernel uses atomic_add and also autotune, you should add out to reset_to_zero to avoid it being added to many times.

)
yield inp, dim, index

bench = TensorSelectBenchmark(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to override

    def get_gbps(self, args, latency=None):
        # """Return the dynamic input iterator for each Operator."""
        raise NotImplementedError(
            "Each Benchmark must implement its own input iterator."
        )

for it to compute gbps metric.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants