-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_script.py
101 lines (84 loc) · 3.59 KB
/
train_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Import necessary libraries
from transformers import AutoTokenizer
from datasets import load_dataset
import nextgenjax as ngj
import jax
import optax
import flax
import numpy as np
from tqdm import tqdm
# Initialize the NextGenJAX model and tokenizer
model_name = "VishwamAI/vishwamai" # This will be replaced with the actual NextGenJAX model name
config = ngj.NextGenJAXConfig() # Placeholder for actual config
model = ngj.NextGenJAXModel(config)
tokenizer = ngj.NextGenJAXTokenizer.from_pretrained(model_name)
# Load datasets
datasets = {
"gaokao": load_dataset("MARIO-Math-Reasoning/Gaokao2023-Math-En"),
"gemma": load_dataset("path_to_gemma_dataset"), # Placeholder
"phi": load_dataset("path_to_phi_dataset"), # Placeholder
"mmlu_math": load_dataset("path_to_mmlu_math_dataset") # Placeholder
}
def preprocess_function(examples, dataset_name):
if dataset_name == "gaokao":
inputs = [f"Question: {q}\nAnswer:" for q in examples["question"]]
targets = [f" {a}" for a in examples["answer"]]
elif dataset_name in ["gemma", "phi", "mmlu_math"]:
# Placeholder for dataset-specific preprocessing
inputs = examples["input"]
targets = examples["output"]
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# Preprocess and combine datasets
combined_dataset = []
for name, dataset in datasets.items():
tokenized_dataset = dataset["train"].map(
lambda examples: preprocess_function(examples, name),
batched=True,
remove_columns=dataset["train"].column_names
)
combined_dataset.extend(tokenized_dataset)
# Split the combined dataset into train and evaluation sets
np.random.shuffle(combined_dataset)
split_index = int(len(combined_dataset) * 0.8)
train_dataset = combined_dataset[:split_index]
eval_dataset = combined_dataset[split_index:]
# Initialize optimizer
learning_rate = 1e-4
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(model.params)
# Training loop
num_epochs = 3
batch_size = 8
for epoch in range(num_epochs):
# Training
train_losses = []
for i in tqdm(range(0, len(train_dataset), batch_size)):
batch = train_dataset[i:i+batch_size]
def loss_fn(params):
logits = model.apply(params, batch['input_ids'])
loss = jax.nn.cross_entropy_loss(logits, batch['labels'])
return loss
loss, grads = jax.value_and_grad(loss_fn)(model.params)
updates, opt_state = optimizer.update(grads, opt_state)
model.params = optax.apply_updates(model.params, updates)
train_losses.append(loss)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {np.mean(train_losses)}")
# Evaluation
eval_losses = []
for i in range(0, len(eval_dataset), batch_size):
batch = eval_dataset[i:i+batch_size]
logits = model.apply(model.params, batch['input_ids'])
loss = jax.nn.cross_entropy_loss(logits, batch['labels'])
eval_losses.append(loss)
print(f"Epoch {epoch+1}/{num_epochs}, Eval Loss: {np.mean(eval_losses)}")
# Save checkpoint
checkpoint_path = f"./checkpoints/model_epoch_{epoch+1}.ckpt"
with open(checkpoint_path, 'wb') as f:
f.write(flax.serialization.to_bytes(model.params))
# Final model saving (replace with appropriate method for NextGenJAX)
# model.save_pretrained("VishwamAI/vishwamai-finetuned")
# tokenizer.save_pretrained("VishwamAI/vishwamai-finetuned")
print("Training completed.")