Skip to content

Commit 5f9ecc9

Browse files
authored
Merge pull request #101 from EleutherAI/more-tests
Extended tests, runtime type checking
2 parents 4c0c951 + 7e3ca38 commit 5f9ecc9

28 files changed

+727
-150
lines changed

delphi/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def create_neighbours(
9191
elif constructor_cfg.neighbours_type == "decoder_similarity":
9292

9393
neighbour_calculator = NeighbourCalculator(
94-
autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250
94+
autoencoder=saes[hookpoint].to("cuda"), number_of_neighbours=250
9595
)
9696

9797
elif constructor_cfg.neighbours_type == "encoder_similarity":
9898
neighbour_calculator = NeighbourCalculator(
99-
autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250
99+
autoencoder=saes[hookpoint].to("cuda"), number_of_neighbours=250
100100
)
101101
else:
102102
raise ValueError(
@@ -136,7 +136,7 @@ async def process_cache(
136136
} # The latent range to explain
137137

138138
dataset = LatentDataset(
139-
raw_dir=str(latents_path),
139+
raw_dir=latents_path,
140140
sampler_cfg=run_cfg.sampler_cfg,
141141
constructor_cfg=run_cfg.constructor_cfg,
142142
modules=hookpoints,

delphi/clients/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, model: str):
1717
@abstractmethod
1818
async def generate(
1919
self, prompt: Union[str, list[dict[str, str]]], **kwargs
20-
) -> Response:
20+
) -> str | Response:
2121
pass
2222

2323
# @abstractmethod

delphi/clients/offline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ def __init__(
7474
self.statistics_path = Path("statistics")
7575
self.statistics_path.mkdir(parents=True, exist_ok=True)
7676

77-
async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
77+
async def process_func(
78+
self,
79+
batches: Union[str, list[Union[dict[str, str], list[dict[str, str]]]]],
80+
kwargs,
81+
):
7882
"""
7983
Process a single request.
8084
"""
@@ -142,7 +146,9 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
142146
)
143147
return new_response
144148

145-
async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: # type: ignore
149+
async def generate(
150+
self, prompt: Union[str, list[dict[str, str]]], **kwargs
151+
) -> Response: # type: ignore
146152
"""
147153
Enqueue a request and wait for the result.
148154
"""

delphi/clients/openrouter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ..logger import logger
77
from .client import Client
8+
from .types import ChatFormatRequest
89

910
# Preferred provider routing arguments.
1011
# Change depending on what model you'd like to use.
@@ -37,7 +38,11 @@ def postprocess(self, response):
3738
return Response(msg)
3839

3940
async def generate( # type: ignore
40-
self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs # type: ignore
41+
self,
42+
prompt: ChatFormatRequest,
43+
raw: bool = False,
44+
max_retries: int = 1,
45+
**kwargs, # type: ignore
4146
) -> Response: # type: ignore
4247
kwargs.pop("schema", None)
4348
max_tokens = kwargs.pop("max_tokens", 500)

delphi/clients/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Literal, TypedDict, Union
2+
3+
4+
class Message(TypedDict):
5+
content: str
6+
role: Literal["system", "user", "assistant"]
7+
8+
9+
ChatFormatRequest = Union[str, list[str], list[Message], None]

delphi/explainers/contrastive_explainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from delphi.explainers.default.prompts import SYSTEM_CONTRASTIVE
7-
from delphi.explainers.explainer import Explainer, ExplainerResult
7+
from delphi.explainers.explainer import Explainer, ExplainerResult, Response
88
from delphi.latents.latents import ActivatingExample, LatentRecord, NonActivatingExample
99

1010

@@ -54,7 +54,11 @@ async def __call__(self, record: LatentRecord) -> ExplainerResult:
5454
)
5555

5656
try:
57-
explanation = self.parse_explanation(response.text)
57+
if isinstance(response, Response):
58+
response_text = response.text
59+
else:
60+
response_text = response
61+
explanation = self.parse_explanation(response_text)
5862
if self.verbose:
5963
from ..logger import logger
6064

delphi/explainers/explainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import aiofiles
1010

11-
from ..clients.client import Client
11+
from ..clients.client import Client, Response
1212
from ..latents.latents import ActivatingExample, LatentRecord
1313
from ..logger import logger
1414

