diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index feba4f6f072..3dc873ac21c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -793,17 +793,35 @@ def register_ported_op_all_packed_dims(): # Ported ops that support their own prepacking. -@update_features( - [ - exir_ops.edge.aten.embedding.default, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - ] -) -def register_ported_ops_with_prepacking(): +@update_features(exir_ops.edge.aten.embedding.default) +def register_embedding_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_prepacking=True, + supports_resize=True, + ) + + +@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) +def register_batch_norm_op(): + def check_batch_norm_node(node: torch.fx.Node) -> bool: + x = node.args[0] + if not isinstance(x, torch.fx.Node): + return False + x_val = x.meta.get("val", None) + if x_val is None: + return False + x_shape = x_val.size() + # Only support 4-D input tensors since this is a restriction enforced by the + # operator implementation. + # TODO(ssjia): Add shape agnostic support for batch norm + return len(x_shape) == 4 + return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, supports_prepacking=True, supports_resize=True, + are_node_inputs_supported_fn=check_batch_norm_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index 6e638a3275c..8cae626f614 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -64,6 +64,9 @@ $else: #include "broadcasting_utils.h" #include "indexing_utils.h" +$if MASK_PADDING: + #define MASK_PADDING + layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} @@ -140,11 +143,29 @@ void main() { other_texel = other_texel.xxxx; } - write_texel_lpos( - t_out, - lpos, - VEC4_OUT_T(op(in_texel, other_texel, alpha)), - out_axis_map); + VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, other_texel, alpha)); + +#ifdef MASK_PADDING + // Handle padding elements in the last texel to prevent NaN propagation. + // When the packed dimension size is not a multiple of 4, the last texel + // will have padding elements. For division operations, padding elements + // (which are 0/0) can produce NaN values that propagate through reductions. + const int nspill = mod4(out_sizes[packed_dim]); + + if (nspill > 0) { + const int texels_per_batch = divup4(out_sizes[packed_dim]); + const bool is_last_texel = (lpos[packed_dim] % texels_per_batch) == (texels_per_batch - 1); + + if (is_last_texel) { + // Explicitly set padding elements to 0 to avoid NaN + [[unroll]] for (int i = nspill; i < 4; i++) { + out_texel[i] = 0; + } + } + } +#endif + + write_texel_lpos(t_out, lpos, out_texel, out_axis_map); } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index 70793628d80..ee96b5c05b4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -10,6 +10,7 @@ binary_op: NDIM: 3 DTYPE: float PACKING: C_packed + MASK_PADDING: 0 generate_variant_forall: STORAGE: - VALUE: texture3d @@ -26,10 +27,12 @@ binary_op: OPERATOR: X * Y - NAME: binary_div OPERATOR: X / Y + MASK_PADDING: 1 - NAME: binary_pow OPERATOR: pow(X, Y) - NAME: binary_floor_divide OPERATOR: floor(X / Y) + MASK_PADDING: 1 - NAME: binary_minimum OPERATOR: min(X, Y) - NAME: binary_eq_int32 diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 03a3263c293..2c0bc12b7cc 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1061,6 +1061,24 @@ def forward(self, x): sample_inputs, ) + def test_vulkan_backend_batch_norm_after_linear(self): + class LinearBatchNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 64) + self.bn = torch.nn.BatchNorm1d(num_features=64) + + def forward(self, x): + x = self.linear(x) + return self.bn(x) + + sample_inputs = (torch.randn(size=(4, 128), dtype=torch.float32),) + + self.lower_module_and_test_output( + LinearBatchNormModule(), + sample_inputs, + ) + def test_vulkan_backend_full(self): class FullModule(torch.nn.Module): def __init__(self): @@ -1767,6 +1785,52 @@ def forward(self, x): (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),), ) + def test_vulkan_backend_div_with_padding_nan_propagation(self): + """ + Test division operations with non-multiple-of-4 channels followed by convolution. + + This test verifies the fix for NaN propagation in padding texels during division. + When the packed dimension (channels=3) is not a multiple of 4, texture-backed + tensors have padding elements in the last texel. Without proper masking, division + operations produce NaN values (0/0) in padding regions that propagate through + subsequent operations like convolution, corrupting results. + + This simulates a common real-world pattern: per-channel image normalization + (subtract mean, divide by std) followed by convolution. + """ + + class NormalizationConvModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Per-channel mean and std for normalization (shape: [1, 3, 1, 1]) + self.mean = torch.tensor([[[[0.485]], [[0.456]], [[0.406]]]]) + self.std = torch.tensor([[[[0.229]], [[0.224]], [[0.215]]]]) + + # Conv2d layer to process normalized image + self.conv = torch.nn.Conv2d( + in_channels=3, # Non-multiple-of-4 to trigger padding + out_channels=16, + kernel_size=3, + padding=1, + stride=1, + bias=True, + ) + + def forward(self, x): + # Simulate image normalization: (x - mean) / std + # This is where NaN could appear in padding texels without the fix + x = x - self.mean + x = x / self.std + # Convolution operation that would be corrupted by NaN propagation + x = self.conv(x) + return x + + module = NormalizationConvModule() + # Use a typical image tensor size [batch=1, channels=3, height=256, width=256] + sample_inputs = (torch.randn(size=(1, 3, 256, 256), dtype=torch.float32),) + + self.lower_module_and_test_output(module, sample_inputs) + def test_vulkan_backend_grid_priors(self): class GridPriorsModule(torch.nn.Module): def __init__(self):