Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,17 +793,35 @@ def register_ported_op_all_packed_dims():


# Ported ops that support their own prepacking.
@update_features(
[
exir_ops.edge.aten.embedding.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
]
)
def register_ported_ops_with_prepacking():
@update_features(exir_ops.edge.aten.embedding.default)
def register_embedding_op():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
supports_prepacking=True,
supports_resize=True,
)


@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default)
def register_batch_norm_op():
def check_batch_norm_node(node: torch.fx.Node) -> bool:
x = node.args[0]
if not isinstance(x, torch.fx.Node):
return False
x_val = x.meta.get("val", None)
if x_val is None:
return False
x_shape = x_val.size()
# Only support 4-D input tensors since this is a restriction enforced by the
# operator implementation.
# TODO(ssjia): Add shape agnostic support for batch norm
return len(x_shape) == 4

return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
supports_prepacking=True,
supports_resize=True,
are_node_inputs_supported_fn=check_batch_norm_node,
)


Expand Down
31 changes: 26 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ $else:
#include "broadcasting_utils.h"
#include "indexing_utils.h"

$if MASK_PADDING:
#define MASK_PADDING

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
Expand Down Expand Up @@ -140,11 +143,29 @@ void main() {
other_texel = other_texel.xxxx;
}

write_texel_lpos(
t_out,
lpos,
VEC4_OUT_T(op(in_texel, other_texel, alpha)),
out_axis_map);
VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, other_texel, alpha));

#ifdef MASK_PADDING
// Handle padding elements in the last texel to prevent NaN propagation.
// When the packed dimension size is not a multiple of 4, the last texel
// will have padding elements. For division operations, padding elements
// (which are 0/0) can produce NaN values that propagate through reductions.
const int nspill = mod4(out_sizes[packed_dim]);

if (nspill > 0) {
const int texels_per_batch = divup4(out_sizes[packed_dim]);
const bool is_last_texel = (lpos[packed_dim] % texels_per_batch) == (texels_per_batch - 1);

if (is_last_texel) {
// Explicitly set padding elements to 0 to avoid NaN
[[unroll]] for (int i = nspill; i < 4; i++) {
out_texel[i] = 0;
}
}
}
#endif

write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
}

#endif
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ binary_op:
NDIM: 3
DTYPE: float
PACKING: C_packed
MASK_PADDING: 0
generate_variant_forall:
STORAGE:
- VALUE: texture3d
Expand All @@ -26,10 +27,12 @@ binary_op:
OPERATOR: X * Y
- NAME: binary_div
OPERATOR: X / Y
MASK_PADDING: 1
- NAME: binary_pow
OPERATOR: pow(X, Y)
- NAME: binary_floor_divide
OPERATOR: floor(X / Y)
MASK_PADDING: 1
- NAME: binary_minimum
OPERATOR: min(X, Y)
- NAME: binary_eq_int32
Expand Down
64 changes: 64 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,24 @@ def forward(self, x):
sample_inputs,
)

def test_vulkan_backend_batch_norm_after_linear(self):
class LinearBatchNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 64)
self.bn = torch.nn.BatchNorm1d(num_features=64)

def forward(self, x):
x = self.linear(x)
return self.bn(x)

sample_inputs = (torch.randn(size=(4, 128), dtype=torch.float32),)

self.lower_module_and_test_output(
LinearBatchNormModule(),
sample_inputs,
)

def test_vulkan_backend_full(self):
class FullModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1767,6 +1785,52 @@ def forward(self, x):
(torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),),
)

def test_vulkan_backend_div_with_padding_nan_propagation(self):
"""
Test division operations with non-multiple-of-4 channels followed by convolution.

This test verifies the fix for NaN propagation in padding texels during division.
When the packed dimension (channels=3) is not a multiple of 4, texture-backed
tensors have padding elements in the last texel. Without proper masking, division
operations produce NaN values (0/0) in padding regions that propagate through
subsequent operations like convolution, corrupting results.

This simulates a common real-world pattern: per-channel image normalization
(subtract mean, divide by std) followed by convolution.
"""

class NormalizationConvModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Per-channel mean and std for normalization (shape: [1, 3, 1, 1])
self.mean = torch.tensor([[[[0.485]], [[0.456]], [[0.406]]]])
self.std = torch.tensor([[[[0.229]], [[0.224]], [[0.215]]]])

# Conv2d layer to process normalized image
self.conv = torch.nn.Conv2d(
in_channels=3, # Non-multiple-of-4 to trigger padding
out_channels=16,
kernel_size=3,
padding=1,
stride=1,
bias=True,
)

def forward(self, x):
# Simulate image normalization: (x - mean) / std
# This is where NaN could appear in padding texels without the fix
x = x - self.mean
x = x / self.std
# Convolution operation that would be corrupted by NaN propagation
x = self.conv(x)
return x

module = NormalizationConvModule()
# Use a typical image tensor size [batch=1, channels=3, height=256, width=256]
sample_inputs = (torch.randn(size=(1, 3, 256, 256), dtype=torch.float32),)

self.lower_module_and_test_output(module, sample_inputs)

def test_vulkan_backend_grid_priors(self):
class GridPriorsModule(torch.nn.Module):
def __init__(self):
Expand Down
Loading