-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_colt5.py
128 lines (98 loc) · 5.25 KB
/
eval_colt5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import T5Tokenizer, DataCollatorWithPadding
from tqdm import tqdm # Import tqdm for the progress bar
from colt5_attention.colt5_model import CoLT5
from colt5_attention.transformer_block import CoordinateDescentRouter
import torch.nn as nn
import pickle
def extract_router_history(model):
"""
Extracts the routing history from all CoordinateDescentRouter instances within the model.
Args:
model (nn.Module): The CoLT5 model instance.
Returns:
dict: A dictionary where keys are router names (as per model's module hierarchy)
and values are their corresponding routing histories.
"""
router_histories = {}
for name, module in model.named_modules():
if isinstance(module, CoordinateDescentRouter):
router_histories[name] = module.routing_history
return router_histories
# Load the model and tokenizer
model = CoLT5(num_layers=6, dim=512).to('cuda') # Adjust this if your architecture changes
model.load_state_dict(torch.load('./checkpoints_925/best_colt5.pth'))
model.eval() # Set the model to evaluation mode
tokenizer = T5Tokenizer.from_pretrained('t5-small')
# Load and preprocess the test dataset
test_dataset = load_dataset('trivia_qa', 'unfiltered', split='validation') # Or 'test' if available
def preprocess_function(examples):
inputs = [f"trivia question: {question}" for question in examples['question']]
model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding='max_length', return_tensors='pt')
# Handle answers
answers = [answer['value'] if len(answer['value']) > 0 else "" for answer in examples['answer']]
labels = tokenizer(answers, max_length=128, truncation=True, padding='max_length', return_tensors='pt')
model_inputs['labels'] = labels['input_ids']
# Convert attention_mask to boolean type
model_inputs['attention_mask'] = model_inputs['attention_mask'].bool()
return model_inputs
# Tokenize the test dataset
tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True)
# Remove unnecessary columns after tokenization
tokenized_dataset = tokenized_test_dataset.remove_columns(['question', 'question_id', 'question_source', 'entity_pages', 'search_results', 'answer'])
# Print data types of all columns in the tokenized dataset
for column in tokenized_dataset.features:
print(f"Column: {column}, Type: {tokenized_dataset.features[column]}")
# Use the DataCollatorWithPadding to pad inputs dynamically
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
# DataLoader for the test set
test_loader = DataLoader(tokenized_dataset, batch_size=64, shuffle=True, collate_fn=data_collator)
# Evaluate the model
total_loss = 0
with torch.no_grad(): # Disable gradient calculation for evaluation
loop = tqdm(test_loader, leave=True, desc="Evaluating")
for batch in loop:
input_ids = batch['input_ids'].to('cuda')
labels = batch['labels'].to('cuda')
mask = batch['attention_mask'].to('cuda')
decoder_input_ids = torch.full(labels.shape, tokenizer.pad_token_id, dtype=torch.long).to('cuda')
decoder_input_ids[:,1:] = labels[:,:-1] # Shift labels for decoder input
# Forward pass
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, mask=mask, keep_routing_history=True)
# Loss function: Cross-Entropy
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
total_loss += loss.item()
loop.set_postfix(loss=loss.item()) # Update progress bar with the current loss
average_loss = total_loss / len(test_loader)
print(f"Average Test Loss: {average_loss}")
# Extract routing histories
router_histories = extract_router_history(model)
# for router_name, history in router_histories.items():
# print(f"Router: {router_name}")
# print(f"Selected Indices: {len(history['selected_indices'])}")
# def compare_similarity(router_name):
# kv_router = ".conditional_attn.kv_router"
# ffn_router = ".conditional_ff.router"
# selected_kv = router_histories[router_name+kv_router]['selected_indices'][0]
# selected_ffn = router_histories[router_name+ffn_router]['selected_indices'][0]
# common_indices = set(selected_kv).intersection(selected_ffn)
# similarity = len(common_indices) / len(kv_router)
# return similarity
from routing_history_analysis import compare_similarity_per_batch, plot_similarity_histogram
# Define the number of encoder layers
num_encoder_layers = 6 # Adjust based on your model's architecture
# Initialize a dictionary to store similarity scores per layer
layer_similarity = {}
# Compute similarity scores for each layer
for layer in range(num_encoder_layers):
print(f"\nProcessing Layer {layer}...")
similarity_scores = compare_similarity_per_batch(layer, router_histories)
layer_similarity[layer] = similarity_scores
print(f"Completed Layer {layer}: {len(similarity_scores)} batches compared.")
# Example: Plot histogram for each layer
for layer, scores in layer_similarity.items():
plot_similarity_histogram(scores, layer)
print("Evaluation complete. Routing history analysis finished.")