Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save tensors in lower precision #273

Merged
merged 16 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@

- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282)

- Added a `scores_precision` to `FeatureAttributionOutput.save` to enable efficient saving in `float16` and `float8` formats. This is useful for saving large attribution outputs in a more memory-efficient way. [#273](https://github.com/inseq-team/inseq/pull/273)

```python
import inseq

attrib_model = inseq.load_model("gpt2", "attention")
out = attrib_model.attribute("Hello world", generation_kwargs={'max_new_tokens': 100})

# Previous usage, memory inefficient
out.save("output.json")

# Memory-efficient saving
out.save("output_fp16.json", scores_precision="float16") # or "float8"

# Automatic conversion to float32
out_loaded = inseq.FeatureAttributionOutput.load("output_fp16.json")
```

- - A new `SliceAggregator` (`"slices"`) is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a `FeatureAttributionSequenceOutput` object, using the same syntax of `ContiguousSpanAggregator`. The `__getitem__` method of the `FeatureAttributionSequenceOutput` is a shortcut for this, allowing slicing with `[start:stop]` syntax. [#282](https://github.com/inseq-team/inseq/pull/282)

```python
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ fix-style:

.PHONY: check-safety
check-safety:
$(PYTHON) -m safety check --full-report -i 70612
$(PYTHON) -m safety check --full-report -i 70612 -i 71670

.PHONY: lint
lint: fix-style check-safety
Expand Down
81 changes: 73 additions & 8 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
from copy import deepcopy
from dataclasses import dataclass, field
Expand All @@ -8,6 +9,8 @@
import torch

from ..utils import (
convert_from_safetensor,
convert_to_safetensor,
drop_padding,
get_sequences_from_batched_steps,
json_advanced_dump,
Expand All @@ -20,6 +23,7 @@
MultipleScoresPerSequenceTensor,
MultipleScoresPerStepTensor,
OneOrMoreTokenWithIdSequences,
ScorePrecision,
SequenceAttributionTensor,
SingleScorePerStepTensor,
SingleScoresPerSequenceTensor,
Expand Down Expand Up @@ -175,6 +179,55 @@ def __sub__(self, other: "FeatureAttributionSequenceOutput") -> "FeatureAttribut
raise ValueError(f"Cannot compare {type(other)} with {type(self)}")
return self.aggregate("pair", paired_attr=other, do_post_aggregation_checks=False)

def _convert_to_safetensors(self, scores_precision: ScorePrecision = "float32"):
"""
Converts tensor attributes within the class to the specified precision.
The conversion is based on the specified `scores_precision`.
If the input tensor is already of the desired precision, no conversion occurs.
For float8, the function performs scaling and converts to uint8, which can be later converted back to float16 upon reloading.

Args:
scores_precision (str, optional): Desired output data type precision. Defaults to "float32".
Returns:
self: The function modifies the class attributes in-place.
"""

if self.source_attributions is not None:
self.source_attributions = convert_to_safetensor(
self.source_attributions.contiguous(), scores_precision=scores_precision
)
if self.target_attributions is not None:
self.target_attributions = convert_to_safetensor(
self.target_attributions.contiguous(), scores_precision=scores_precision
)
if self.step_scores is not None:
self.step_scores = {
k: convert_to_safetensor(v.contiguous(), scores_precision=scores_precision)
for k, v in self.step_scores.items()
}
if self.sequence_scores is not None:
self.sequence_scores = {
k: convert_to_safetensor(v.contiguous(), scores_precision=scores_precision)
for k, v in self.sequence_scores.items()
}
return self

def _recover_from_safetensors(self):
"""
Converts tensor attributes within the class from b64-encoded safetensors to torch tensors.`.
"""
if self.source_attributions is not None:
self.source_attributions = convert_from_safetensor(base64.b64decode(self.source_attributions))
if self.target_attributions is not None:
self.target_attributions = convert_from_safetensor(base64.b64decode(self.target_attributions))
if self.step_scores is not None:
self.step_scores = {k: convert_from_safetensor(base64.b64decode(v)) for k, v in self.step_scores.items()}
if self.sequence_scores is not None:
self.sequence_scores = {
k: convert_from_safetensor(base64.b64decode(v)) for k, v in self.sequence_scores.items()
}
return self

@staticmethod
def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable:
if attr.source_attributions is None or name.startswith("decoder"):
Expand Down Expand Up @@ -562,6 +615,7 @@ def save(
ndarray_compact: bool = True,
use_primitives: bool = False,
split_sequences: bool = False,
scores_precision: ScorePrecision = "float32",
) -> None:
"""Save class contents to a JSON file.

Expand All @@ -583,22 +637,33 @@ def save(
If True, the output is split into multiple files, one per sequence. The file names are generated by
appending the sequence index to the given path (e.g. ``./out.json`` with two sequences ->
``./out_0.json``, ``./out_1.json``)
scores_precision (:obj:`str`, *optional*, defaults to "float32"):
Rounding precision for saved scores. Can be used to reduce space on disk but introduces rounding
errors. Can be combined with compress=True for further space reduction.
Accepted values: "float32", "float16", or "float8". Default: "float32" (no rounding).
"""
if not overwrite and Path(path).exists():
raise ValueError(f"{path} already exists. Override with overwrite=True.")
save_outs = []
paths = []
if split_sequences:
for i, seq in enumerate(self.sequence_attributions):
for seq_id in range(len(self.sequence_attributions)):
attr_out = deepcopy(self)
attr_out.sequence_attributions = [seq]
attr_out.sequence_attributions = [
attr_out.sequence_attributions[seq_id]._convert_to_safetensors(scores_precision=scores_precision)
] # this overwrites the original
attr_out.step_attributions = None
attr_out.info["input_texts"] = [attr_out.info["input_texts"][i]]
attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][i]]
attr_out.info["input_texts"] = [attr_out.info["input_texts"][seq_id]]
attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][seq_id]]
save_outs.append(attr_out)
paths.append(f"{str(path).split('.json')[0]}_{i}.json{'.gz' if compress else ''}")
paths.append(f"{str(path).split('.json')[0]}_{seq_id}.json{'.gz' if compress else ''}")
else:
save_outs.append(self)
self_out = deepcopy(self)
self_out.sequence_attributions = [
seq._convert_to_safetensors(scores_precision=scores_precision)
for seq in self_out.sequence_attributions
]
save_outs.append(self_out)
paths.append(path)
for attr_out, path_out in zip(save_outs, paths):
with open(path_out, f"w{'b' if compress else ''}") as f:
Expand Down Expand Up @@ -631,9 +696,9 @@ def load(
:class:`~inseq.data.FeatureAttributionOutput`: Loaded attribution output
"""
out = json_advanced_load(path, decompression=decompress)
out.sequence_attributions = [seq.torch() for seq in out.sequence_attributions]
out.sequence_attributions = [seq._recover_from_safetensors() for seq in out.sequence_attributions]
if out.step_attributions is not None:
out.step_attributions = [step.torch() for step in out.step_attributions]
out.step_attributions = [step._recover_from_safetensors() for step in out.step_attributions]
return out

def aggregate(
Expand Down
4 changes: 4 additions & 0 deletions inseq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
from .torch_utils import (
aggregate_contiguous,
check_device,
convert_from_safetensor,
convert_to_safetensor,
euclidean_distance,
filter_logits,
find_block_stack,
Expand All @@ -71,6 +73,8 @@
"UnknownAttributionMethodError",
"MissingAlignmentsError",
"cache_results",
"convert_to_safetensor",
"convert_from_safetensor",
"optional",
"pad",
"pretty_list",
Expand Down
7 changes: 4 additions & 3 deletions inseq/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import base64
import json
from collections import OrderedDict
from json import JSONEncoder
Expand Down Expand Up @@ -59,6 +60,8 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k
"""
if isinstance(obj, (list, dict)):
return obj
if isinstance(obj, bytes):
return base64.b64encode(obj).decode("UTF8")
if hasattr(obj, "__class__") and hasattr(obj, "__dict__"):
if not hasattr(obj, "__new__"):
raise TypeError(f"class '{obj.__class__}' does not have a __new__ method; ")
Expand All @@ -84,9 +87,7 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k
dct["attributes"] = hashodict(obj.__dict__)
if use_primitives:
attrs = dct.get("attributes", {})
return attrs
else:
return dct
return attrs if use_primitives else dct
return obj


Expand Down
33 changes: 33 additions & 0 deletions inseq/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from inspect import signature
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union

import safetensors
import torch
import torch.nn.functional as F
from jaxtyping import Int, Num
Expand Down Expand Up @@ -40,6 +41,38 @@ def remap_from_filtered(
return new_source.scatter(0, index, filtered)


def convert_to_safetensor(tensor: torch.Tensor, scores_precision="float32") -> bytes:
"""
Converts a torch tensor to a safetensor.

Args:
tensor (torch.Tensor): some torch tensor
scores_precision (str): format to convert weights to: [float32, float16, float8]
Returns:
bytes: A safetensor in bytes format
Raises:
ValueError if `scores_precision` doesn't match the possible options

"""
if scores_precision == "float32":
return safetensors.torch.save({"attribution": tensor})
elif scores_precision == "float16":
return safetensors.torch.save({"attribution": tensor.to(torch.float16)})
elif scores_precision == "float8":
logger.warning("Float8 precision is experimental and may result in loss of precision.")
return safetensors.torch.save({"attribution": tensor.to(torch.float8_e4m3fn)})
else:
raise ValueError("`scores_precision` has to be one of [float32, float16, float8]")


def convert_from_safetensor(safetensor: bytes) -> torch.Tensor:
"""
Convert a safetensor to a torch tensor and convert weights to float32.
Adapted from https://huggingface.co/docs/safetensors/metadata_parsing
"""
return safetensors.torch.load(safetensor)["attribution"].to(torch.float32)


def postprocess_attribution_scores(func: Callable) -> Callable:
@wraps(func)
def postprocess_scores_wrapper(
Expand Down
4 changes: 3 additions & 1 deletion inseq/utils/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union

import torch
from captum.attr._utils.attribution import Attribution
Expand Down Expand Up @@ -71,6 +71,8 @@ class TextSequences:
OneOrMoreTokenWithIdSequences = Sequence[Sequence[TokenWithId]]
OneOrMoreAttributionSequences = Sequence[Sequence[float]]

ScorePrecision = Literal["float32", "float16", "float8"]

IndexSpan = Union[tuple[int, int], Sequence[tuple[int, int]]]
OneOrMoreIndices = Union[int, list[int], tuple[int, int]]
OneOrMoreIndicesDict = dict[int, OneOrMoreIndices]
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dependencies = [
"numpy>=1.21.6",
"jaxtyping>=0.2.25",
"typeguard<=2.13.3",
"torch>=2.1.1",
"torch>=2.0",
"matplotlib>=3.5.3",
"tqdm>=4.64.0",
"nvidia-cublas-cu11>=11.10.3.66; sys_platform=='Linux'",
Expand Down Expand Up @@ -84,7 +84,7 @@ lint = [
"ruff>=0.2.0"
]
sklearn = [
"scikit-learn>=1.4.0",
"scikit-learn>=1.5.1",
"joblib>=1.3.2"
]
datasets = [
Expand Down
10 changes: 5 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ asttokens==2.4.1
# via stack-data
attrs==23.2.0
# via aiohttp
authlib==1.3.0
authlib==1.3.1
# via safety
babel==2.14.0
# via sphinx
Expand Down Expand Up @@ -46,7 +46,7 @@ contourpy==1.2.0
# via matplotlib
coverage==7.4.1
# via pytest-cov
cryptography==42.0.5
cryptography==43.0.0
# via authlib
cycler==0.12.1
# via matplotlib
Expand Down Expand Up @@ -311,11 +311,11 @@ safety==3.1.0
# via inseq (pyproject.toml)
safety-schemas==0.0.2
# via safety
scikit-learn==1.4.0
scikit-learn==1.5.1
# via inseq (pyproject.toml)
scipy==1.12.0
# via scikit-learn
sentencepiece==0.1.99
sentencepiece==0.2.0
# via transformers
setuptools==69.1.0
# via
Expand Down Expand Up @@ -375,7 +375,7 @@ threadpoolctl==3.2.0
# via scikit-learn
tokenizers==0.15.2
# via transformers
torch==2.2.0
torch==2.3.1
# via
# inseq (pyproject.toml)
# captum
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ rich==13.7.0
# via inseq (pyproject.toml)
safetensors==0.4.2
# via transformers
sentencepiece==0.1.99
sentencepiece==0.2.0
# via transformers
six==1.16.0
# via python-dateutil
sympy==1.12
# via torch
tokenizers==0.15.2
# via transformers
torch==2.2.0
torch==2.3.1
# via
# inseq (pyproject.toml)
# captum
Expand Down
Loading
Loading