Skip to content

Commit 8f22fdc

Browse files
committed
update readme
1 parent 21ccb3d commit 8f22fdc

File tree

1 file changed

+94
-6
lines changed

1 file changed

+94
-6
lines changed

README.md

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,95 @@ that's it nothing special .
1010
pip install -U triformer
1111
```
1212
### Usage
13+
- Using TritonLayerNorm
14+
```python
15+
import torch
16+
from triformer import TritonLayerNorm
1317

14-
Coming Soon
18+
# Create dummy data
19+
batch_size, seq_len, hidden_dim = 32, 64, 512
20+
x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
1521

22+
# Initialize and use LayerNorm
23+
layer_norm = TritonLayerNorm(hidden_dim).cuda()
24+
ln_output = layer_norm(x)
25+
26+
# Print information about the tensors
27+
print("Input shape:", x.shape)
28+
print("Output shape:", ln_output.shape)
29+
30+
31+
# Print a small sample
32+
print("\nSample of output (first 5 values of first sequence):")
33+
print(ln_output[0, 0, :10].cpu().detach().numpy())
34+
```
35+
```python
36+
# Softmax Example
37+
from triformer import TritonSoftmax
38+
import torch
39+
batch_size, seq_len = 32, 64
40+
attention_scores = torch.randn(batch_size, seq_len, seq_len).cuda()
41+
42+
# Regular softmax
43+
softmax = TritonSoftmax(is_causal=False).cuda()
44+
regular_attention = softmax(attention_scores)
45+
46+
# Causal softmax
47+
causal_softmax = TritonSoftmax(is_causal=True).cuda()
48+
causal_attention = causal_softmax(attention_scores)
49+
50+
print("\n=== Softmax ===")
51+
print("Input shape:", attention_scores.shape)
52+
print("Output shape:", regular_attention.shape)
53+
print("\nRegular softmax sample (first 5 values):")
54+
print(regular_attention[0, 0, :5].cpu().detach().numpy())
55+
print("\nCausal softmax sample (first 5 values):")
56+
print(causal_attention[0, 0, :5].cpu().detach().numpy())
57+
print("\nRow sums (should be 1.0):")
58+
print("Regular:", regular_attention[0, 0].sum().item())
59+
print("Causal:", causal_attention[0, 0].sum().item())
60+
```
61+
62+
```python
63+
from triformer import TritonDropout
64+
import torch
65+
batch_size, seq_len, hidden_dim = 32, 64, 512
66+
x = torch.ones(batch_size, seq_len, hidden_dim).cuda() # Using ones for clearer demonstration
67+
68+
training_output = TritonDropout.apply(x,0.5,42).cuda()
69+
70+
71+
print("\n=== Dropout ===")
72+
print("Input shape:", x.shape)
73+
print("Output shape:", training_output.shape)
74+
print("\nSample output (first 10 values, showing dropout pattern):")
75+
print(training_output[0, 0, :10].cpu().detach().numpy())
76+
print("\nPercentage of non-zero values (should be ~0.5):")
77+
print((training_output != 0).float().mean().item())
78+
```
79+
```python
80+
from triformer import TritonCrossEntropyLoss
81+
82+
batch_size, seq_len, vocab_size = 32, 64, 30000
83+
logits = torch.randn(batch_size * seq_len, vocab_size).cuda()
84+
targets = torch.randint(0, vocab_size, (batch_size * seq_len,)).cuda()
85+
86+
criterion = TritonCrossEntropyLoss(
87+
pad_token_id=0,
88+
reduction='mean',
89+
n_chunks=1
90+
).cuda()
91+
92+
loss = criterion(logits, targets)
93+
94+
print("\n=== Cross Entropy Loss ===")
95+
print("Logits shape:", logits.shape)
96+
print("Targets shape:", targets.shape)
97+
print("Loss value:", loss.item())
98+
print("\nSample logits (first 5 values for first item):")
99+
print(logits[0, :5].cpu().detach().numpy())
100+
print("Corresponding target:", targets[0].item())
101+
```
16102
# Benchmarking
17103
The benchmarking was done on the L40s GPU
18104

@@ -91,8 +177,10 @@ pytest tests/test_cross_entropy.py
91177

92178
## Future Plans - To Do
93179
- [ ] Create a library specifically for transformers in vision and language
94-
- [x] Implement the layernorm in Triton
95-
- [x] Implement the softmax in Triton
96-
- [x] Implement the dropout in Triton
97-
- [x] Implement the cross entropy loss in Triton
98-
-
180+
- [x] Core Operations:
181+
- [x] LayerNorm in Triton
182+
- [x] Softmax in Triton
183+
- [x] Dropout in Triton
184+
- [x] Cross Entropy Loss in Triton
185+
- [ ] Feed Forward Network (fused GeLU + Linear in Triton)
186+
- [ ] The complete transformer model

0 commit comments

Comments
 (0)