diff --git a/singer_sdk/batch.py b/singer_sdk/batch.py index f329465c4..6a0fee38e 100644 --- a/singer_sdk/batch.py +++ b/singer_sdk/batch.py @@ -12,9 +12,10 @@ from singer_sdk.helpers._batch import BatchConfig _T = t.TypeVar("_T") +_B = t.TypeVar("_B", bound="BaseBatcher") -def __getattr__(name: str) -> t.Any: # noqa: ANN401 +def __getattr__(name: str) -> t.Any: # noqa: ANN401 # pragma: no cover if name == "JSONLinesBatcher": warnings.warn( "The class JSONLinesBatcher was moved to singer_sdk.contrib.batch_encoder_jsonl.", # noqa: E501 @@ -98,23 +99,35 @@ def get_batches(self, records: t.Iterator[dict]) -> t.Iterator[list[str]]: Returns: A list of file paths (called a manifest). + """ + encoding_format = self.batch_config.encoding.format + batcher_type: type[BaseBatcher] = self.get_batcher(encoding_format) + batcher = batcher_type( + self.tap_name, + self.stream_name, + self.batch_config, + ) + return batcher.get_batches(records) + + @classmethod + def get_batcher(cls, name: str) -> type[_B]: + """Get a batcher by name. + + Args: + name: The name of the batcher. + + Returns: + The batcher class. Raises: - ValueError: If unsupported format given. + ValueError: If the batcher is not found. """ - encoding_format = self.batch_config.encoding.format plugins = entry_points(group="singer_sdk.batch_encoders") try: - plugin = next(filter(lambda x: x.name == encoding_format, plugins)) + plugin = next(filter(lambda x: x.name == name, plugins)) except StopIteration: - message = f"Unsupported batch format: {encoding_format}" + message = f"Unsupported batcher: {name}" raise ValueError(message) from None - batcher_type: type[Batcher] = plugin.load() - batcher = batcher_type( - self.tap_name, - self.stream_name, - self.batch_config, - ) - return batcher.get_batches(records) + return plugin.load() diff --git a/singer_sdk/sinks/core.py b/singer_sdk/sinks/core.py index f1eb60341..f3a1975cf 100644 --- a/singer_sdk/sinks/core.py +++ b/singer_sdk/sinks/core.py @@ -526,9 +526,6 @@ def process_batch_files( Raises: NotImplementedError: If the batch file encoding is not supported. """ - spec = importlib.util.find_spec("pyarrow") - if spec: - import pyarrow.parquet as pq file: GzipFile | t.IO storage: StorageTarget | None = None @@ -550,7 +547,12 @@ def process_batch_files( ) context = {"records": [json.loads(line) for line in context_file]} # type: ignore[attr-defined] self.process_batch(context) - elif spec and encoding.format == BatchFileFormat.PARQUET: + elif ( + importlib.util.find_spec("pyarrow") + and encoding.format == BatchFileFormat.PARQUET + ): + import pyarrow.parquet as pq + with storage.fs(create=False) as batch_fs, batch_fs.open( tail, mode="rb", diff --git a/tests/core/test_batch.py b/tests/core/test_batch.py index 9fd911995..c2076e37b 100644 --- a/tests/core/test_batch.py +++ b/tests/core/test_batch.py @@ -109,6 +109,11 @@ def test_storage_from_url(file_url: str, root: str): assert target.root == root +def test_get_unsupported_batcher(): + with pytest.raises(ValueError, match="Unsupported batcher"): + Batcher.get_batcher("unsupported") + + @pytest.mark.parametrize( "file_url,expected", [