From e4ecc068908a593635f5599d5f4d54cfea54e702 Mon Sep 17 00:00:00 2001 From: Yu Zhao <160552605+yuzhaouoe@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:06:49 +0100 Subject: [PATCH] Update lime.py Convert tensor's dtype to "float" before moving to CPU, which avoids errors when using "bf16" --- inseq/attr/feat/ops/lime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inseq/attr/feat/ops/lime.py b/inseq/attr/feat/ops/lime.py index 19cc869c..ffb19ad1 100644 --- a/inseq/attr/feat/ops/lime.py +++ b/inseq/attr/feat/ops/lime.py @@ -259,7 +259,7 @@ def perturb_func( ) def detach_to_list(t): - return t.detach().cpu().numpy().tolist() if type(t) == torch.Tensor else t + return t.detach().float().cpu().numpy().tolist() if type(t) == torch.Tensor else t # Additionally remove special_token_ids mask_special_token_ids = torch.Tensor(