Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion tests/unit/utilities/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
This module tests the tensor utility functions, particularly the filter_dict_by_prefix function.
"""

import numpy as np
import pytest
import torch

from transformer_lens.utilities.tensors import filter_dict_by_prefix
from transformer_lens.utilities.tensors import filter_dict_by_prefix, to_numpy


class TestFilterDictByPrefix:
Expand Down Expand Up @@ -205,3 +207,73 @@ def test_filter_dict_multiple_prefixes_sequentially(self):
assert "layer1.weight" in decoder_result
# Ensure original dict is not modified
assert len(test_dict) == 4


class TestToNumpy:
"""Test cases for the to_numpy function."""

def test_bfloat16_tensor(self):
"""bfloat16 tensors should be upcast to float32 instead of raising a TypeError.

NumPy has no bfloat16 dtype, so calling .numpy() on a bfloat16 tensor raises
``TypeError: Got unsupported ScalarType BFloat16``. bfloat16 is common in
TransformerLens because many pretrained models load in reduced precision.
"""
tensor = torch.tensor([1.0, 2.0, -3.5], dtype=torch.bfloat16)
result = to_numpy(tensor)
assert isinstance(result, np.ndarray)
assert result.dtype == np.float32
# Values that are exactly representable in bfloat16 should round-trip exactly.
np.testing.assert_array_equal(result, np.array([1.0, 2.0, -3.5], dtype=np.float32))

def test_float32_tensor_passthrough(self):
"""float32 tensors should convert without dtype changes."""
tensor = torch.tensor([1.5, 2.5], dtype=torch.float32)
result = to_numpy(tensor)
assert isinstance(result, np.ndarray)
assert result.dtype == np.float32
np.testing.assert_array_equal(result, np.array([1.5, 2.5], dtype=np.float32))

def test_float16_tensor(self):
"""float16 tensors are representable in numpy and should be preserved."""
tensor = torch.tensor([1.0, 2.0], dtype=torch.float16)
result = to_numpy(tensor)
assert isinstance(result, np.ndarray)
assert result.dtype == np.float16

def test_int_tensor(self):
"""Integer tensors should convert without modification."""
tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
result = to_numpy(tensor)
assert isinstance(result, np.ndarray)
np.testing.assert_array_equal(result, np.array([1, 2, 3]))

def test_parameter_bfloat16(self):
"""nn.Parameter wrapping a bfloat16 tensor should also be handled."""
param = torch.nn.Parameter(torch.tensor([4.0, 5.0], dtype=torch.bfloat16))
result = to_numpy(param)
assert isinstance(result, np.ndarray)
assert result.dtype == np.float32
np.testing.assert_array_equal(result, np.array([4.0, 5.0], dtype=np.float32))

def test_numpy_array_passthrough(self):
"""numpy arrays should be returned unchanged."""
array = np.array([1.0, 2.0])
result = to_numpy(array)
assert result is array

def test_list_and_tuple(self):
"""Lists and tuples should be converted to numpy arrays."""
np.testing.assert_array_equal(to_numpy([1, 2, 3]), np.array([1, 2, 3]))
np.testing.assert_array_equal(to_numpy((4, 5, 6)), np.array([4, 5, 6]))

def test_scalar(self):
"""Python scalars should be converted to numpy arrays."""
result = to_numpy(3.5)
assert isinstance(result, np.ndarray)
assert result.item() == 3.5

def test_invalid_type_raises(self):
"""Unsupported types should raise a ValueError."""
with pytest.raises(ValueError, match="invalid type"):
to_numpy({"a": 1})
8 changes: 7 additions & 1 deletion transformer_lens/utilities/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ def to_numpy(tensor):
array = np.array(tensor)
return array
elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)):
return tensor.detach().cpu().numpy()
tensor = tensor.detach().cpu()
# NumPy has no bfloat16 dtype, so calling .numpy() directly on a bfloat16
# tensor raises a TypeError. Upcast to float32 first (bfloat16 is common in
# TransformerLens since many pretrained models are loaded in reduced precision).
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)
return tensor.numpy()
elif isinstance(tensor, (int, float, bool, str)):
return np.array(tensor)
else:
Expand Down
Loading