Skip to content

Commit 0d27741

Browse files
committed
fix
Signed-off-by: Dylan Chen <[email protected]>
1 parent 76a47c7 commit 0d27741

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ def _(
121121
ret = mat_a.new_empty(shape, dtype=out_dtype)
122122
return ret
123123

124+
@torch.library.register_fake("trtllm::cuda_scaled_mm")
125+
def _(
126+
mat_a: torch.Tensor,
127+
mat_b: torch.Tensor,
128+
scale_a: torch.Tensor,
129+
scale_b: torch.Tensor,
130+
bias,
131+
out_dtype,
132+
userbuffers_id=False,
133+
):
134+
shape = [i for i in mat_a.shape]
135+
shape[-1] = mat_b.shape[-1]
136+
ret = mat_a.new_empty(shape, dtype=out_dtype)
137+
return ret
138+
124139
@torch.library.register_fake("trtllm::cublas_mm")
125140
def _(mat_a, mat_b, bias, out_dtype):
126141
shape = list(mat_a.shape)

0 commit comments

Comments
 (0)