1
1
import json
2
2
from collections import defaultdict
3
+ from collections .abc import Callable
3
4
from dataclasses import dataclass
4
5
from pathlib import Path
5
- from typing import Callable
6
6
7
7
import numpy as np
8
8
import torch
15
15
from delphi .config import CacheConfig
16
16
from delphi .latents .collect_activations import collect_activations
17
17
18
- location_tensor_shape = Float [Tensor , "batch sequence num_latents" ]
19
- token_tensor_shape = Float [Tensor , "batch sequence" ]
18
+ location_tensor_type = Int [Tensor , "batch_sequence 3" ]
19
+ activation_tensor_type = Float [Tensor , "batch_sequence" ]
20
+ token_tensor_type = Int [Tensor , "batch sequence" ]
21
+ latent_tensor_type = Float [Tensor , "batch sequence num_latents" ]
22
+
23
+
24
+ def get_nonzeros_batch (
25
+ latents : latent_tensor_type ,
26
+ ) -> tuple [
27
+ Float [Tensor , "batch sequence num_latents" ], Float [Tensor , "batch sequence " ]
28
+ ]:
29
+ """
30
+ Get non-zero activations for large batches that exceed int32 max value.
31
+
32
+ Args:
33
+ latents: Input latent activations.
34
+
35
+ Returns:
36
+ tuple[Tensor, Tensor]: Non-zero latent locations and activations.
37
+ """
38
+ # Calculate the maximum batch size that fits within sys.maxsize
39
+ max_batch_size = torch .iinfo (torch .int32 ).max // (
40
+ latents .shape [1 ] * latents .shape [2 ]
41
+ )
42
+ nonzero_latent_locations = []
43
+ nonzero_latent_activations = []
44
+
45
+ for i in range (0 , latents .shape [0 ], max_batch_size ):
46
+ batch = latents [i : i + max_batch_size ]
47
+
48
+ # Get nonzero locations and activations
49
+ batch_locations = torch .nonzero (batch .abs () > 1e-5 )
50
+ batch_activations = batch [batch .abs () > 1e-5 ]
51
+
52
+ # Adjust indices to account for batching
53
+ batch_locations [:, 0 ] += i
54
+ nonzero_latent_locations .append (batch_locations )
55
+ nonzero_latent_activations .append (batch_activations )
56
+
57
+ # Concatenate results
58
+ nonzero_latent_locations = torch .cat (nonzero_latent_locations , dim = 0 )
59
+ nonzero_latent_activations = torch .cat (nonzero_latent_activations , dim = 0 )
60
+ return nonzero_latent_locations , nonzero_latent_activations
20
61
21
62
22
63
class InMemoryCache :
@@ -37,25 +78,25 @@ def __init__(
37
78
filters: Filters for selecting specific latents.
38
79
batch_size: Size of batches for processing. Defaults to 64.
39
80
"""
40
- self .latent_locations_batches : dict [str , list [location_tensor_shape ]] = (
81
+ self .latent_locations_batches : dict [str , list [location_tensor_type ]] = (
41
82
defaultdict (list )
42
83
)
43
- self .latent_activations_batches : dict [str , list [location_tensor_shape ]] = (
84
+ self .latent_activations_batches : dict [str , list [latent_tensor_type ]] = (
44
85
defaultdict (list )
45
86
)
46
- self .tokens_batches : dict [str , list [token_tensor_shape ]] = defaultdict (list )
87
+ self .tokens_batches : dict [str , list [token_tensor_type ]] = defaultdict (list )
47
88
48
- self .latent_locations : dict [str , location_tensor_shape ] = {}
49
- self .latent_activations : dict [str , location_tensor_shape ] = {}
50
- self .tokens : dict [str , token_tensor_shape ] = {}
89
+ self .latent_locations : dict [str , location_tensor_type ] = {}
90
+ self .latent_activations : dict [str , latent_tensor_type ] = {}
91
+ self .tokens : dict [str , token_tensor_type ] = {}
51
92
52
93
self .filters = filters
53
94
self .batch_size = batch_size
54
95
55
96
def add (
56
97
self ,
57
- latents : location_tensor_shape ,
58
- tokens : token_tensor_shape ,
98
+ latents : latent_tensor_type ,
99
+ tokens : token_tensor_type ,
59
100
batch_number : int ,
60
101
module_path : str ,
61
102
):
@@ -96,47 +137,9 @@ def save(self):
96
137
self .tokens_batches [module_path ], dim = 0
97
138
)
98
139
99
- def get_nonzeros_batch (
100
- self , latents : location_tensor_shape
101
- ) -> tuple [
102
- Float [Tensor , "batch sequence num_latents" ], Float [Tensor , "batch sequence " ]
103
- ]:
104
- """
105
- Get non-zero activations for large batches that exceed int32 max value.
106
-
107
- Args:
108
- latents: Input latent activations.
109
-
110
- Returns:
111
- tuple[Tensor, Tensor]: Non-zero latent locations and activations.
112
- """
113
- # Calculate the maximum batch size that fits within sys.maxsize
114
- max_batch_size = torch .iinfo (torch .int32 ).max // (
115
- latents .shape [1 ] * latents .shape [2 ]
116
- )
117
- nonzero_latent_locations = []
118
- nonzero_latent_activations = []
119
-
120
- for i in range (0 , latents .shape [0 ], max_batch_size ):
121
- batch = latents [i : i + max_batch_size ]
122
-
123
- # Get nonzero locations and activations
124
- batch_locations = torch .nonzero (batch .abs () > 1e-5 )
125
- batch_activations = batch [batch .abs () > 1e-5 ]
126
-
127
- # Adjust indices to account for batching
128
- batch_locations [:, 0 ] += i
129
- nonzero_latent_locations .append (batch_locations )
130
- nonzero_latent_activations .append (batch_activations )
131
-
132
- # Concatenate results
133
- nonzero_latent_locations = torch .cat (nonzero_latent_locations , dim = 0 )
134
- nonzero_latent_activations = torch .cat (nonzero_latent_activations , dim = 0 )
135
- return nonzero_latent_locations , nonzero_latent_activations
136
-
137
- def get_nonzeros (self , latents : location_tensor_shape , module_path : str ) -> tuple [
138
- location_tensor_shape ,
139
- location_tensor_shape ,
140
+ def get_nonzeros (self , latents : latent_tensor_type , module_path : str ) -> tuple [
141
+ location_tensor_type ,
142
+ activation_tensor_type ,
140
143
]:
141
144
"""
142
145
Get the nonzero latent locations and activations.
@@ -153,7 +156,7 @@ def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tupl
153
156
(
154
157
nonzero_latent_locations ,
155
158
nonzero_latent_activations ,
156
- ) = self . get_nonzeros_batch (latents )
159
+ ) = get_nonzeros_batch (latents )
157
160
else :
158
161
nonzero_latent_locations = torch .nonzero (latents .abs () > 1e-5 )
159
162
nonzero_latent_activations = latents [latents .abs () > 1e-5 ]
@@ -209,8 +212,8 @@ def __init__(
209
212
self .filter_submodules (filters )
210
213
211
214
def load_token_batches (
212
- self , n_tokens : int , tokens : token_tensor_shape
213
- ) -> list [token_tensor_shape ]:
215
+ self , n_tokens : int , tokens : token_tensor_type
216
+ ) -> list [token_tensor_type ]:
214
217
"""
215
218
Load and prepare token batches for processing.
216
219
@@ -248,7 +251,7 @@ def filter_submodules(self, filters: dict[str, Float[Tensor, "indices"]]):
248
251
]
249
252
self .hookpoint_to_sparse_encode = filtered_submodules
250
253
251
- def run (self , n_tokens : int , tokens : token_tensor_shape ):
254
+ def run (self , n_tokens : int , tokens : token_tensor_type ):
252
255
"""
253
256
Run the latent caching process.
254
257
@@ -521,11 +524,11 @@ def generate_statistics_cache(
521
524
print (f"Fraction of strong single token latents: { strong_token_fraction :%} " )
522
525
523
526
return CacheStatistics (
524
- frac_alive = fraction_alive ,
525
- frac_fired_1pct = one_percent ,
526
- frac_fired_10pct = ten_percent ,
527
- frac_weak_single_token = single_token_fraction ,
528
- frac_strong_single_token = strong_token_fraction ,
527
+ frac_alive = float ( fraction_alive ) ,
528
+ frac_fired_1pct = float ( one_percent ) ,
529
+ frac_fired_10pct = float ( ten_percent ) ,
530
+ frac_weak_single_token = float ( single_token_fraction ) ,
531
+ frac_strong_single_token = float ( strong_token_fraction ) ,
529
532
)
530
533
531
534
0 commit comments