Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule]Fix the bug when loading database_tuning_record.json if there is pad_einsum primitive #17413

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

YXY-0922
Copy link
Contributor

When loading from database_tuning_record.json in Meta Schedule (this line: B_reindex_pad_shared_dyn[v0, v1] = T.if_then_else(v0 < 1, B[v1, v0], T.float16(0))), the parameter dtype of the primitive pad_einsum is read as int64, causing a block iterator v0 that should be int32 to be inferred as int64. This results in an InternalError: Check failed: (ret_ex.dtype() == var.dtype()) is false: substituting v0:int32 -> v0:int64. This commit performs a type conversion within UnpackedApplyToSchedule to fix this bug.

for ax0_ax1_fused in range(2048):
    with T.block("B_reindex_pad_shared.dyn"):
        v0 = T.axis.spatial(16, ax0_ax1_fused // 128)
        v1 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused % 128)
        T.reads(B[v1, v0])
        T.writes(B_reindex_pad_shared_dyn[v0, v1])
        T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1})
        B_reindex_pad_shared_dyn[v0, v1] = T.if_then_else(v0 < 1, B[v1, v0], T.float16(0))

…edule, the parameter dtype of the primitive pad_einsum is read as int64, causing a block iterator v0 that should be int32 to be inferred as int64. This results in an InternalError: Check failed: (ret_ex.dtype() == var.dtype()) is false: substituting v0:int32 -> v0:int64. This commit performs a type conversion within UnpackedApplyToSchedule to fix this bug.
@YXY-0922 YXY-0922 changed the title [MetaSchedule]Fix the bug when load database_tuning_record.json if there is pad_einsum primitive. [MetaSchedule]Fix the bug when loading database_tuning_record.json if there is pad_einsum primitive. Sep 24, 2024
@YXY-0922 YXY-0922 changed the title [MetaSchedule]Fix the bug when loading database_tuning_record.json if there is pad_einsum primitive. [MetaSchedule]Fix the bug when loading database_tuning_record.json if there is pad_einsum primitive Sep 24, 2024
@Hzfengsy
Copy link
Member

Hzfengsy commented Oct 2, 2024

Sorry for the late response, and thanks for the great catch. One thing is that did you test if the TIR function itself is i64-based? It would be good if you could add test cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants