Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add argmin (#318) #346

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

wyjoutstanding
Copy link

@wyjoutstanding wyjoutstanding commented Dec 7, 2024

PR Category

Operator, OP Test, Benchmark

Type of Change

New Feature

Description

Support argmin, detail see #318

Issue

Resolves #318

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

A100 test result:

Operator: argmin Performance Test (dtype=torch.float16, mode=cuda, level=comprehensive)
Size Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.019456 0.013312 1.462 [torch.Size([1048576])]
SUCCESS 0.010240 0.010240 1.000 [torch.Size([64, 64])]
SUCCESS 0.055296 0.043008 1.286 [torch.Size([4096, 4096])]
SUCCESS 0.055296 0.043008 1.286 [torch.Size([64, 512, 512])]
SUCCESS 1.888256 1.249280 1.511 [torch.Size([1024, 1024, 1024])]
SUCCESS 0.007168 0.009216 0.778 [torch.Size([4])]
SUCCESS 0.008192 0.009216 0.889 [torch.Size([1024])]
SUCCESS 1.886208 1.249280 1.510 [torch.Size([1073741824])]
SUCCESS 0.008192 0.009216 0.889 [torch.Size([1024, 1])]
SUCCESS 0.014336 0.010240 1.400 [torch.Size([1024, 16])]
SUCCESS 0.017408 0.010240 1.700 [torch.Size([1024, 256])]
SUCCESS 0.026624 0.018432 1.444 [torch.Size([1024, 4096])]
SUCCESS 0.158720 0.106496 1.490 [torch.Size([1024, 65536])]
SUCCESS 0.009216 0.010240 0.900 [torch.Size([64, 64, 1])]
SUCCESS 0.032768 0.010240 3.200 [torch.Size([64, 64, 16])]
SUCCESS 0.019456 0.012288 1.583 [torch.Size([64, 64, 256])]
SUCCESS 0.055296 0.043008 1.286 [torch.Size([64, 64, 4096])]

Operator: argmin Performance Test (dtype=torch.float32, mode=cuda, level=comprehensive)
Size Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.020480 0.014336 1.429 [torch.Size([1048576])]
SUCCESS 0.010240 0.009216 1.111 [torch.Size([64, 64])]
SUCCESS 0.074752 0.066560 1.123 [torch.Size([4096, 4096])]
SUCCESS 0.074752 0.066560 1.123 [torch.Size([64, 512, 512])]
SUCCESS 2.638848 2.404352 1.098 [torch.Size([1024, 1024, 1024])]
SUCCESS 0.007168 0.009216 0.778 [torch.Size([4])]
SUCCESS 0.008192 0.009216 0.889 [torch.Size([1024])]
SUCCESS 2.638848 2.402304 1.098 [torch.Size([1073741824])]
SUCCESS 0.008192 0.009216 0.889 [torch.Size([1024, 1])]
SUCCESS 0.015360 0.010240 1.500 [torch.Size([1024, 16])]
SUCCESS 0.017408 0.011264 1.545 [torch.Size([1024, 256])]
SUCCESS 0.033792 0.027648 1.222 [torch.Size([1024, 4096])]
SUCCESS 0.214016 0.178176 1.201 [torch.Size([1024, 65536])]
SUCCESS 0.009216 0.009216 1.000 [torch.Size([64, 64, 1])]
SUCCESS 0.033792 0.010240 3.300 [torch.Size([64, 64, 16])]
SUCCESS 0.019456 0.014336 1.357 [torch.Size([64, 64, 256])]
SUCCESS 0.074752 0.066560 1.123 [torch.Size([64, 64, 4096])]

Operator: argmin Performance Test (dtype=torch.bfloat16, mode=cuda, level=comprehensive)
Size Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.019456 0.013312 1.462 [torch.Size([1048576])]
SUCCESS 0.010240 0.010240 1.000 [torch.Size([64, 64])]
SUCCESS 0.056320 0.044032 1.279 [torch.Size([4096, 4096])]
SUCCESS 0.056320 0.044032 1.279 [torch.Size([64, 512, 512])]
SUCCESS 1.954816 1.280000 1.527 [torch.Size([1024, 1024, 1024])]
SUCCESS 0.007168 0.009216 0.778 [torch.Size([4])]
SUCCESS 0.008192 0.009216 0.889 [torch.Size([1024])]
SUCCESS 1.972224 1.277952 1.543 [torch.Size([1073741824])]
SUCCESS 0.008192 0.009216 0.889 [torch.Size([1024, 1])]
SUCCESS 0.014336 0.010240 1.400 [torch.Size([1024, 16])]
SUCCESS 0.017408 0.010240 1.700 [torch.Size([1024, 256])]
SUCCESS 0.026624 0.019456 1.368 [torch.Size([1024, 4096])]
SUCCESS 0.162816 0.106496 1.529 [torch.Size([1024, 65536])]
SUCCESS 0.009216 0.010240 0.900 [torch.Size([64, 64, 1])]
SUCCESS 0.033792 0.010240 3.300 [torch.Size([64, 64, 16])]
SUCCESS 0.019456 0.012288 1.583 [torch.Size([64, 64, 256])]
SUCCESS 0.056320 0.044032 1.279 [torch.Size([64, 64, 4096])]

@iclementine iclementine self-assigned this Dec 9, 2024
import triton.language as tl

from ..utils import libentry
from ..utils.shape_utils import can_use_int32_index
Copy link
Collaborator

@iclementine iclementine Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the new changes in argmax and use triton_lang_extension since we made a decision to use int64 indexing everywhere to prevent unexpected overflow.

We used to use a more conservative way to do so by computing the maximum element offset of a tensor but now we decide to make it easier.

See also #327

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE

Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test case where dim is None.

@wyjoutstanding
Copy link
Author

Please add a test case where dim is None.

@wyjoutstanding wyjoutstanding changed the title add argmin(#318) add argmin (#318) Dec 27, 2024
@wyjoutstanding
Copy link
Author

@iclementine rebase了,帮忙review下

from ..utils import libentry
from ..utils import triton_lang_extension as tle

torch_dtype_to_tl_dtype_and_max_value = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里格式有一些不符合格式化工具检查的结果。冒号前不用加空格,后面加即可。
请使用 pre-commit 工具处理一下。

pip install pre-commit 然后在工程目录 pre-commit install.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Code Contribution: 【Lv2】【Operator Development】argmin
2 participants