Skip to content

Commit

Permalink
update setup
Browse files Browse the repository at this point in the history
  • Loading branch information
dame-cell committed Oct 26, 2024
1 parent f562cac commit e75eed2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions triformer/forward_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def grid(meta):
class TritonLinearFunction(Function):
@staticmethod
def forward(ctx, x, weight, bias, use_relu=True):
# Convert input to float16 if necessary
x = x.to(dtype=torch.float16)
# Input should already be float16, no conversion needed

# Add batch dimension if necessary
if x.ndim == 2:
Expand Down Expand Up @@ -152,12 +151,12 @@ def __init__(self, in_features, out_features, use_relu=True):
self.in_features = in_features
self.out_features = out_features
self.use_relu = use_relu
# Change weight and bias to float32
# Change weight and bias to float16
self.weight = nn.Parameter(
torch.empty(out_features, in_features, device='cuda', dtype=torch.float32)
torch.empty(out_features, in_features, device='cuda', dtype=torch.float16)
)
self.bias = nn.Parameter(
torch.zeros(out_features, device='cuda', dtype=torch.float32)
torch.zeros(out_features, device='cuda', dtype=torch.float16)
)
self.reset_parameters()

Expand Down

0 comments on commit e75eed2

Please sign in to comment.