Skip to content

Commit 49079ee

Browse files
pszemrajPeter Szemraj
andauthored
add train script (#2)
* add enwiki8 * add train script Signed-off-by: Peter Szemraj <[email protected]> * 🎨 Signed-off-by: Peter Szemraj <[email protected]> * upd base_decoding for samba shapes Signed-off-by: Peter Szemraj <[email protected]> * save config to json Signed-off-by: Peter Szemraj <[email protected]> * ignore outputs Signed-off-by: Peter Szemraj <[email protected]> * 🚧 prefer rotary-embedding-torch for rope impl Signed-off-by: Peter Szemraj <[email protected]> * fix loss reporting Signed-off-by: Peter Szemraj <[email protected]> * clean up, improve IO Signed-off-by: Peter Szemraj <[email protected]> * 📝 pay homage Signed-off-by: Peter Szemraj <[email protected]> --------- Signed-off-by: Peter Szemraj <[email protected]> Co-authored-by: Peter Szemraj <[email protected]>
1 parent 6f9376d commit 49079ee

File tree

9 files changed

+345
-135
lines changed

9 files changed

+345
-135
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ samba_pytorch/_version.py
66
*.pyd
77
*.pt*
88

9+
# outputs
10+
out/*
11+
912
# <<< END CUSTOM
1013
# Byte-compiled / optimized / DLL files
1114
__pycache__/

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ This aims to be a simpler implementation of the [original repo](https://github.c
77
## Installation
88

99
> [!TIP]
10-
> While the `pip install` command _should_ install all deps and the package, in practice some of the more CUDA-heavy deps are better installed separately from source. See section below for more details.
10+
> The pip install command _should_ install all dependencies and the package, but some CUDA-heavy dependencies are better installed separately. See below for more details.
1111
1212
```bash
1313
git clone https://github.com/pszemraj/samba-pytorch.git
@@ -40,6 +40,16 @@ model = GPT(cfg)
4040
model
4141
```
4242

43+
### Training
44+
45+
A minimalist training script for a character-level language model on enwiki8:
46+
47+
```python
48+
python train.py
49+
```
50+
51+
Credit to [nGPT-pytorch](https://github.com/lucidrains/nGPT-pytorch) for the enwik8 data set and the training script, which has been adapted for this repo.
52+
4353
## repo structure
4454

4555
```text

data/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Data source
2+
3+
Credit to [nGPT-pytorch](https://github.com/lucidrains/nGPT-pytorch) for the enwik8 dataset. The enwik8 data was (_originally_) downloaded from the Hutter prize page: <http://prize.hutter1.net/>

data/enwik8.gz

34.9 MB
Binary file not shown.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"flash-attn>=2.0.0.post1",
2020
"mamba-ssm",
2121
"numpy",
22+
"rotary-embedding-torch",
2223
"sentencepiece",
2324
"torch>=2.0.0",
2425
"tqdm",

samba_pytorch/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
101101
@property
102102
def mlp_class(self) -> Type:
103103
from samba_pytorch import samba
104+
104105
# `self._mlp_class` cannot be the type to keep the config json serializable
105106
return getattr(samba, self._mlp_class)
106107

samba_pytorch/modules/rmsnorm.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import torch
22
from torch import nn
3-
from torch.nn import functional as F
43
from einops import rearrange
54
from typing import Optional, Tuple, Union
65

6+
77
def maybe_align(x: torch.Tensor, alignment_in_bytes: int = 16) -> torch.Tensor:
88
"""Ensures memory alignment by cloning if necessary."""
99
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
1010

11+
1112
def dropout_add_layer_norm(
1213
x0: torch.Tensor,
1314
residual: Optional[torch.Tensor],
@@ -54,7 +55,7 @@ def dropout_add_layer_norm(
5455

5556
# Apply row scaling if provided
5657
if rowscale is not None:
57-
x0 = x0 * rearrange(rowscale, 'b -> b 1')
58+
x0 = x0 * rearrange(rowscale, "b -> b 1")
5859

5960
# Compute normalization (either LayerNorm or RMSNorm)
6061
if is_rms_norm:
@@ -74,21 +75,23 @@ def dropout_add_layer_norm(
7475
return output, mask
7576
return output
7677

78+
7779
class DropoutAddLayerNorm(nn.Module):
7880
"""
7981
Module that combines dropout, residual connection, and layer normalization.
8082
"""
83+
8184
def __init__(
8285
self,
8386
hidden_size: int,
8487
prenorm: bool = False,
8588
p: float = 0.0,
8689
eps: float = 1e-5,
8790
residual_in_fp32: bool = False,
88-
device = None,
89-
dtype = None,
91+
device=None,
92+
dtype=None,
9093
):
91-
factory_kwargs = {'device': device, 'dtype': dtype}
94+
factory_kwargs = {"device": device, "dtype": dtype}
9295
super().__init__()
9396
self.prenorm = prenorm
9497
self.p = p
@@ -101,7 +104,7 @@ def forward(
101104
self,
102105
x0: torch.Tensor,
103106
residual: Optional[torch.Tensor] = None,
104-
rowscale: Optional[torch.Tensor] = None
107+
rowscale: Optional[torch.Tensor] = None,
105108
) -> torch.Tensor:
106109
return dropout_add_layer_norm(
107110
x0,
@@ -112,28 +115,24 @@ def forward(
112115
self.eps,
113116
rowscale=rowscale,
114117
prenorm=self.prenorm,
115-
residual_in_fp32=self.residual_in_fp32
118+
residual_in_fp32=self.residual_in_fp32,
116119
)
117120

118121
def reset_parameters(self):
119122
"""Reset parameters to default initialization."""
120123
nn.init.ones_(self.weight)
121124
nn.init.zeros_(self.bias)
122125

126+
123127
class RMSNorm(nn.Module):
124128
"""
125129
Root Mean Square Layer Normalization.
126130
127131
Implementation follows the paper: https://arxiv.org/abs/1910.07467
128132
"""
129-
def __init__(
130-
self,
131-
hidden_size: int,
132-
eps: float = 1e-5,
133-
device = None,
134-
dtype = None
135-
):
136-
factory_kwargs = {'device': device, 'dtype': dtype}
133+
134+
def __init__(self, hidden_size: int, eps: float = 1e-5, device=None, dtype=None):
135+
factory_kwargs = {"device": device, "dtype": dtype}
137136
super().__init__()
138137
self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
139138
self.eps = eps
@@ -145,14 +144,11 @@ def reset_parameters(self):
145144
"""Reset parameters to default initialization."""
146145
nn.init.ones_(self.weight)
147146

148-
def rms_norm(
149-
x: torch.Tensor,
150-
weight: torch.Tensor,
151-
epsilon: float
152-
) -> torch.Tensor:
147+
148+
def rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor:
153149
"""
154150
Applies RMS normalization to the input tensor.
155151
"""
156152
norm_x = torch.mean(x * x, dim=-1, keepdim=True)
157153
x_normed = x * torch.rsqrt(norm_x + epsilon)
158-
return x_normed * weight
154+
return x_normed * weight

0 commit comments

Comments
 (0)