generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
setfit_datamodule.py
139 lines (122 loc) · 5.27 KB
/
setfit_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
from typing import Dict, Optional, Tuple
import pandas as pd
import pyarrow as pa
from datasets import Dataset, load_dataset
from pytorch_lightning import LightningDataModule
from setfit import sample_dataset
from setfit.data import SetFitDataset
from torch.utils.data import DataLoader
class DataModule(LightningDataModule):
def __init__(
self,
dataset_id: str = "sst2",
max_input_length: Optional[int] = None,
num_samples: int = 16,
batch_size: int = 16,
pin_memory: bool = True,
num_workers: int = os.cpu_count(),
column_mapping: Optional[Dict] = None,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
self.save_hyperparameters(logger=False, ignore="local_data_path")
dataset = load_dataset(self.hparams.dataset_id)
if self.hparams.column_mapping:
for key in dataset:
dataset[key] = self._apply_column_mapping(
dataset[key], self.hparams.column_mapping
)
self.train_dataset = sample_dataset(
dataset["train"], label_column="label", num_samples=self.hparams.num_samples
)
# valid and test data is sepalated from original validation data
self.valid_dataset, self.test_dataset = self._separate_dataset(dataset["validation"])
# for error of tokenizer (https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def _apply_column_mapping(
self, dataset: "Dataset", column_mapping: Dict[str, str]
) -> "Dataset":
"""Applies the provided column mapping to the dataset, renaming columns accordingly.
Extra features not in the column mapping are prefixed with `"feat_"`.
"""
dataset = dataset.rename_columns(
{
**column_mapping,
**{
col: f"feat_{col}" for col in dataset.column_names if col not in column_mapping
},
}
)
dset_format = dataset.format
dataset = dataset.with_format(
type=dset_format["type"],
columns=dataset.column_names,
output_all_columns=dset_format["output_all_columns"],
**dset_format["format_kwargs"],
)
return dataset
def _separate_dataset(self, dataset: "Dataset") -> Tuple["Dataset", "Dataset"]:
df = pd.DataFrame(dataset)
val_df = df.iloc[: len(df) // 2].reset_index(drop=True)
test_df = df.iloc[len(df) // 2 :].reset_index(drop=True)
val_dataset = Dataset(pa.Table.from_pandas(val_df))
test_dataset = Dataset(pa.Table.from_pandas(test_df))
return val_dataset, test_dataset
def prepare_data(self):
pass
def setup(self, stage: Optional[str] = None):
if stage == "fit" or stage is None:
self.setfit_train_dataset: Dataset = SetFitDataset(
self.train_dataset["text"],
self.train_dataset["label"],
tokenizer=self.trainer.model.model_body.tokenizer,
max_length=self.hparams.max_input_length
if self.hparams.max_input_length
else self.trainer.model.model_body.get_max_seq_length(),
)
self.setfit_valid_dataset: Dataset = SetFitDataset(
self.valid_dataset["text"],
self.valid_dataset["label"],
tokenizer=self.trainer.model.model_body.tokenizer,
max_length=self.hparams.max_input_length
if self.hparams.max_input_length
else self.trainer.model.model_body.get_max_seq_length(),
)
if stage == "test" or stage is None:
# test_dataset is composed of valid_data since test data of sst2 contains unlabel data
self.setfit_test_dataset: Dataset = SetFitDataset(
self.test_dataset["text"],
self.test_dataset["label"],
tokenizer=self.trainer.model.model_body.tokenizer,
max_length=self.hparams.max_input_length
if self.hparams.max_input_length
else self.trainer.model.model_body.get_max_seq_length(),
)
def train_dataloader(self):
return DataLoader(
dataset=self.setfit_train_dataset,
batch_size=self.hparams.batch_size,
collate_fn=SetFitDataset.collate_fn,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
dataset=self.setfit_valid_dataset,
batch_size=self.hparams.batch_size,
collate_fn=SetFitDataset.collate_fn,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
dataset=self.setfit_test_dataset,
batch_size=self.hparams.batch_size,
collate_fn=SetFitDataset.collate_fn,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)