Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] Add support to emulate unsupported float type #19943

Merged
merged 3 commits into from
Feb 13, 2025

Conversation

pashu123
Copy link
Contributor

@pashu123 pashu123 commented Feb 10, 2025

This change enables the conversion of types such as f8E4M3FNUZ and f8E5M2FNUZ (emulated via the existing APIs) into f32 operations. The conversion logic is now tightly coupled with the executable target attribute, so that it is applied only for the gfx942 target. This removes the need for manual pass configuration to specify source types and aligns the behaviour with the target’s capabilities. For any new conversion, just populate the conversion target with source and target types.

FIX: #19921 (comment)

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does not fix anything because you do not add the pass to any pipeline.. Can you add an e2e test to https://github.com/iree-org/iree/tree/main/tests/e2e/linalg?

Also, we need better documentation (and description) for the pass and the PR description.

E.g., people who dont have context might ask "what is unsupported float type"

@pashu123
Copy link
Contributor Author

I think it does not fix anything because you do not add the pass to any pipeline.. Can you add an e2e test to https://github.com/iree-org/iree/tree/main/tests/e2e/linalg?

Also, we need better documentation (and description) for the pass and the PR description.

E.g., people who dont have context might ask "what is unsupported float type"

You're right! After further thinking I'll wrap those emulate APIs in a fn and call it here:

arith::populateExpandBFloat16Patterns(patterns);
rather than creating a pass.

@krzysz00
Copy link
Contributor

@pashu123 Because you need to run --arith-to-amdgpu to get the FP8 conversion instructions, you shouldn't just add the rewrites to ConvertotROCDL like that

This sort of thing does, with how stuff is architected today, likely need to be a pass.

See upstreams --arith-emulate-unsupported-floats

@pashu123
Copy link
Contributor Author

pashu123 commented Feb 10, 2025

@pashu123 Because you need to run --arith-to-amdgpu to get the FP8 conversion instructions, you shouldn't just add the rewrites to ConvertotROCDL like that

This sort of thing does, with how stuff is architected today, likely need to be a pass.

See upstreams --arith-emulate-unsupported-floats

If you look at the current pass, it's using APIs from --arith-emulate-unsupported-floats pass. Would you like me to tailor the downstream pass for a subset of arith operations?

Operation *op = getOperation();

// Add the source types to be converted to the target type.
SmallVector<Type> sourceTypes = {Float8E4M3FNUZType::get(context)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a thought, could we merge the BF16 -> F32 arithmetic pass into this one?

Also, you'll also want the E5M2 type here

Converts arith operations on unsupported float types to f32.
Comment on lines 1129 to 1136
funcPassManager.addPass([&]() {
ConvertUnsupportedFloatArithPassOptions options;
// Convert arith operation with the given `source types` to `target`
// type.
options.sourceTypeStrs = {"f8E4M3FNUZ", "f8E5M2FNUZ"};
options.targetTypeStr = "f32";
return createConvertUnsupportedFloatArithPass(options);
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The notion of which type is "supported" depends on the combination of (target architecture, specific operation, specific operand for that operation). For example, on CDNA3, we have matrix multiplication instructions that handle f8E4M3FNUZ and f8E5M2FNUZ for the LHS and RHS operands, but not for the accumulator operand. How is that nuance reflected in the logic being introduced in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/llvm/llvm-project/blob/b04a980b5597c61a8df2b489c4894bc0240b8e13/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp#L122 This doesn't touch those operations for now but in future should add the pattern downstream of what it should look for based on the target. I should add the pass under a flag checking the target.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target should be looked up from the target attribute on the op that the pass is running on, not from a flag.

In the pass's runOnOperation() method, do something like auto target = IREE::HAL::ExecutableTargetAttr::lookup(getOperation()); .

That means that the list of unsupported types shouldn't be a pass option. The pass should be option-less and act on the executable target. Similar to what this PR did (also see its test specifying targets): #19922

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info! It was constructive. I've updated the pass based on the target arch. This only populates gfx94{*} source and target types for now and can be updated based on the need.

// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, find ways to elide some irrelevant parts of this target attribute. For instance, the mma array could be empty. (Side note: we should make wgp optional for things like this.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the PR description with detailed pass description and what is supported in the PR. E.g., Add what source and target conversion types for gfx94{*} series.

@pashu123 pashu123 merged commit 0ff26a7 into iree-org:main Feb 13, 2025
40 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ROCm][Codegen] llama 8b fp8 with attention segfault
5 participants