@@ -44,6 +44,7 @@ async def __call__(self, record: LatentRecord) -> ExplainerResult:
4444
response = await self.client.generate(
4545
messages, temperature=self.temperature, **self.generation_kwargs
4646
)
47+
assert isinstance(response, Response)
4748

4849
try:
4950
explanation = self.parse_explanation(response.text)

delphi/latents/cache.py

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import json
22
from collections import defaultdict
3+
from collections.abc import Callable
34
from dataclasses import dataclass
45
from pathlib import Path
5-
from typing import Callable
66

77
import numpy as np
88
import torch
@@ -15,8 +15,49 @@
1515
from delphi.config import CacheConfig
1616
from delphi.latents.collect_activations import collect_activations
1717

18-
location_tensor_shape = Float[Tensor, "batch sequence num_latents"]
19-
token_tensor_shape = Float[Tensor, "batch sequence"]
18+
location_tensor_type = Int[Tensor, "batch_sequence 3"]
19+
activation_tensor_type = Float[Tensor, "batch_sequence"]
20+
token_tensor_type = Int[Tensor, "batch sequence"]
21+
latent_tensor_type = Float[Tensor, "batch sequence num_latents"]
22+
23+
24+
def get_nonzeros_batch(
25+
latents: latent_tensor_type,
26+
) -> tuple[
27+
Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "]
28+
]:
29+
"""
30+
Get non-zero activations for large batches that exceed int32 max value.
31+
32+
Args:
33+
latents: Input latent activations.
34+
35+
Returns:
36+
tuple[Tensor, Tensor]: Non-zero latent locations and activations.
37+
"""
38+
# Calculate the maximum batch size that fits within sys.maxsize
39+
max_batch_size = torch.iinfo(torch.int32).max // (
40+
latents.shape[1] * latents.shape[2]
41+
)
42+
nonzero_latent_locations = []
43+
nonzero_latent_activations = []
44+
45+
for i in range(0, latents.shape[0], max_batch_size):
46+
batch = latents[i : i + max_batch_size]
47+
48+
# Get nonzero locations and activations
49+
batch_locations = torch.nonzero(batch.abs() > 1e-5)
50+
batch_activations = batch[batch.abs() > 1e-5]
51+
52+
# Adjust indices to account for batching
53+
batch_locations[:, 0] += i
54+
nonzero_latent_locations.append(batch_locations)
55+
nonzero_latent_activations.append(batch_activations)
56+
57+
# Concatenate results
58+
nonzero_latent_locations = torch.cat(nonzero_latent_locations, dim=0)
59+
nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0)
60+
return nonzero_latent_locations, nonzero_latent_activations
2061

2162

