Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TVD Loss Kernel #324

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ loss.backward()
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
| JSD | `liger_kernel.transformers.LigerJSD` |
| FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
| TVD | `liger_kernel.transformers.LigerTVDLoss` |

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
Expand All @@ -286,6 +287,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
- **TVD**: [TVD](https://aclanthology.org/2023.acl-long.605.pdf) (Total variation distance), is implemented by computing both the loss and gradient in the forward pass. It achieves ~2X speed and ~15% memory reduction for 128k vocab size.
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.


Expand Down
37 changes: 37 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
Expand Down Expand Up @@ -619,3 +655,4 @@ layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160
layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1

136 changes: 136 additions & 0 deletions benchmark/scripts/benchmark_tvd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.transformers.tvd import LigerTVDLoss


class TorchTVDLoss(torch.nn.Module):
def __init__(self, reduction="batchmean"):
super(TorchTVDLoss, self).__init__()
self.reduction = reduction

def forward(self, p, q):
tvd = torch.abs(p - q) / 2.0
if self.reduction == "mean":
return torch.sum(tvd) / (p.size(0) * p.size(1))
elif self.reduction == "sum":
return torch.sum(tvd)
elif self.reduction == "none":
return tvd
elif self.reduction == "batchmean":
return torch.sum(tvd) / p.size(0)
else:
raise ValueError("Invalid reduction type.")


S, E = 12, 18


def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
torch_tvd = TorchTVDLoss(reduction=reduction)
liger_tvd = LigerTVDLoss(reduction=reduction)

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
return liger_tvd(_input, target)
else:
return torch_tvd(_input, target)

if input.kernel_operation_mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif input.kernel_operation_mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
elif input.kernel_operation_mode == "full":

def full():
y = fwd()
y.backward(retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full, quantiles=QUANTILES, rep=100
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
torch_tvd = TorchTVDLoss(reduction=reduction)
liger_tvd = LigerTVDLoss(reduction=reduction)

V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
return liger_tvd(_input, target)
else:
return torch_tvd(_input, target)

def full():
y = fwd()
y.backward(retain_graph=True)

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)

return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()
common_args = {
"kernel_name": "tvd",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"B": 8, "T": 2048}],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_memory_tvd,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_args,
)

run_benchmarks(
bench_test_fn=bench_speed_tvd,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_args,
)
176 changes: 176 additions & 0 deletions src/liger_kernel/ops/tvd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from typing import Literal

import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import ensure_contiguous

MAX_FUSED_SIZE = 65536 // 4

REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]

_REDUCTION_MODE_NONE = tl.constexpr(0)
_REDUCTION_MODE_SUM = tl.constexpr(1)
_REDUCTION_MODE_MEAN = tl.constexpr(2)
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)

_str_to_reduction_mode = {
"none": _REDUCTION_MODE_NONE.value,
"sum": _REDUCTION_MODE_SUM.value,
"mean": _REDUCTION_MODE_MEAN.value,
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
}


def get_num_warps(BLOCK_SIZE):
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8

return num_warps


@triton.jit
def _tv_distance_kernel(
p_ptr,
p_stride,
q_ptr,
q_stride,
loss_ptr,
loss_stride,
grads_ptr,
grads_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
):
pid = tl.program_id(0).to(tl.int64)
p_ptr += pid * p_stride
q_ptr += pid * q_stride
loss_ptr += pid * loss_stride
grads_ptr += pid * grads_stride

base_offsets = tl.arange(0, BLOCK_SIZE)

loss_sum = 0.0
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + base_offsets
mask = offsets < n_cols

p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)

# TVD(P || Q) = 0.5 * |P - Q|
tv_loss = 0.5 * tl.abs(p - q)

grad_res = tl.where(p > q, 0.5, -0.5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we're doing gradients calculation in forward pass already, we can divide gradients by BT (BT * V) based on the reduction mode here to avoid extra calculations in backward pass and saving reduction mode in ctx


tl.store(grads_ptr + offsets, grad_res, mask=mask)

if reduction == _REDUCTION_MODE_NONE:
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
else:
loss_sum += tl.sum(tv_loss, axis=0)

if reduction != _REDUCTION_MODE_NONE:
tl.store(loss_ptr, loss_sum)


def tv_distance_forward_triton(p, q, reduction):
BT, V = p.shape

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
num_warps = get_num_warps(BLOCK_SIZE)

grid = (BT,)

reduction = _str_to_reduction_mode[reduction]

out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
grads = torch.empty_like(p)

_tv_distance_kernel[grid](
p,
p.stride(0),
q,
q.stride(0),
output_tensor,
output_tensor.stride(0),
grads,
grads.stride(0),
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
reduction=reduction,
)

if reduction == _REDUCTION_MODE_BATCHMEAN.value:
return output_tensor.sum() / BT, grads / BT
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0), grads
elif reduction == _REDUCTION_MODE_MEAN.value:
return output_tensor.sum() / (BT * V), grads / (BT * V)
else:
return output_tensor, grads


def tvd_backward_triton(grad_output, grads):

# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return grads

return grads * grad_output


class LigerTVDLossFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
"""

@staticmethod
@ensure_contiguous
def forward(
ctx,
p: torch.Tensor,
q: torch.Tensor,
reduction: REDUCTION_LITERAL = "batchmean",
) -> torch.Tensor:
"""A forward pass for the Total Variation Distance Loss.

Args:
ctx: Torch autograd context
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".

Returns:
torch.Tensor: The computed Total Variation Distance Loss.
"""
loss, grads = tv_distance_forward_triton(p, q, reduction)
ctx.save_for_backward(grads)
return loss

@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""A backward pass for the Total Variation Distance Loss.

Args:
ctx: Torch autograd context
grad_output (torch.Tensor): The gradient of the loss with respect to the output.

Returns:
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs.
"""
(grads,) = ctx.saved_tensors

grads = tvd_backward_triton(grad_output, grads)

return grads, None, None
Loading