Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task registration for fewshot binding #681

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 105 additions & 13 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def splits(self) -> Sequence[str]:
def get_dataset(
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -173,7 +173,7 @@ def add_provider(cls, name: str, provider):
task_registry_provenance_tracking.maybe_record_provenance(
frame=inspect.currentframe(),
name=name,
provider_type=provider.__class__.__name__,
provider_type=provider.__class__.__name__, # pylint:disable=attribute-error
)

@classmethod
Expand Down Expand Up @@ -325,7 +325,7 @@ def list_shards(self, split: str) -> Sequence[str]:
@abc.abstractmethod
def get_dataset(
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -432,7 +432,7 @@ def __repr__(self):

def get_dataset(
self,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -550,7 +550,7 @@ def get_dataset(
num_epochs: Optional[int] = 1, # Unused
) -> tf.data.Dataset:
if split is None:
split = tfds.Split.TRAIN
split = tfds.Split.TRAIN # pylint:disable=attribute-error
return self.tfds_dataset.load(
split, shuffle_files=shuffle, seed=seed, shard_info=shard_info
)
Expand Down Expand Up @@ -639,7 +639,7 @@ def __repr__(self):

def get_dataset(
self,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -694,7 +694,7 @@ def list_shards(self, split: str) -> Sequence[str]:
return _list_files(pattern=filepattern)

if not any(glob.has_magic(f) for f in filepattern):
return filepattern
return filepattern # pytype: disable=bad-return-type
else:
return _list_files(pattern=filepattern)

Expand Down Expand Up @@ -1512,7 +1512,7 @@ def assert_cached(self) -> None:
), f"'{self.name}' does not exist in any of the task cache directories."

def get_cached_stats(
self, split: str = tfds.Split.TRAIN
self, split: str = tfds.Split.TRAIN # pylint:disable=attribute-error
) -> Mapping[str, Union[int, float]]:
"""Returns basic statistics for cached dataset."""
self.assert_cached()
Expand All @@ -1526,10 +1526,10 @@ def get_cached_stats(
self._stats[split] = json.load(f)
return self._stats[split]

def get_dataset(
def get_dataset( # pylint: disable=arguments-renamed
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
shuffle_buffer_size: Optional[int] = None, # Unique to Task
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def get_dataset(
)
else:
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
ds = ds.shard(shard_info.num_shards, shard_info.index)
ds = ds.shard(shard_info.num_shards, shard_info.index) # pylint:disable=attribute-error

num_shards = shard_info.num_shards if shard_info else 1
if try_in_mem_cache and (
Expand Down Expand Up @@ -1915,7 +1915,7 @@ def get_task_dataset(
task: Task,
output_feature_keys: Set[str],
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -1947,7 +1947,7 @@ def _get_all_mixing_rates(self, tasks):
def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -2115,6 +2115,98 @@ def _get_submixture_rate(self, mix: "Mixture") -> float:
return float(rate)


def get_dataset_iterator_from_tasks(
tasks: Union[
Sequence[SubtaskOrName], Sequence[Tuple[SubtaskOrName, MixtureRate]]
],
sources: Sequence[grain.TfDataSource],
proportions: Sequence[float],
shard_info: Optional[ShardInfo],
seed: Optional[int],
num_epochs: Optional[int],
strict_transformations: bool,
shuffle: bool,
batch_size: Optional[int],
sequence_length: Optional[Mapping[str, int]],
trim_output_features: bool,
output_features: Mapping[str, str],
feature_converter: FeatureConverter,
) -> grain.TfGrainDatasetIterator:
"""Returns a deterministic DatasetIterator for the mixture."""
if shard_info is None:
shard_options = grain.NoSharding()
else:
shard_options = grain.ShardOptions(
shard_index=shard_info.index, shard_count=shard_info.num_shards
)

if num_epochs and num_epochs != 1:
raise ValueError(
"Epochs are not supported for mixtures. A mixture "
"always repeats indefinitely over it's tasks."
)

if sequence_length is not None:
# Avoid index being dropped. In case of example packing we even need to
# pack it (but it should never be the limiting factor).
sequence_length = dict(sequence_length)
sequence_length[grain.INDEX] = max(sequence_length.values())

extra_args = {
"sequence_length": sequence_length,
"output_features": output_features,
}
add_kwargs = lambda t: utils.add_kwargs_to_transform(t, **extra_args)

transformations_per_source = []
for task in tasks:
transformations_per_source.append(
[add_kwargs(t) for t in task.preprocessors] # pytype: disable=attribute-error
) # pylint: disable=protected-access
# Transformations applied after combination all data sources.
transformations = [
seqio_preprocessors.ReshapeFeatures({grain.INDEX: [-1]}),
seqio_preprocessors.DropFeatures(
set(grain.META_FEATURES) - {grain.INDEX}
),
]
if trim_output_features:
transformations.append(seqio_preprocessors._TrimDataset()) # pylint: disable=protected-access
if hasattr(feature_converter, "get_grain_transforms"):
transformations += feature_converter.get_grain_transforms(
batch_size=batch_size, task_feature_lengths=sequence_length
)
elif strict_transformations:
raise NotImplementedError(
f"FeatureConverter {feature_converter} does "
"not implement get_grain_transforms()."
)
else:
transformations += [
functools.partial(
feature_converter, task_feature_lengths=sequence_length
)
]
transformations = [add_kwargs(t) for t in transformations]

sampler = grain.TfMixtureIndexSampler(
[len(s) for s in sources],
shard_options=shard_options,
proportions=proportions,
shuffle=shuffle,
seed=seed,
)
data_loader = grain.TfMixtureDataLoader(
sources=sources,
sampler=sampler,
transformations_per_source=transformations_per_source,
transformations=transformations,
iterator_options=grain.IteratorOptions(drop_grain_meta_features=True),
strict_transformations=strict_transformations,
)
return iter(data_loader) # pytype: disable=bad-return-type




def _log_padding_fractions(dataset, sequence_length, num_examples=100):
Expand Down
Loading