Skip to content

Commit 8ad69e7

Browse files
authored
fix: Fix padding token bug in featurize (#19)
* Fixed padding bug * Bumped version
1 parent 73dc242 commit 8ad69e7

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

tokenlearn/featurize.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,19 @@ def featurize(
6565
continue # Skip empty batches
6666

6767
# Encode the batch to get token embeddings
68-
token_embeddings = model.encode(
69-
list_batch,
70-
output_value="token_embeddings",
71-
convert_to_tensor=True,
72-
)
68+
token_embeddings = model.encode(list_batch, output_value="token_embeddings", convert_to_numpy=True)
7369

7470
# Tokenize the batch to get input IDs
7571
tokenized_ids = model.tokenize(list_batch)["input_ids"]
7672

7773
for tokenized_id, token_embedding in zip(tokenized_ids, token_embeddings):
78-
# Convert token IDs to tokens (excluding special tokens)
79-
token_ids = tokenized_id[1:-1]
80-
# Decode tokens to text
81-
text = model.tokenizer.decode(token_ids)
74+
# Decode the token IDs to get the text
75+
text = model.tokenizer.decode(tokenized_id, skip_special_tokens=True)
8276
if text in seen:
8377
continue
8478
seen.add(text)
8579
# Get the corresponding token embeddings (excluding special tokens)
86-
token_embeds = token_embedding[1:-1]
87-
# Convert embeddings to NumPy arrays
88-
token_embeds = token_embeds.detach().cpu().numpy()
80+
token_embeds = token_embedding[1:-1].detach().cpu().numpy()
8981
# Compute the mean of the token embeddings
9082
mean = np.mean(token_embeds, axis=0)
9183
txts.append(text)

tokenlearn/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version_triple__ = (0, 1, 1)
1+
__version_triple__ = (0, 1, 2)
22
__version__ = ".".join(map(str, __version_triple__))

0 commit comments

Comments
 (0)