Skip to content

Commit

Permalink
removed dedicated multilingual dataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
lrei committed Oct 8, 2023
1 parent 5f14b06 commit cc8b311
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 50 deletions.
45 changes: 1 addition & 44 deletions mbdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,47 +447,4 @@ def print_featurize(self):

def get_label_encoder(self) -> MultiLabelBinarizer:
"""Return the label encoder."""
return self.mlb


class MLDataset(MLDatasetWithFloats):
"""Multilabel Dataset."""

id2label: Dict[int, str]
label2id: Dict[str, int]
num_labels: int
mlb: MultiLabelBinarizer
label_names: List[str]


def _featurize_one(self, ex: MBExample):
raise NotImplementedError("Not implemented")

def _featurize(self, exs: List[MBExample]):
# tokenize
features = self.tokenizer(
text=[ex.text for ex in exs],
padding=True,
truncation=True,
return_tensors="pt",
max_length=self.max_seq_length,
)

# create labels tensor
labels = torch.FloatTensor(self.mlb.transform([ex.labels for ex in exs]))
features["labels"] = labels

return features

def __len__(self):
"""Length of dataset corresponds to the number of examples."""
return len(self.examples)

def __getitem__(self, i):
"""Return the i-th example's features."""
item = {k: self.features[k][i] for k in self.features.keys()} # type: ignore

return item

def get_label_list(self) -> Union[List[List[str]], List[str]]:
raise NotImplementedError("Not implemented")
return self.mlb
8 changes: 4 additions & 4 deletions multilingual_train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DataCollatorWithPadding,
)
from processors import MultiLabelTSVProcessor
from mbdataset import MLDataset
from mbdataset import MLDatasetWithFloats
from mutils import compute_metrics

TRN_SRC = "/data/exp/emotions_translation/emolit.tsv"
Expand Down Expand Up @@ -64,11 +64,11 @@

# load train dataset
proc_trn = MultiLabelTSVProcessor(TRN_SRC)
ds_trn = MLDataset(proc_trn, tkz, max_seq_length=SEQLEN)
ds_trn = MLDatasetWithFloats(proc_trn, tkz, max_seq_length=SEQLEN)

# load eval dataset
proc_gold = MultiLabelTSVProcessor(GOLD_SRC)
ds_gold = MLDataset(proc_gold, tkz, max_seq_length=SEQLEN, le=ds_trn.get_label_encoder())
ds_gold = MLDatasetWithFloats(proc_gold, tkz, max_seq_length=SEQLEN, le=ds_trn.get_label_encoder())

logger.info(f"Train: {len(ds_trn)} examples")
logger.info(f"Gold: {len(ds_gold)} examples")
Expand Down Expand Up @@ -110,7 +110,7 @@
# create a processor for this language
proc_lang = MultiLabelTSVProcessor(GOLD_SRC, lang=lang)
# create a dataset for this language
ds_lang = MLDataset(proc_lang, tkz, max_seq_length=SEQLEN, le=ds_trn.get_label_encoder())
ds_lang = MLDatasetWithFloats(proc_lang, tkz, max_seq_length=SEQLEN, le=ds_trn.get_label_encoder())
# evaluate
res = trainer.evaluate(ds_lang)
print(res)
Expand Down
8 changes: 6 additions & 2 deletions processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
split: Optional[str] = None,
split_col: Optional[str] = None,
random_state: int = 1984,
print_counts: bool = False,
):
"""See class."""
super(MultiLabelTSVProcessor, self).__init__()
Expand All @@ -91,6 +92,8 @@ def __init__(
self.split_col = split_col
self.split = split

self.print_counts = print_counts

self.not_labels = [self.text_col, self.lang_col, self.split_col]
self.not_labels = [c for c in self.not_labels if c is not None]
self.not_labels += ["tid", "id", "split", "og_split", "text", "lang"]
Expand Down Expand Up @@ -179,8 +182,9 @@ def _read_tsv(self) -> List[MBExample]:
examples.append(ex)

logger.info(f"Read: {len(examples)} examples")
for k in sorted(counter.keys()):
logger.info(f"\t{k}:\t{counter[k]}")
if self.print_counts:
for k in sorted(counter.keys()):
logger.info(f"\t{k}:\t{counter[k]}")
return examples

def get_examples(self) -> List[MBExample]:
Expand Down

0 comments on commit cc8b311

Please sign in to comment.