From e75eed2c3f966a7f18eb11fd84196f37561f2207 Mon Sep 17 00:00:00 2001 From: win10 Date: Sat, 26 Oct 2024 15:18:03 +0530 Subject: [PATCH] update setup --- triformer/forward_mlp.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/triformer/forward_mlp.py b/triformer/forward_mlp.py index a318b00..118715c 100644 --- a/triformer/forward_mlp.py +++ b/triformer/forward_mlp.py @@ -96,8 +96,7 @@ def grid(meta): class TritonLinearFunction(Function): @staticmethod def forward(ctx, x, weight, bias, use_relu=True): - # Convert input to float16 if necessary - x = x.to(dtype=torch.float16) + # Input should already be float16, no conversion needed # Add batch dimension if necessary if x.ndim == 2: @@ -152,12 +151,12 @@ def __init__(self, in_features, out_features, use_relu=True): self.in_features = in_features self.out_features = out_features self.use_relu = use_relu - # Change weight and bias to float32 + # Change weight and bias to float16 self.weight = nn.Parameter( - torch.empty(out_features, in_features, device='cuda', dtype=torch.float32) + torch.empty(out_features, in_features, device='cuda', dtype=torch.float16) ) self.bias = nn.Parameter( - torch.zeros(out_features, device='cuda', dtype=torch.float32) + torch.zeros(out_features, device='cuda', dtype=torch.float16) ) self.reset_parameters()