Skip to content

Commit 7d060db

Browse files
SS-JIAssjia
andauthored
[ET-VK] Re-implement split_with_sizes (#15793)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #15829 * #15796 * #15795 * #15794 * __->__ #15793 As title. The current implementation of split_with_sizes uses functions from the `Copy.[h|cpp]` file in particular `add_copy_channel_offset_node`. However, the shaders dispatched by this function have a critical bug where the output tensor is passed in separately with difference access types, i.e. ```cpp graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), global_size, local_size, // Inputs and Outputs { {out, vkapi::kWrite}, {out, vkapi::kRead}, {in, vkapi::kRead}, }, ``` This creates many validation layer errors because the memory barriers for the resource cannot be formed properly. The shader essentially relies on undefined behaviour to work correctly To fix, this diff re-implements the operator from scratch with a dedicated compute shader. Differential Revision: [D86910642](https://our.internmc.facebook.com/intern/diff/D86910642/) --------- Co-authored-by: ssjia <[email protected]>
1 parent 3af3385 commit 7d060db

20 files changed

+295
-1291
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ def register_cat_op():
740740
[
741741
exir_ops.edge.aten.select_copy.int,
742742
exir_ops.edge.aten.slice_copy.Tensor,
743+
exir_ops.edge.aten.split_with_sizes_copy.default,
743744
]
744745
)
745746
def register_transfer_ops():
@@ -782,10 +783,7 @@ def register_ported_op():
782783
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry because they support all packed dimensions
783784
@update_features(
784785
[
785-
# Tensor combination
786786
exir_ops.edge.aten.repeat.default,
787-
exir_ops.edge.aten.split_with_sizes_copy.default,
788-
exir_ops.edge.aten.split.Tensor,
789787
]
790788
)
791789
def register_ported_op_all_packed_dims():

backends/vulkan/runtime/graph/ops/glsl/common.glslh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,15 @@ int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) {
8686
return pack_into_int32(quantized);
8787
}
8888

89+
#ifdef DEBUG_MODE
90+
91+
#define printf debugPrintfEXT
92+
93+
void printVec4(vec4 texel) {
94+
debugPrintfEXT(
95+
"texel: %f, %f, %f, %f\\n", texel.x, texel.y, texel.z, texel.w);
96+
}
97+
98+
#endif // DEBUG_MODE
99+
89100
#endif // COMMON_GLSLH

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl

Lines changed: 0 additions & 80 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml

Lines changed: 0 additions & 12 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl

Lines changed: 0 additions & 68 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml

Lines changed: 0 additions & 17 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl

Lines changed: 0 additions & 135 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)