Skip to content

Commit

Permalink
Setup Fewshot Binding Task and Mixture
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586754237
  • Loading branch information
SeqIO Team authored and SeqIO committed Dec 1, 2023
1 parent 515d917 commit 2aa85f3
Showing 1 changed file with 104 additions and 12 deletions.
116 changes: 104 additions & 12 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def splits(self) -> Sequence[str]:
def get_dataset(
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -173,7 +173,7 @@ def add_provider(cls, name: str, provider):
task_registry_provenance_tracking.maybe_record_provenance(
frame=inspect.currentframe(),
name=name,
provider_type=provider.__class__.__name__,
provider_type=provider.__class__.__name__, # pylint:disable=attribute-error
)

@classmethod
Expand Down Expand Up @@ -325,7 +325,7 @@ def list_shards(self, split: str) -> Sequence[str]:
@abc.abstractmethod
def get_dataset(
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -432,7 +432,7 @@ def __repr__(self):

def get_dataset(
self,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -550,7 +550,7 @@ def get_dataset(
num_epochs: Optional[int] = 1, # Unused
) -> tf.data.Dataset:
if split is None:
split = tfds.Split.TRAIN
split = tfds.Split.TRAIN # pylint:disable=attribute-error
return self.tfds_dataset.load(
split, shuffle_files=shuffle, seed=seed, shard_info=shard_info
)
Expand Down Expand Up @@ -639,7 +639,7 @@ def __repr__(self):

def get_dataset(
self,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -1055,7 +1055,7 @@ class Task(DatasetProviderBase):
def __init__(
self,
name: str,
source: DataSource,
source: Optional[DataSource],
output_features: Mapping[str, Feature],
preprocessors: Optional[Sequence[Callable[..., tf.data.Dataset]]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -1512,7 +1512,7 @@ def assert_cached(self) -> None:
), f"'{self.name}' does not exist in any of the task cache directories."

def get_cached_stats(
self, split: str = tfds.Split.TRAIN
self, split: str = tfds.Split.TRAIN # pylint:disable=attribute-error
) -> Mapping[str, Union[int, float]]:
"""Returns basic statistics for cached dataset."""
self.assert_cached()
Expand All @@ -1529,7 +1529,7 @@ def get_cached_stats(
def get_dataset(
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
shuffle_buffer_size: Optional[int] = None, # Unique to Task
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def get_dataset(
)
else:
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
ds = ds.shard(shard_info.num_shards, shard_info.index)
ds = ds.shard(shard_info.num_shards, shard_info.index) # pylint:disable=attribute-error

num_shards = shard_info.num_shards if shard_info else 1
if try_in_mem_cache and (
Expand Down Expand Up @@ -1915,7 +1915,7 @@ def get_task_dataset(
task: Task,
output_feature_keys: Set[str],
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -1947,7 +1947,7 @@ def _get_all_mixing_rates(self, tasks):
def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -2115,6 +2115,98 @@ def _get_submixture_rate(self, mix: "Mixture") -> float:
return float(rate)


def get_dataset_iterator_from_tasks(
tasks: Union[
Sequence[SubtaskOrName], Sequence[Tuple[SubtaskOrName, MixtureRate]]
],
sources: Sequence[grain.TfDataSource],
proportions: Sequence[float],
shard_info: Optional[ShardInfo],
seed: Optional[int],
num_epochs: Optional[int],
strict_transformations: bool,
shuffle: bool,
batch_size: Optional[int],
sequence_length: Optional[Mapping[str, int]],
trim_output_features: bool,
output_features: Mapping[str, str],
feature_converter: FeatureConverter,
) -> grain.TfGrainDatasetIterator:
"""Returns a deterministic DatasetIterator for the mixture."""
if shard_info is None:
shard_options = grain.NoSharding()
else:
shard_options = grain.ShardOptions(
shard_index=shard_info.index, shard_count=shard_info.num_shards
)

if num_epochs and num_epochs != 1:
raise ValueError(
"Epochs are not supported for mixtures. A mixture "
"always repeats indefinitely over it's tasks."
)

if sequence_length is not None:
# Avoid index being dropped. In case of example packing we even need to
# pack it (but it should never be the limiting factor).
sequence_length = dict(sequence_length)
sequence_length[grain.INDEX] = max(sequence_length.values())

extra_args = {
"sequence_length": sequence_length,
"output_features": output_features,
}
add_kwargs = lambda t: utils.add_kwargs_to_transform(t, **extra_args)

transformations_per_source = []
for task in tasks:
transformations_per_source.append(
[add_kwargs(t) for t in task.preprocessors]
) # pylint: disable=protected-access
# Transformations applied after combination all data sources.
transformations = [
seqio_preprocessors.ReshapeFeatures({grain.INDEX: [-1]}),
seqio_preprocessors.DropFeatures(
set(grain.META_FEATURES) - {grain.INDEX}
),
]
if trim_output_features:
transformations.append(seqio_preprocessors._TrimDataset()) # pylint: disable=protected-access
if hasattr(feature_converter, "get_grain_transforms"):
transformations += feature_converter.get_grain_transforms(
batch_size=batch_size, task_feature_lengths=sequence_length
)
elif strict_transformations:
raise NotImplementedError(
f"FeatureConverter {feature_converter} does "
"not implement get_grain_transforms()."
)
else:
transformations += [
functools.partial(
feature_converter, task_feature_lengths=sequence_length
)
]
transformations = [add_kwargs(t) for t in transformations]

sampler = grain.TfMixtureIndexSampler(
[len(s) for s in sources],
shard_options=shard_options,
proportions=proportions,
shuffle=shuffle,
seed=seed,
)
data_loader = grain.TfMixtureDataLoader(
sources=sources,
sampler=sampler,
transformations_per_source=transformations_per_source,
transformations=transformations,
iterator_options=grain.IteratorOptions(drop_grain_meta_features=True),
strict_transformations=strict_transformations,
)
return iter(data_loader) # pytype: disable=bad-return-type




def _log_padding_fractions(dataset, sequence_length, num_examples=100):
Expand Down

0 comments on commit 2aa85f3

Please sign in to comment.