diff --git a/edsnlp/core/stream.py b/edsnlp/core/stream.py index e9efc7cd3..0b4a57793 100644 --- a/edsnlp/core/stream.py +++ b/edsnlp/core/stream.py @@ -25,7 +25,7 @@ from typing_extensions import Literal import edsnlp.data -from edsnlp.utils.batching import BatchBy, BatchFn, BatchSizeArg, batchify_fns +from edsnlp.utils.batching import BatchBy, BatchFn, BatchSizeArg, batchify, batchify_fns from edsnlp.utils.collections import flatten, flatten_once, shuffle from edsnlp.utils.stream_sentinels import StreamSentinel @@ -47,25 +47,6 @@ def deep_isgeneratorfunction(x): raise ValueError(f"{x} does not have a __call__ or batch_process method.") -class _InferType: - # Singleton is important since the INFER object may be passed to - # other processes, i.e. pickled, depickled, while it should - # always be the same object. - instance = None - - def __repr__(self): - return "INFER" - - def __new__(cls, *args, **kwargs): - if cls.instance is None: - cls.instance = super().__new__(cls) - return cls.instance - - def __bool__(self): - return False - - -INFER = _InferType() CONTEXT = [{}] T = TypeVar("T") @@ -125,8 +106,8 @@ def __init__( ): if batch_fn is None: if size is None: - size = INFER - batch_fn = INFER + size = None + batch_fn = None else: batch_fn = batchify_fns["docs"] self.size = size @@ -287,12 +268,12 @@ def __init__( reader: Optional[BaseReader] = None, writer: Optional[Union[BaseWriter, BatchWriter]] = None, ops: List[Any] = [], - config: Dict = {}, + config: Optional[Dict] = None, ): self.reader = reader self.writer = writer self.ops: List[Op] = ops - self.config = config + self.config = config or {} @classmethod def validate_batching(cls, batch_size, batch_by): @@ -302,17 +283,12 @@ def validate_batching(cls, batch_size, batch_by): "Cannot use both a batch_size expression and a batch_by function" ) batch_size, batch_by = BatchSizeArg.validate(batch_size) - if ( - batch_size is not None - and batch_size is not INFER - and not isinstance(batch_size, int) - ): + if batch_size is not None and not isinstance(batch_size, int): raise ValueError( f"Invalid batch_size (must be an integer or None): {batch_size}" ) if ( batch_by is not None - and batch_by is not INFER and batch_by not in batchify_fns and not callable(batch_by) ): @@ -321,11 +297,11 @@ def validate_batching(cls, batch_size, batch_by): @property def batch_size(self): - return self.config.get("batch_size", 1) + return self.config.get("batch_size", None) @property def batch_by(self): - return self.config.get("batch_by", "docs") + return self.config.get("batch_by", None) @property def disable_implicit_parallelism(self): @@ -372,39 +348,36 @@ def deterministic(self): @with_non_default_args def set_processing( self, - batch_size: int = INFER, - batch_by: BatchBy = "docs", - split_into_batches_after: str = INFER, - num_cpu_workers: Optional[int] = INFER, - num_gpu_workers: Optional[int] = INFER, + batch_size: Optional[Union[int, str]] = None, + batch_by: BatchBy = None, + split_into_batches_after: str = None, + num_cpu_workers: Optional[int] = None, + num_gpu_workers: Optional[int] = None, disable_implicit_parallelism: bool = True, - backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = INFER, - autocast: Union[bool, Any] = INFER, + backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = None, + autocast: Union[bool, Any] = None, show_progress: bool = False, - gpu_pipe_names: Optional[List[str]] = INFER, - process_start_method: Optional[Literal["fork", "spawn"]] = INFER, - gpu_worker_devices: Optional[List[str]] = INFER, - cpu_worker_devices: Optional[List[str]] = INFER, + gpu_pipe_names: Optional[List[str]] = None, + process_start_method: Optional[Literal["fork", "spawn"]] = None, + gpu_worker_devices: Optional[List[str]] = None, + cpu_worker_devices: Optional[List[str]] = None, deterministic: bool = True, work_unit: Literal["record", "fragment"] = "record", - chunk_size: int = INFER, + chunk_size: int = None, sort_chunks: bool = False, _non_default_args: Iterable[str] = (), ) -> "Stream": """ Parameters ---------- - batch_size: int - Number of documents to process at a time in a GPU worker (or in the - main process if no workers are used). This is the global batch size - that is used for batching methods that do not provide their own - batching arguments. + batch_size: Optional[Union[int, str]] + The batch size. Can also be a batching expression like + "32 docs", "1024 words", "dataset", "fragment", etc. batch_by: BatchBy Function to compute the batches. If set, it should take an iterable of documents and return an iterable of batches. You can also set it to "docs", "words" or "padded_words" to use predefined batching functions. - Defaults to "docs". Only used for operations that do not provide their - own batching arguments. + Defaults to "docs". num_cpu_workers: int Number of CPU workers. A CPU worker handles the non deep-learning components and the preprocessing, collating and postprocessing of deep-learning @@ -468,15 +441,15 @@ def set_processing( """ kwargs = {k: v for k, v in locals().items() if k in _non_default_args} if ( - kwargs.pop("chunk_size", INFER) is not INFER - or kwargs.pop("sort_chunks", INFER) is not INFER + kwargs.pop("chunk_size", None) is not None + or kwargs.pop("sort_chunks", None) is not None ): warnings.warn( "chunk_size and sort_chunks are deprecated, use " "map_batched(sort_fn, batch_size=chunk_size) instead.", VisibleDeprecationWarning, ) - if kwargs.pop("split_into_batches_after", INFER) is not INFER: + if kwargs.pop("split_into_batches_after", None) is not None: warnings.warn( "split_into_batches_after is deprecated.", VisibleDeprecationWarning ) @@ -486,7 +459,7 @@ def set_processing( ops=self.ops, config={ **self.config, - **{k: v for k, v in kwargs.items() if v is not INFER}, + **{k: v for k, v in kwargs.items() if v is not None}, }, ) @@ -690,8 +663,8 @@ def map_gpu( def map_pipeline( self, model: Pipeline, - batch_size: Optional[int] = INFER, - batch_by: BatchBy = INFER, + batch_size: Optional[Union[int, str]] = None, + batch_by: BatchBy = None, ) -> "Stream": """ Maps a pipeline to the documents, i.e. adds each component of the pipeline to @@ -974,16 +947,10 @@ def __getattr__(self, item): def _make_stages(self, split_torch_pipes: bool) -> List[Stage]: current_ops = [] stages = [] - self_batch_fn = batchify_fns.get(self.batch_by, self.batch_by) - self_batch_size = self.batch_size - assert self_batch_size is not None ops = [copy(op) for op in self.ops] for op in ops: - if isinstance(op, BatchifyOp): - op.batch_fn = self_batch_fn if op.batch_fn is INFER else op.batch_fn - op.size = self_batch_size if op.size is INFER else op.size if ( isinstance(op, MapBatchesOp) and hasattr(op.pipe, "forward") @@ -1005,15 +972,31 @@ def validate_ops(self, ops, update: bool = False): # Check batchify requirements requires_sentinels = set() + self_batch_size, self_batch_by = self.validate_batching( + self.batch_size, self.batch_by + ) + if self_batch_by is None: + self_batch_by = "docs" + if self_batch_size is None: + self_batch_size = 1 + self_batch_fn = batchify_fns.get(self_batch_by, self_batch_by) + if hasattr(self.writer, "batch_fn") and hasattr( self.writer.batch_fn, "requires_sentinel" ): requires_sentinels.add(self.writer.batch_fn.requires_sentinel) - self_batch_fn = batchify_fns.get(self.batch_by, self.batch_by) for op in reversed(ops): if isinstance(op, BatchifyOp): - batch_fn = op.batch_fn or self_batch_fn + if op.batch_fn is None and op.size is None: + batch_size = self_batch_size + batch_fn = self_batch_fn + elif op.batch_fn is None: + batch_size = op.size + batch_fn = batchify + else: + batch_size = op.size + batch_fn = op.batch_fn sentinel_mode = op.sentinel_mode or ( "auto" if "sentinel_mode" in signature(batch_fn).parameters @@ -1021,7 +1004,7 @@ def validate_ops(self, ops, update: bool = False): ) if sentinel_mode == "auto": sentinel_mode = "split" if requires_sentinels else "drop" - if requires_sentinels and op.sentinel_mode == "drop": + if requires_sentinels and sentinel_mode == "drop": raise ValueError( f"Operation {op} drops the stream sentinel values " f"(markers for the end of a dataset or a dataset " @@ -1031,10 +1014,12 @@ def validate_ops(self, ops, update: bool = False): f"any upstream batching operation." ) if update: + op.size = batch_size + op.batch_fn = batch_fn op.sentinel_mode = sentinel_mode - if hasattr(batch_fn, "requires_sentinel"): - requires_sentinels.add(batch_fn.requires_sentinel) + if hasattr(op.batch_fn, "requires_sentinel"): + requires_sentinels.add(op.batch_fn.requires_sentinel) sentinel_str = ", ".join(requires_sentinels) if requires_sentinels and self.backend == "spark": diff --git a/tests/data/test_stream.py b/tests/data/test_stream.py index 141a04d18..b1026d486 100644 --- a/tests/data/test_stream.py +++ b/tests/data/test_stream.py @@ -58,19 +58,22 @@ def forward(batch): assert set(res.tolist()) == {i * 2 for i in range(15)} +# fmt: off @pytest.mark.parametrize( - "sort,num_cpu_workers,batch_by,expected", + "sort,num_cpu_workers,batch_kwargs,expected", [ - (False, 1, "words", [3, 1, 3, 1, 3, 1]), - (False, 1, "padded_words", [2, 1, 1, 2, 1, 1, 2, 1, 1]), - (False, 1, "docs", [10, 2]), - (False, 2, "words", [2, 1, 2, 1, 2, 1, 1, 1, 1]), - (False, 2, "padded_words", [2, 1, 2, 1, 2, 1, 1, 1, 1]), - (False, 2, "docs", [6, 6]), - (True, 2, "padded_words", [3, 3, 2, 1, 1, 1, 1]), + (False, 1, {"batch_size": 10, "batch_by": "words"}, [3, 1, 3, 1, 3, 1]), # noqa: E501 + (False, 1, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 1, 2, 1, 1, 2, 1, 1]), # noqa: E501 + (False, 1, {"batch_size": 10, "batch_by": "docs"}, [10, 2]), # noqa: E501 + (False, 2, {"batch_size": 10, "batch_by": "words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501 + (False, 2, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501 + (False, 2, {"batch_size": 10, "batch_by": "docs"}, [6, 6]), # noqa: E501 + (True, 2, {"batch_size": 10, "batch_by": "padded_words"}, [3, 3, 2, 1, 1, 1, 1]), # noqa: E501 + (False, 2, {"batch_size": "10 words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501 ], ) -def test_map_with_batching(sort, num_cpu_workers, batch_by, expected): +# fmt: on +def test_map_with_batching(sort, num_cpu_workers, batch_kwargs, expected): nlp = edsnlp.blank("eds") nlp.add_pipe( "eds.matcher", @@ -94,8 +97,7 @@ def test_map_with_batching(sort, num_cpu_workers, batch_by, expected): stream = stream.map_batches(len) stream = stream.set_processing( num_cpu_workers=num_cpu_workers, - batch_size=10, - batch_by=batch_by, + **batch_kwargs, chunk_size=1000, # deprecated split_into_batches_after="matcher", show_progress=True,