diff --git a/flair/data.py b/flair/data.py index 7ee32f40b9..6fd41e759c 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1247,6 +1247,36 @@ def remove_labels(self, typename: str): # delete labels at object itself super().remove_labels(typename) + def _get_token_level_label_of_each_token(self, label_type: str) -> List[str]: + """Generates a label for each token in the sentence. This function requires that the labels corresponding to the label_type are token-level tokens. + + Args: + sentence: a flair sentence to generate labels for + label_type: a string representing the type of the labels, e.g., "pos" + """ + list_of_labels = ["O" for _ in range(len(self.tokens))] + for label in self.get_labels(label_type): + label_token_index = label.data_point._internal_index + list_of_labels[label_token_index - 1] = label.value + return list_of_labels + + def _get_span_level_label_of_each_token(self, label_type: str) -> List[str]: + """Generates a label for each token in the sentence in BIO format. This function requires that the labels corresponding to the label_type are span-level tokens. + + Args: + sentence: a flair sentence to generate labels for + label_type: a string representing the type of the labels, e.g., "ner" + """ + list_of_labels = ["O" for _ in range(len(self.tokens))] + for label in self.get_labels(label_type): + tokens = label.data_point.tokens + start_token_index = tokens[0]._internal_index + list_of_labels[start_token_index - 1] = f"B-{label.value}" + for token in tokens[1:]: + token_index = token._internal_index + list_of_labels[token_index - 1] = f"I-{label.value}" + return list_of_labels + class DataPair(DataPoint, typing.Generic[DT, DT2]): def __init__(self, first: DT, second: DT2) -> None: diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index b2ab2f45dd..bf5aa4e9e9 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -11,15 +11,10 @@ from collections import defaultdict from collections.abc import Iterable, Iterator from pathlib import Path -from typing import ( - Any, - Optional, - Union, - cast, -) +from typing import Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union, cast +from torch.utils.data import ConcatDataset, Dataset, Subset import requests -from torch.utils.data import ConcatDataset, Dataset import flair from flair.data import ( @@ -28,7 +23,9 @@ MultiCorpus, Relation, Sentence, + Span, Token, + _iter_dataset, get_spans_from_bio, ) from flair.datasets.base import find_train_dev_test_files @@ -463,6 +460,149 @@ def __init__( **corpusargs, ) + @staticmethod + def _get_level_of_label(dataset: Optional[Dataset], label_type: str) -> Optional[Union[Type[Token], Type[Span]]]: + """Gets level of label type by checking the first label in this dataset. + + Raises: + NotImplementedError: if level of label_type is neither Token nor Span + """ + for sentence in _iter_dataset(dataset): + for label in sentence.get_labels(label_type): + if isinstance(label.data_point, Token): + return Token + elif isinstance(label.data_point, Span): + return Span + else: + raise NotImplementedError( + f"The level of {label_type} is neither token nor span. Only token level labels and span level labels can be handled now." + ) + log.warning(f"There is no label of type {label_type} in this dataset.") + return None + + @staticmethod + def _write_dataset_to_file( + dataset: Optional[Dataset], label_types: List[str], file_path: Path, column_delimiter: str = "\t" + ) -> None: + """Writes a dataset to a file. + + Following these two rules. + (1) the text and the label(s) of every token is represented in one line separated by column_delimiter + (2) every sentence is separated from the previous one by an empty line + + Note: + Only labels corresponding to label_types will be written. + Only token level or span level sequence tagging labels are supported. + Currently, the whitespace_after attribute of each token will not be preserved in the written file. + + Args: + dataset: a dataset to write + label_types: a list of label types to write e.g., ["ner", "pos"] + file_path: a path to store the file + column_delimiter: a string to separate token texts and labels in a line, the default value is a tab + """ + if dataset: + label_type_tuples = [] + for label_type in label_types: + level_of_label = ColumnCorpus._get_level_of_label(dataset, label_type) + label_type_tuples.append((label_type, level_of_label)) + + with open(file_path, mode="w") as output_file: + for sentence in _iter_dataset(dataset): + texts = [token.text for token in sentence.tokens] + texts_and_labels = [texts] + for label_type, level in label_type_tuples: + if level is None: + texts_and_labels.append(["O" for _ in range(len(sentence))]) + elif level is Token: + texts_and_labels.append(sentence._get_token_level_label_of_each_token(label_type)) + elif level is Span: + texts_and_labels.append(sentence._get_span_level_label_of_each_token(label_type)) + else: + raise NotImplementedError(f"The level of {label_type} is neither token nor span.") + + for text_and_labels_of_a_token in zip(*texts_and_labels): + output_file.write(column_delimiter.join(text_and_labels_of_a_token) + "\n") + output_file.write("\n") + else: + log.warning("dataset is None, did not write any file.") + + @classmethod + def load_corpus_with_meta_data(cls, directory: Path) -> "ColumnCorpus": + """Creates a ColumnCorpus instance from the directory generated by 'write_to_directory'.""" + with open(directory / "meta_data.json") as file: + meta_data = json.load(file) + + meta_data["column_format"] = {int(key): value for key, value in meta_data["column_format"].items()} + + return cls( + data_folder=directory, + autofind_splits=True, + skip_first_line=False, + **meta_data, + ) + + def _write_corpus_meta_data( + self, label_types: List[str], file_path: Path, column_delimiter: str, max_depth=5 + ) -> None: + """Writes meta data of this corpus to a json file. + + Note: + Currently, the whitespace_after attribute of each token will not be preserved. Only default_whitespace_after attribute of each dataset will be written to the file. + """ + meta_data = { + "name": self.name, + "sample_missing_splits": False, + "column_delimiter": column_delimiter, + } + + column_format = {0: "text"} + for label_type_index, label_type in enumerate(label_types): + column_format[label_type_index + 1] = label_type + meta_data["column_format"] = column_format + + nonempty_dataset = self.train or self.dev or self.test + # Sometimes, nonempty_dataset is a ConcatDataset or Subset, we need to get the original ColumnDataset + # to access the encoding, in_memory, banned_sentences and default_whitespace_after attributes + for _ in range(max_depth): + if type(nonempty_dataset) is ColumnDataset: + break + elif type(nonempty_dataset) is ConcatDataset: + nonempty_dataset = nonempty_dataset.datasets[0] + elif type(nonempty_dataset) is Subset: + nonempty_dataset = nonempty_dataset.dataset + else: + raise NotImplementedError("Unsupported type") + + if type(nonempty_dataset) is not ColumnDataset: + raise NotImplementedError("Unsupported type") + + meta_data["encoding"] = nonempty_dataset.encoding + meta_data["in_memory"] = nonempty_dataset.in_memory + meta_data["banned_sentences"] = nonempty_dataset.banned_sentences + meta_data["default_whitespace_after"] = nonempty_dataset.default_whitespace_after + + with open(file_path, mode="w") as output_file: + json.dump(meta_data, output_file) + + def write_to_directory(self, label_types: List[str], output_directory: Path, column_delimiter: str = "\t") -> None: + """Writes train, dev, test dataset (if exist) and the meta data of the corpus to a directory. + + Note: + Only labels corresponding to label_types will be written. + Only token level or span level sequence tagging labels are supported. + Currently, the whitespace_after attribute of each token will not be preserved in the written file. + + Args: + label_types: a list of label types to write e.g., ["ner", "pos"] + output_directory: a directory to store the files + column_delimiter: a string to separate token texts and labels in a line, the default value is a tab + """ + os.makedirs(output_directory, exist_ok=True) + for dataset, file_name in [(self.train, "train.conll"), (self.dev, "dev.conll"), (self.test, "test.conll")]: + ColumnCorpus._write_dataset_to_file(dataset, label_types, output_directory / file_name, column_delimiter) + self._write_corpus_meta_data(label_types, output_directory / "meta_data.json", column_delimiter) + class ColumnDataset(FlairDataset): # special key for space after @@ -817,6 +957,26 @@ def _remap_label(self, tag): tag = self.label_name_map[tag] # for example, transforming 'PER' to 'person' return tag + def write_dataset_to_file(self, label_types: List[str], file_path: Path, column_delimiter: str = "\t") -> None: + """Writes a dataset to a file. + + Following these two rules. + (1) the text and the label(s) of every token is represented in one line separated by column_delimiter + (2) every sentence is separated from the previous one by an empty line + + Note: + Only labels corresponding to label_types will be written. + Only token level or span level sequence tagging labels are supported. + Currently, the whitespace_after attribute of each token will not be preserved in the written file. + + Args: + label_types: a list of label types to write e.g., ["ner", "pos"] + file_path: a path to store the file + column_delimiter: a string to separate token texts and labels in a line, the default value is a tab + """ + file_path.parent.mkdir(exist_ok=True, parents=True) + ColumnCorpus._write_dataset_to_file(self, label_types, file_path, column_delimiter) + def __line_completes_sentence(self, line: str) -> bool: sentence_completed = line.isspace() or line == "" return sentence_completed diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 25a99f87e0..2c76b40678 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -418,6 +418,24 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path): _assert_universal_dependencies_conllu_dataset(corpus.train) +def test_write_to_and_load_from_directory(tasks_base_path): + from pathlib import Path + + corpus = ColumnCorpus( + tasks_base_path / "column_with_whitespaces", + train_file="eng.train", + column_format={0: "text", 1: "ner"}, + column_delimiter=" ", + skip_first_line=False, + sample_missing_splits=False, + ) + directory = Path("resources/taggers/") + corpus.write_to_directory(["ner"], directory, column_delimiter="\t") + loaded_corpus = ColumnCorpus.load_corpus_with_meta_data(directory) + assert len(loaded_corpus.train) == len(corpus.train) + assert loaded_corpus.train[0].to_tagged_string() == corpus.train[0].to_tagged_string() + + @pytest.mark.skip() def test_hipe_2022_corpus(tasks_base_path): # This test covers the complete HIPE 2022 dataset.