@@ -10,9 +10,95 @@ that's it nothing special .
10
10
pip install -U triformer
11
11
```
12
12
### Usage
13
+ - Using TritonLayerNorm
14
+ ``` python
15
+ import torch
16
+ from triformer import TritonLayerNorm
13
17
14
- Coming Soon
18
+ # Create dummy data
19
+ batch_size, seq_len, hidden_dim = 32 , 64 , 512
20
+ x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
15
21
22
+ # Initialize and use LayerNorm
23
+ layer_norm = TritonLayerNorm(hidden_dim).cuda()
24
+ ln_output = layer_norm(x)
25
+
26
+ # Print information about the tensors
27
+ print (" Input shape:" , x.shape)
28
+ print (" Output shape:" , ln_output.shape)
29
+
30
+
31
+ # Print a small sample
32
+ print (" \n Sample of output (first 5 values of first sequence):" )
33
+ print (ln_output[0 , 0 , :10 ].cpu().detach().numpy())
34
+ ```
35
+ ``` python
36
+ # Softmax Example
37
+ from triformer import TritonSoftmax
38
+ import torch
39
+ batch_size, seq_len = 32 , 64
40
+ attention_scores = torch.randn(batch_size, seq_len, seq_len).cuda()
41
+
42
+ # Regular softmax
43
+ softmax = TritonSoftmax(is_causal = False ).cuda()
44
+ regular_attention = softmax(attention_scores)
45
+
46
+ # Causal softmax
47
+ causal_softmax = TritonSoftmax(is_causal = True ).cuda()
48
+ causal_attention = causal_softmax(attention_scores)
49
+
50
+ print (" \n === Softmax ===" )
51
+ print (" Input shape:" , attention_scores.shape)
52
+ print (" Output shape:" , regular_attention.shape)
53
+ print (" \n Regular softmax sample (first 5 values):" )
54
+ print (regular_attention[0 , 0 , :5 ].cpu().detach().numpy())
55
+ print (" \n Causal softmax sample (first 5 values):" )
56
+ print (causal_attention[0 , 0 , :5 ].cpu().detach().numpy())
57
+ print (" \n Row sums (should be 1.0):" )
58
+ print (" Regular:" , regular_attention[0 , 0 ].sum().item())
59
+ print (" Causal:" , causal_attention[0 , 0 ].sum().item())
60
+ ```
61
+
62
+ ``` python
63
+ from triformer import TritonDropout
64
+ import torch
65
+ batch_size, seq_len, hidden_dim = 32 , 64 , 512
66
+ x = torch.ones(batch_size, seq_len, hidden_dim).cuda() # Using ones for clearer demonstration
67
+
68
+ training_output = TritonDropout.apply(x,0.5 ,42 ).cuda()
69
+
70
+
71
+ print (" \n === Dropout ===" )
72
+ print (" Input shape:" , x.shape)
73
+ print (" Output shape:" , training_output.shape)
74
+ print (" \n Sample output (first 10 values, showing dropout pattern):" )
75
+ print (training_output[0 , 0 , :10 ].cpu().detach().numpy())
76
+ print (" \n Percentage of non-zero values (should be ~0.5):" )
77
+ print ((training_output != 0 ).float().mean().item())
78
+ ```
79
+ ``` python
80
+ from triformer import TritonCrossEntropyLoss
81
+
82
+ batch_size, seq_len, vocab_size = 32 , 64 , 30000
83
+ logits = torch.randn(batch_size * seq_len, vocab_size).cuda()
84
+ targets = torch.randint(0 , vocab_size, (batch_size * seq_len,)).cuda()
85
+
86
+ criterion = TritonCrossEntropyLoss(
87
+ pad_token_id = 0 ,
88
+ reduction = ' mean' ,
89
+ n_chunks = 1
90
+ ).cuda()
91
+
92
+ loss = criterion(logits, targets)
93
+
94
+ print (" \n === Cross Entropy Loss ===" )
95
+ print (" Logits shape:" , logits.shape)
96
+ print (" Targets shape:" , targets.shape)
97
+ print (" Loss value:" , loss.item())
98
+ print (" \n Sample logits (first 5 values for first item):" )
99
+ print (logits[0 , :5 ].cpu().detach().numpy())
100
+ print (" Corresponding target:" , targets[0 ].item())
101
+ ```
16
102
# Benchmarking
17
103
The benchmarking was done on the L40s GPU
18
104
@@ -91,8 +177,10 @@ pytest tests/test_cross_entropy.py
91
177
92
178
## Future Plans - To Do
93
179
- [ ] Create a library specifically for transformers in vision and language
94
- - [x] Implement the layernorm in Triton
95
- - [x] Implement the softmax in Triton
96
- - [x] Implement the dropout in Triton
97
- - [x] Implement the cross entropy loss in Triton
98
- -
180
+ - [x] Core Operations:
181
+ - [x] LayerNorm in Triton
182
+ - [x] Softmax in Triton
183
+ - [x] Dropout in Triton
184
+ - [x] Cross Entropy Loss in Triton
185
+ - [ ] Feed Forward Network (fused GeLU + Linear in Triton)
186
+ - [ ] The complete transformer model
0 commit comments