-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Codegen] Add support to emulate unsupported float type #19943
Conversation
Signed-off-by: Chi Liu<[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it does not fix anything because you do not add the pass to any pipeline.. Can you add an e2e test to https://github.com/iree-org/iree/tree/main/tests/e2e/linalg?
Also, we need better documentation (and description) for the pass and the PR description.
E.g., people who dont have context might ask "what is unsupported float type"
You're right! After further thinking I'll wrap those emulate APIs in a fn and call it here:
|
@pashu123 Because you need to run This sort of thing does, with how stuff is architected today, likely need to be a pass. See upstreams |
If you look at the current pass, it's using APIs from --arith-emulate-unsupported-floats pass. Would you like me to tailor the downstream pass for a subset of arith operations? |
Operation *op = getOperation(); | ||
|
||
// Add the source types to be converted to the target type. | ||
SmallVector<Type> sourceTypes = {Float8E4M3FNUZType::get(context)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a thought, could we merge the BF16 -> F32 arithmetic pass into this one?
Also, you'll also want the E5M2 type here
Converts arith operations on unsupported float types to f32.
funcPassManager.addPass([&]() { | ||
ConvertUnsupportedFloatArithPassOptions options; | ||
// Convert arith operation with the given `source types` to `target` | ||
// type. | ||
options.sourceTypeStrs = {"f8E4M3FNUZ", "f8E5M2FNUZ"}; | ||
options.targetTypeStr = "f32"; | ||
return createConvertUnsupportedFloatArithPass(options); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The notion of which type is "supported" depends on the combination of (target architecture, specific operation, specific operand for that operation). For example, on CDNA3, we have matrix multiplication instructions that handle f8E4M3FNUZ and f8E5M2FNUZ for the LHS and RHS operands, but not for the accumulator operand. How is that nuance reflected in the logic being introduced in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/llvm/llvm-project/blob/b04a980b5597c61a8df2b489c4894bc0240b8e13/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp#L122 This doesn't touch those operations for now but in future should add the pattern downstream of what it should look for based on the target. I should add the pass under a flag checking the target.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The target should be looked up from the target attribute on the op that the pass is running on, not from a flag.
In the pass's runOnOperation() method, do something like auto target = IREE::HAL::ExecutableTargetAttr::lookup(getOperation());
.
That means that the list of unsupported types shouldn't be a pass option. The pass should be option-less and act on the executable target. Similar to what this PR did (also see its test specifying targets): #19922
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info! It was constructive. I've updated the pass based on the target arch. This only populates gfx94{*}
source and target types for now and can be updated based on the need.
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
Outdated
Show resolved
Hide resolved
0fd49a2
to
1966456
Compare
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32 | ||
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ | ||
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ | ||
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, find ways to elide some irrelevant parts of this target attribute. For instance, the mma
array could be empty. (Side note: we should make wgp
optional for things like this.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the PR description with detailed pass description and what is supported in the PR. E.g., Add what source and target conversion types for gfx94{*} series.
This change enables the conversion of types such as f8E4M3FNUZ and f8E5M2FNUZ (emulated via the existing APIs) into f32 operations. The conversion logic is now tightly coupled with the executable target attribute, so that it is applied only for the gfx942 target. This removes the need for manual pass configuration to specify source types and aligns the behaviour with the target’s capabilities. For any new conversion, just populate the conversion target with source and target types.
FIX: #19921 (comment)