Skip to content

Implement EmbeddingBag #3596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed

Implement EmbeddingBag #3596

wants to merge 11 commits into from

Conversation

cognaiger9
Copy link
Collaborator

  • Add EmbeddingBag operation with forward kernels.
  • Add driver and gtest for kernels.
  • MIOpen performs better if:
    • Mode: Max
    • Mode: Mean or Sum, when the tensor type is float and all tensors are contiguous and number of elements in the output exceeds 2^19
  • Note:
    • Forward solver only works with 2D tensor

Average improvement over ROCm

type fwd
float16 1.38
float 1.3
bfloat16 1.4

Detail Benchmark

float16 (mode max)
op_name dtype input size weight size cont direction ROCm MIOpen Improvement
EmbeddingBag float16 [32 655] [100 256] cont fwd 320271 253759 1.26
EmbeddingBag float16 [512 512] [256 1024] cont fwd 1198653 846078 1.42
EmbeddingBag float16 [512 512] [256 1024] noncont fwd 1210990 1012089 1.20
EmbeddingBag float16 [512 512] [512 2048] cont fwd 2385852 1520730 1.57
EmbeddingBag float16 [512 512] [512 2048] noncont fwd 2394267 1838890 1.30
EmbeddingBag float16 [1024 512] [256 1024] cont fwd 2391292 1555550 1.54
EmbeddingBag float16 [1024 512] [512 2048] cont fwd 4765832 2955600 1.61
EmbeddingBag float16 [1024 512] [512 2048] noncont fwd 4791704 3989040 1.20
EmbeddingBag float16 [768 768] [256 1024] cont fwd 2684508 1778770 1.51
EmbeddingBag float16 [768 768] [256 1024] noncont fwd 2706427 2111200 1.28
EmbeddingBag float16 [768 768] [512 2048] cont fwd 5352231 3353980 1.60
EmbeddingBag float16 [768 768] [512 2048] noncont fwd 5376823 4098639 1.31
EmbeddingBag float16 [256 256] [512 2048] cont fwd 607119 422861 1.44
EmbeddingBag float16 [256 256] [512 2048] noncont fwd 612143 508318 1.20
EmbeddingBag float16 [16 255] [100 256] cont fwd 129472 108169 1.20
float32 (mode max)
op_name dtype input size weight size cont direction ROCm MIOpen Improvement
EmbeddingBag float32 [512 512] [256 1024] cont fwd 1191598 863910 1.38
EmbeddingBag float32 [512 512] [256 1024] noncont fwd 1191582 1029700 1.16
EmbeddingBag float32 [512 512] [512 2048] cont fwd 2364124 1556490 1.52
EmbeddingBag float32 [512 512] [512 2048] noncont fwd 2376844 1863040 1.28
EmbeddingBag float32 [1024 512] [256 1024] cont fwd 2381244 1587020 1.50
EmbeddingBag float32 [1024 512] [512 2048] cont fwd 4769672 3037880 1.57
EmbeddingBag float32 [1024 512] [512 2048] noncont fwd 4799640 3896120 1.23
EmbeddingBag float32 [768 768] [256 1024] cont fwd 2664284 1819250 1.46
EmbeddingBag float32 [768 768] [256 1024] noncont fwd 2684236 2146680 1.25
EmbeddingBag float32 [768 768] [512 2048] cont fwd 5335863 3436330 1.55
EmbeddingBag float32 [768 768] [512 2048] noncont fwd 5371095 4152190 1.29
EmbeddingBag float32 [256 256] [512 2048] cont fwd 602367 435573 1.38
EmbeddingBag float32 [256 256] [512 2048] noncont fwd 607999 517136 1.18
EmbeddingBag float32 [128 512] [512 2048] cont fwd 600879 528356 1.14
bfloat16 (mode max)
op_name dtype input size weight size cont direction ROCm MIOpen Improvement
EmbeddingBag bfloat16 [512 512] [256 1024] cont fwd 1188750 834043 1.43
EmbeddingBag bfloat16 [512 512] [256 1024] noncont fwd 1188286 1010610 1.18
EmbeddingBag bfloat16 [512 512] [512 2048] cont fwd 2360604 1491550 1.58
EmbeddingBag bfloat16 [512 512] [512 2048] noncont fwd 2360636 1816100 1.30
EmbeddingBag bfloat16 [1024 512] [256 1024] cont fwd 2373468 1527450 1.55
EmbeddingBag bfloat16 [1024 512] [512 2048] cont fwd 4721576 2911920 1.62
EmbeddingBag bfloat16 [1024 512] [512 2048] noncont fwd 4735288 4048420 1.17
EmbeddingBag bfloat16 [768 768] [100 256] cont fwd 637519 525137 1.21
EmbeddingBag bfloat16 [768 768] [256 1024] cont fwd 2662716 1746410 1.52
EmbeddingBag bfloat16 [768 768] [256 1024] noncont fwd 2681467 2100390 1.28
EmbeddingBag bfloat16 [768 768] [512 2048] cont fwd 5294503 3315640 1.60
EmbeddingBag bfloat16 [768 768] [512 2048] noncont fwd 5323143 4110299 1.30
EmbeddingBag bfloat16 [256 256] [512 2048] cont fwd 600703 416995 1.44
float32 (mode mean)
op_name dtype size direction cont direction ROCm MIOpen Improvement
EmbeddingBag float32 [512 512] [256 1024] cont fwd 1154688 955314 1.21
EmbeddingBag float32 [512 512] [512 2048] cont fwd 2307201 1764120 1.31
EmbeddingBag float32 [1024 512] [256 1024] cont fwd 2303857 1798470 1.28
EmbeddingBag float32 [1024 512] [512 2048] cont fwd 4597474 3431640 1.34
EmbeddingBag float32 [768 768] [512 2048] cont fwd 2588381 2049459 1.26
EmbeddingBag float32 [768 768] [512 2048] cont fwd 5169386 3869800 1.34
EmbeddingBag float32 [256 256] [512 2048] cont fwd 589464 481217 1.22
EmbeddingBag float32 [16 255] [512 2048] cont fwd 125583 114259 1.10
EmbeddingBag float32 [16 255] [512 2048] cont fwd 142318 126775 1.12
EmbeddingBag float32 [128 128] [100 256] cont fwd 72623 65636 1.11

@BradPepersAMD
Copy link
Collaborator

MIOpen is moving to the new monorepo setup and all older unmerged PR's are being closed. Please re-open this as part of the new repo if these changes are still needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants