diff --git a/tests/unit/utilities/test_tensors.py b/tests/unit/utilities/test_tensors.py index bbd2e0363..68e47c0e4 100644 --- a/tests/unit/utilities/test_tensors.py +++ b/tests/unit/utilities/test_tensors.py @@ -3,9 +3,11 @@ This module tests the tensor utility functions, particularly the filter_dict_by_prefix function. """ +import numpy as np +import pytest import torch -from transformer_lens.utilities.tensors import filter_dict_by_prefix +from transformer_lens.utilities.tensors import filter_dict_by_prefix, to_numpy class TestFilterDictByPrefix: @@ -205,3 +207,73 @@ def test_filter_dict_multiple_prefixes_sequentially(self): assert "layer1.weight" in decoder_result # Ensure original dict is not modified assert len(test_dict) == 4 + + +class TestToNumpy: + """Test cases for the to_numpy function.""" + + def test_bfloat16_tensor(self): + """bfloat16 tensors should be upcast to float32 instead of raising a TypeError. + + NumPy has no bfloat16 dtype, so calling .numpy() on a bfloat16 tensor raises + ``TypeError: Got unsupported ScalarType BFloat16``. bfloat16 is common in + TransformerLens because many pretrained models load in reduced precision. + """ + tensor = torch.tensor([1.0, 2.0, -3.5], dtype=torch.bfloat16) + result = to_numpy(tensor) + assert isinstance(result, np.ndarray) + assert result.dtype == np.float32 + # Values that are exactly representable in bfloat16 should round-trip exactly. + np.testing.assert_array_equal(result, np.array([1.0, 2.0, -3.5], dtype=np.float32)) + + def test_float32_tensor_passthrough(self): + """float32 tensors should convert without dtype changes.""" + tensor = torch.tensor([1.5, 2.5], dtype=torch.float32) + result = to_numpy(tensor) + assert isinstance(result, np.ndarray) + assert result.dtype == np.float32 + np.testing.assert_array_equal(result, np.array([1.5, 2.5], dtype=np.float32)) + + def test_float16_tensor(self): + """float16 tensors are representable in numpy and should be preserved.""" + tensor = torch.tensor([1.0, 2.0], dtype=torch.float16) + result = to_numpy(tensor) + assert isinstance(result, np.ndarray) + assert result.dtype == np.float16 + + def test_int_tensor(self): + """Integer tensors should convert without modification.""" + tensor = torch.tensor([1, 2, 3], dtype=torch.int64) + result = to_numpy(tensor) + assert isinstance(result, np.ndarray) + np.testing.assert_array_equal(result, np.array([1, 2, 3])) + + def test_parameter_bfloat16(self): + """nn.Parameter wrapping a bfloat16 tensor should also be handled.""" + param = torch.nn.Parameter(torch.tensor([4.0, 5.0], dtype=torch.bfloat16)) + result = to_numpy(param) + assert isinstance(result, np.ndarray) + assert result.dtype == np.float32 + np.testing.assert_array_equal(result, np.array([4.0, 5.0], dtype=np.float32)) + + def test_numpy_array_passthrough(self): + """numpy arrays should be returned unchanged.""" + array = np.array([1.0, 2.0]) + result = to_numpy(array) + assert result is array + + def test_list_and_tuple(self): + """Lists and tuples should be converted to numpy arrays.""" + np.testing.assert_array_equal(to_numpy([1, 2, 3]), np.array([1, 2, 3])) + np.testing.assert_array_equal(to_numpy((4, 5, 6)), np.array([4, 5, 6])) + + def test_scalar(self): + """Python scalars should be converted to numpy arrays.""" + result = to_numpy(3.5) + assert isinstance(result, np.ndarray) + assert result.item() == 3.5 + + def test_invalid_type_raises(self): + """Unsupported types should raise a ValueError.""" + with pytest.raises(ValueError, match="invalid type"): + to_numpy({"a": 1}) diff --git a/transformer_lens/utilities/tensors.py b/transformer_lens/utilities/tensors.py index bc2034694..5e0971b13 100644 --- a/transformer_lens/utilities/tensors.py +++ b/transformer_lens/utilities/tensors.py @@ -23,7 +23,13 @@ def to_numpy(tensor): array = np.array(tensor) return array elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): - return tensor.detach().cpu().numpy() + tensor = tensor.detach().cpu() + # NumPy has no bfloat16 dtype, so calling .numpy() directly on a bfloat16 + # tensor raises a TypeError. Upcast to float32 first (bfloat16 is common in + # TransformerLens since many pretrained models are loaded in reduced precision). + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float32) + return tensor.numpy() elif isinstance(tensor, (int, float, bool, str)): return np.array(tensor) else: