diff --git a/inseq/attr/feat/ops/lime.py b/inseq/attr/feat/ops/lime.py index 19cc869..ffb19ad 100644 --- a/inseq/attr/feat/ops/lime.py +++ b/inseq/attr/feat/ops/lime.py @@ -259,7 +259,7 @@ def perturb_func( ) def detach_to_list(t): - return t.detach().cpu().numpy().tolist() if type(t) == torch.Tensor else t + return t.detach().float().cpu().numpy().tolist() if type(t) == torch.Tensor else t # Additionally remove special_token_ids mask_special_token_ids = torch.Tensor(