2263
class InMemoryCache:
@@ -37,25 +78,25 @@ def __init__(
3778
filters: Filters for selecting specific latents.
3879
batch_size: Size of batches for processing. Defaults to 64.
3980
"""
40-
self.latent_locations_batches: dict[str, list[location_tensor_shape]] = (
81+
self.latent_locations_batches: dict[str, list[location_tensor_type]] = (
4182
defaultdict(list)
4283
)
43-
self.latent_activations_batches: dict[str, list[location_tensor_shape]] = (
84+
self.latent_activations_batches: dict[str, list[latent_tensor_type]] = (
4485
defaultdict(list)
4586
)
46-
self.tokens_batches: dict[str, list[token_tensor_shape]] = defaultdict(list)
87+
self.tokens_batches: dict[str, list[token_tensor_type]] = defaultdict(list)
4788

48-
self.latent_locations: dict[str, location_tensor_shape] = {}
49-
self.latent_activations: dict[str, location_tensor_shape] = {}
50-
self.tokens: dict[str, token_tensor_shape] = {}
89+
self.latent_locations: dict[str, location_tensor_type] = {}
90+
self.latent_activations: dict[str, latent_tensor_type] = {}
91+
self.tokens: dict[str, token_tensor_type] = {}
5192

5293
self.filters = filters
5394
self.batch_size = batch_size
5495

5596
def add(
5697
self,
57-
latents: location_tensor_shape,
58-
tokens: token_tensor_shape,
98+
latents: latent_tensor_type,
99+
tokens: token_tensor_type,
59100
batch_number: int,
60101
module_path: str,
61102
):
@@ -96,47 +137,9 @@ def save(self):
96137
self.tokens_batches[module_path], dim=0
97138
)
98139

99-
def get_nonzeros_batch(
100-
self, latents: location_tensor_shape
101-
) -> tuple[
102-
Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "]
103-
]:
104-
"""
105-
Get non-zero activations for large batches that exceed int32 max value.
106-
107-
Args:
108-
latents: Input latent activations.
109-
110-
Returns:
111-
tuple[Tensor, Tensor]: Non-zero latent locations and activations.
112-
"""
113-
# Calculate the maximum batch size that fits within sys.maxsize
114-
max_batch_size = torch.iinfo(torch.int32).max // (
115-
latents.shape[1] * latents.shape[2]
116-
)
117-
nonzero_latent_locations = []
118-
nonzero_latent_activations = []
119-
120-
for i in range(0, latents.shape[0], max_batch_size):
121-
batch = latents[i : i + max_batch_size]
122-
123-
# Get nonzero locations and activations
124-
batch_locations = torch.nonzero(batch.abs() > 1e-5)
125-
batch_activations = batch[batch.abs() > 1e-5]
126-
127-
# Adjust indices to account for batching
128-
batch_locations[:, 0] += i
129-
nonzero_latent_locations.append(batch_locations)
130-
nonzero_latent_activations.append(batch_activations)
131-
132-
# Concatenate results
133-
nonzero_latent_locations = torch.cat(nonzero_latent_locations, dim=0)
134-
nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0)
135-
return nonzero_latent_locations, nonzero_latent_activations
136-
137-
def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tuple[
138-
location_tensor_shape,
139-
location_tensor_shape,
140+
def get_nonzeros(self, latents: latent_tensor_type, module_path: str) -> tuple[
141+
location_tensor_type,
142+
activation_tensor_type,
140143
]:
141144
"""
142145
Get the nonzero latent locations and activations.
@@ -153,7 +156,7 @@ def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tupl
153156
(
154157
nonzero_latent_locations,
155158
nonzero_latent_activations,
156-
) = self.get_nonzeros_batch(latents)
159+
) = get_nonzeros_batch(latents)
157160
else:
158161
nonzero_latent_locations = torch.nonzero(latents.abs() > 1e-5)
159162
nonzero_latent_activations = latents[latents.abs() > 1e-5]
@@ -209,8 +212,8 @@ def __init__(
209212
self.filter_submodules(filters)
210213

211214
def load_token_batches(
212-
self, n_tokens: int, tokens: token_tensor_shape
213-
) -> list[token_tensor_shape]:
215+
self, n_tokens: int, tokens: token_tensor_type
216+
) -> list[token_tensor_type]:
214217
"""
215218
Load and prepare token batches for processing.
216219
@@ -248,7 +251,7 @@ def filter_submodules(self, filters: dict[str, Float[Tensor, "indices"]]):
248251
]
249252
self.hookpoint_to_sparse_encode = filtered_submodules
250253

251-
def run(self, n_tokens: int, tokens: token_tensor_shape):
254+
def run(self, n_tokens: int, tokens: token_tensor_type):
252255
"""
253256
Run the latent caching process.
254257
@@ -521,11 +524,11 @@ def generate_statistics_cache(
521524
print(f"Fraction of strong single token latents: {strong_token_fraction:%}")
522525

523526
return CacheStatistics(
524-
frac_alive=fraction_alive,
525-
frac_fired_1pct=one_percent,
526-
frac_fired_10pct=ten_percent,
527-
frac_weak_single_token=single_token_fraction,
528-
frac_strong_single_token=strong_token_fraction,
527+
frac_alive=float(fraction_alive),
528+
frac_fired_1pct=float(one_percent),
529+
frac_fired_10pct=float(ten_percent),
530+
frac_weak_single_token=float(single_token_fraction),
531+
frac_strong_single_token=float(strong_token_fraction),
529532
)
530533

531534

delphi/latents/collect_activations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def collect_activations(
2525
handles = []
2626

2727
def create_hook(hookpoint: str, transcode: bool = False):
28-
def hook_fn(module: nn.Module, input: Any, output: Tensor) -> Tensor | None:
28+
def hook_fn(
29+
module: nn.Module, input: Any, output: Tensor | tuple[Tensor]
30+
) -> Tensor | None:
2931
# If output is a tuple (like in some transformer layers), take first element
3032
if transcode:
3133
if isinstance(input, tuple):

0 commit comments

Comments
 (0)