From 25527c4a18785865d1552ceb1e9be0f00539eb80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Wed, 13 Nov 2024 21:33:55 -0600 Subject: [PATCH] feat: Support other content-types in REST streams (#2762) * WIP: Support other content-types for HTTP streams * Add test * Fix types * Specifically test legacy pagination --- .../graphql-client.py | 5 +- .../rest-client.py | 5 +- docs/_templates/stream_class.rst | 10 + docs/classes/singer_sdk.GraphQLStream.rst | 4 +- docs/classes/singer_sdk.RESTStream.rst | 4 +- docs/classes/singer_sdk.SQLStream.rst | 4 +- docs/classes/singer_sdk.Stream.rst | 4 +- docs/reference.rst | 2 +- pyproject.toml | 2 +- singer_sdk/authenticators.py | 20 +- singer_sdk/streams/rest.py | 185 +++++++++++++----- tests/core/test_streams.py | 87 ++++++-- 12 files changed, 240 insertions(+), 92 deletions(-) create mode 100644 docs/_templates/stream_class.rst diff --git a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py index d1289efa4..289199b11 100644 --- a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py +++ b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py @@ -45,14 +45,11 @@ def http_headers(self) -> dict: Returns: A dictionary of HTTP headers. """ - headers = {} - if "user_agent" in self.config: - headers["User-Agent"] = self.config.get("user_agent") {%- if cookiecutter.auth_method not in ("OAuth2", "JWT") %} # If not using an authenticator, you may also provide inline auth headers: # headers["Private-Token"] = self.config.get("auth_token") {%- endif %} - return headers + return {} def parse_response(self, response: requests.Response) -> t.Iterable[dict]: """Parse the response and return an iterator of result records. diff --git a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py index c1aa634a5..35a53303c 100644 --- a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py +++ b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py @@ -132,14 +132,11 @@ def http_headers(self) -> dict: Returns: A dictionary of HTTP headers. """ - headers = {} - if "user_agent" in self.config: - headers["User-Agent"] = self.config.get("user_agent") {%- if cookiecutter.auth_method not in ("OAuth2", "JWT") %} # If not using an authenticator, you may also provide inline auth headers: # headers["Private-Token"] = self.config.get("auth_token") # noqa: ERA001 {%- endif %} - return headers + return {} def get_new_paginator(self) -> BaseAPIPaginator: """Create a new pagination helper instance. diff --git a/docs/_templates/stream_class.rst b/docs/_templates/stream_class.rst new file mode 100644 index 000000000..547bd6454 --- /dev/null +++ b/docs/_templates/stream_class.rst @@ -0,0 +1,10 @@ +{{ fullname }} +{{ "=" * fullname|length }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ name }} + :members: + :show-inheritance: + :inherited-members: Stream + :special-members: __init__ diff --git a/docs/classes/singer_sdk.GraphQLStream.rst b/docs/classes/singer_sdk.GraphQLStream.rst index 41953196f..2e801b64b 100644 --- a/docs/classes/singer_sdk.GraphQLStream.rst +++ b/docs/classes/singer_sdk.GraphQLStream.rst @@ -5,4 +5,6 @@ .. autoclass:: GraphQLStream :members: - :special-members: __init__, __call__ \ No newline at end of file + :show-inheritance: + :inherited-members: Stream + :special-members: __init__ \ No newline at end of file diff --git a/docs/classes/singer_sdk.RESTStream.rst b/docs/classes/singer_sdk.RESTStream.rst index 9710c6303..6ed4d7b47 100644 --- a/docs/classes/singer_sdk.RESTStream.rst +++ b/docs/classes/singer_sdk.RESTStream.rst @@ -5,4 +5,6 @@ .. autoclass:: RESTStream :members: - :special-members: __init__, __call__ \ No newline at end of file + :show-inheritance: + :inherited-members: Stream + :special-members: __init__ \ No newline at end of file diff --git a/docs/classes/singer_sdk.SQLStream.rst b/docs/classes/singer_sdk.SQLStream.rst index f72894088..bc0546f31 100644 --- a/docs/classes/singer_sdk.SQLStream.rst +++ b/docs/classes/singer_sdk.SQLStream.rst @@ -5,4 +5,6 @@ .. autoclass:: SQLStream :members: - :special-members: __init__, __call__ \ No newline at end of file + :show-inheritance: + :inherited-members: Stream + :special-members: __init__ \ No newline at end of file diff --git a/docs/classes/singer_sdk.Stream.rst b/docs/classes/singer_sdk.Stream.rst index db028a912..946f040d6 100644 --- a/docs/classes/singer_sdk.Stream.rst +++ b/docs/classes/singer_sdk.Stream.rst @@ -5,4 +5,6 @@ .. autoclass:: Stream :members: - :special-members: __init__, __call__ \ No newline at end of file + :show-inheritance: + :inherited-members: Stream + :special-members: __init__ \ No newline at end of file diff --git a/docs/reference.rst b/docs/reference.rst index 71e0d6ddb..6522c8a36 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -21,7 +21,7 @@ Stream Classes .. autosummary:: :toctree: classes - :template: class.rst + :template: stream_class.rst Stream RESTStream diff --git a/pyproject.toml b/pyproject.toml index 55771a960..589fb6336 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,7 +170,7 @@ filterwarnings = [ # https://github.com/meltano/sdk/issues/1354 "ignore:The function singer_sdk.testing.get_standard_tap_tests is deprecated:DeprecationWarning", # https://github.com/meltano/sdk/issues/2744 - "ignore::singer_sdk.helpers._compat.SingerSDKDeprecationWarning", + "default::singer_sdk.helpers._compat.SingerSDKDeprecationWarning", # TODO: Address this SQLite warning in Python 3.13+ "ignore::ResourceWarning", ] diff --git a/singer_sdk/authenticators.py b/singer_sdk/authenticators.py index 3a7fd8833..916669ef7 100644 --- a/singer_sdk/authenticators.py +++ b/singer_sdk/authenticators.py @@ -22,7 +22,7 @@ if t.TYPE_CHECKING: import logging - from singer_sdk.streams.rest import RESTStream + from singer_sdk.streams.rest import _HTTPStream def _add_parameters(initial_url: str, extra_parameters: dict) -> str: @@ -91,7 +91,7 @@ class APIAuthenticatorBase: auth_params: URL query parameters for authentication. """ - def __init__(self, stream: RESTStream) -> None: + def __init__(self, stream: _HTTPStream) -> None: """Init authenticator. Args: @@ -156,7 +156,7 @@ class SimpleAuthenticator(APIAuthenticatorBase): def __init__( self, - stream: RESTStream, + stream: _HTTPStream, auth_headers: dict | None = None, ) -> None: """Create a new authenticator. @@ -186,7 +186,7 @@ class APIKeyAuthenticator(APIAuthenticatorBase): def __init__( self, - stream: RESTStream, + stream: _HTTPStream, key: str, value: str, location: str = "header", @@ -221,7 +221,7 @@ def __init__( @classmethod def create_for_stream( cls: type[APIKeyAuthenticator], - stream: RESTStream, + stream: _HTTPStream, key: str, value: str, location: str, @@ -249,7 +249,7 @@ class BearerTokenAuthenticator(APIAuthenticatorBase): 'Bearer '. The token will be merged with HTTP headers on the stream. """ - def __init__(self, stream: RESTStream, token: str) -> None: + def __init__(self, stream: _HTTPStream, token: str) -> None: """Create a new authenticator. Args: @@ -266,7 +266,7 @@ def __init__(self, stream: RESTStream, token: str) -> None: @classmethod def create_for_stream( cls: type[BearerTokenAuthenticator], - stream: RESTStream, + stream: _HTTPStream, token: str, ) -> BearerTokenAuthenticator: """Create an Authenticator object specific to the Stream class. @@ -299,7 +299,7 @@ class BasicAuthenticator(APIAuthenticatorBase): def __init__( self, - stream: RESTStream, + stream: _HTTPStream, username: str, password: str, ) -> None: @@ -323,7 +323,7 @@ def __init__( @classmethod def create_for_stream( cls: type[BasicAuthenticator], - stream: RESTStream, + stream: _HTTPStream, username: str, password: str, ) -> BasicAuthenticator: @@ -346,7 +346,7 @@ class OAuthAuthenticator(APIAuthenticatorBase): def __init__( self, - stream: RESTStream, + stream: _HTTPStream, auth_endpoint: str | None = None, oauth_scopes: str | None = None, default_expiration: int | None = None, diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index 7f1f6fb89..dcfdf7104 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -5,6 +5,7 @@ import abc import copy import logging +import sys import typing as t from functools import cached_property from http import HTTPStatus @@ -17,6 +18,7 @@ from singer_sdk import metrics from singer_sdk.authenticators import SimpleAuthenticator from singer_sdk.exceptions import FatalAPIError, RetriableAPIError +from singer_sdk.helpers._compat import SingerSDKDeprecationWarning from singer_sdk.helpers.jsonpath import extract_jsonpath from singer_sdk.pagination import ( BaseAPIPaginator, @@ -26,7 +28,13 @@ ) from singer_sdk.streams.core import Stream +if sys.version_info < (3, 13): + from typing_extensions import deprecated +else: + from warnings import deprecated # pragma: no cover + if t.TYPE_CHECKING: + from collections.abc import Iterable, Mapping from datetime import datetime from backoff.types import Details @@ -41,28 +49,21 @@ _TToken = t.TypeVar("_TToken") -class RESTStream(Stream, t.Generic[_TToken], metaclass=abc.ABCMeta): # noqa: PLR0904 - """Abstract base class for REST API streams.""" +class _HTTPStream(Stream, t.Generic[_TToken], metaclass=abc.ABCMeta): # noqa: PLR0904 + """Abstract base class for HTTP streams.""" _page_size: int = DEFAULT_PAGE_SIZE _requests_session: requests.Session | None - #: HTTP method to use for requests. Defaults to "GET". - rest_method = "GET" - - #: JSONPath expression to extract records from the API response. - records_jsonpath: str = "$[*]" - #: Response code reference for rate limit retries extra_retry_statuses: t.Sequence[int] = [HTTPStatus.TOO_MANY_REQUESTS] - #: Optional JSONPath expression to extract a pagination token from the API response. - #: Example: `"$.next_page"` - next_page_token_jsonpath: str | None = None - #: Optional flag to disable HTTP redirects. Defaults to False. allow_redirects: bool = True + #: Set this to True if the API expects a JSON payload in the request body. + payload_as_json: bool = False + # Private constants. May not be supported in future releases: _LOG_REQUEST_METRICS: bool = True # Disabled by default for safety: @@ -90,7 +91,7 @@ def __init__( schema: dict[str, t.Any] | Schema | None = None, path: str | None = None, ) -> None: - """Initialize the REST stream. + """Initialize the HTTP stream. Args: tap: Singer Tap this stream belongs to. @@ -103,8 +104,6 @@ def __init__( self.path = path self._http_headers: dict = {"User-Agent": self.user_agent} self._requests_session = requests.Session() - self._compiled_jsonpath = None - self._next_page_token_compiled_jsonpath = None @staticmethod def _url_encode(val: str | datetime | bool | int | list[str]) -> str: # noqa: FBT001 @@ -140,6 +139,24 @@ def get_url(self, context: Context | None) -> str: # HTTP Request functions + @property + @deprecated( + "Use `http_method` instead.", + category=SingerSDKDeprecationWarning, + ) + def rest_method(self) -> str: + """HTTP method to use for requests. Defaults to "GET". + + .. deprecated:: 0.43.0 + Override :meth:`~singer_sdk.RESTStream.http_method` instead. + """ + return "GET" + + @property + def http_method(self) -> str: + """HTTP method to use for requests. Defaults to "GET".""" + return self.rest_method + @property def requests_session(self) -> requests.Session: """Get requests session. @@ -369,19 +386,25 @@ def prepare_request( Build a request with the stream's URL, path, query parameters, HTTP headers and authenticator. """ - http_method = self.rest_method + http_method = self.http_method url: str = self.get_url(context) params: dict | str = self.get_url_params(context, next_page_token) request_data = self.prepare_request_payload(context, next_page_token) headers = self.http_headers - return self.build_prepared_request( - method=http_method, - url=url, - params=params, - headers=headers, - json=request_data, - ) + prepare_kwargs: dict[str, t.Any] = { + "method": http_method, + "url": url, + "params": params, + "headers": headers, + } + + if self.payload_as_json: + prepare_kwargs["json"] = request_data + else: + prepare_kwargs["data"] = request_data + + return self.build_prepared_request(**prepare_kwargs) def request_records(self, context: Context | None) -> t.Iterable[dict]: """Request records from REST endpoint(s), returning response records. @@ -522,8 +545,16 @@ def prepare_request_payload( self, context: Context | None, next_page_token: _TToken | None, - ) -> dict | None: - """Prepare the data payload for the REST API request. + ) -> ( + Iterable[bytes] + | str + | bytes + | list[tuple[t.Any, t.Any]] + | tuple[tuple[t.Any, t.Any]] + | Mapping[str, t.Any] + | None + ): + """Prepare the data payload for the HTTP request. By default, no payload will be sent (return None). @@ -537,27 +568,6 @@ def prepare_request_payload( next page of data. """ - def get_new_paginator(self) -> BaseAPIPaginator: - """Get a fresh paginator for this API endpoint. - - Returns: - A paginator instance. - """ - if hasattr(self, "get_next_page_token"): - warn( - "`RESTStream.get_next_page_token` is deprecated and will not be used " - "in a future version of the Meltano Singer SDK. " - "Override `RESTStream.get_new_paginator` instead.", - DeprecationWarning, - stacklevel=2, - ) - return LegacyStreamPaginator(self) - - if self.next_page_token_jsonpath: - return JSONPathPaginator(self.next_page_token_jsonpath) - - return SimpleHeaderPaginator("X-Next-Page") - @property def http_headers(self) -> dict: """Return headers dict to be used for HTTP requests. @@ -601,6 +611,9 @@ def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]: continue yield transformed_record + # Abstract methods: + + @abc.abstractmethod def parse_response(self, response: requests.Response) -> t.Iterable[dict]: """Parse the response and return an iterator of result records. @@ -610,9 +623,16 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]: Yields: One item for every item found in the response. """ - yield from extract_jsonpath(self.records_jsonpath, input=response.json()) + ... - # Abstract methods: + @abc.abstractmethod + def get_new_paginator(self) -> BaseAPIPaginator: + """Get a fresh paginator for this endpoint. + + Returns: + A paginator instance. + """ + ... @property def authenticator(self) -> Auth: @@ -712,3 +732,72 @@ def backoff_runtime( # noqa: PLR6301 exception = yield # type: ignore[misc] while True: exception = yield value(exception) + + +class RESTStream(_HTTPStream, t.Generic[_TToken], metaclass=abc.ABCMeta): + """Abstract base class for REST API streams.""" + + #: JSONPath expression to extract records from the API response. + records_jsonpath: str = "$[*]" + + #: Optional JSONPath expression to extract a pagination token from the API response. + #: Example: `"$.next_page"` + next_page_token_jsonpath: str | None = None + + payload_as_json: bool = True + """Set this to False if the API expects something other than JSON in the request + body. + + .. versionadded:: 0.43.0 + """ + + def __init__( + self, + tap: Tap, + name: str | None = None, + schema: dict[str, t.Any] | Schema | None = None, + path: str | None = None, + ) -> None: + """Initialize the REST stream. + + Args: + tap: Singer Tap this stream belongs to. + schema: JSON schema for records in this stream. + name: Name of this stream. + path: URL path for this entity stream. + """ + super().__init__(tap, name, schema, path) + self._compiled_jsonpath = None + self._next_page_token_compiled_jsonpath = None + + def parse_response(self, response: requests.Response) -> t.Iterable[dict]: + """Parse the response and return an iterator of result records. + + Args: + response: A raw :class:`requests.Response` + + Yields: + One item for every item found in the response. + """ + yield from extract_jsonpath(self.records_jsonpath, input=response.json()) + + def get_new_paginator(self) -> BaseAPIPaginator: + """Get a fresh paginator for this API endpoint. + + Returns: + A paginator instance. + """ + if hasattr(self, "get_next_page_token"): + warn( + "`RESTStream.get_next_page_token` is deprecated and will not be used " + "in a future version of the Meltano Singer SDK. " + "Override `RESTStream.get_new_paginator` instead.", + DeprecationWarning, + stacklevel=2, + ) + return LegacyStreamPaginator(self) + + if self.next_page_token_jsonpath: + return JSONPathPaginator(self.next_page_token_jsonpath) + + return SimpleHeaderPaginator("X-Next-Page") diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 1a25defac..038f87657 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -5,6 +5,7 @@ import datetime import logging import typing as t +import urllib.parse import pytest import requests @@ -15,8 +16,7 @@ ) from singer_sdk.helpers._classproperty import classproperty from singer_sdk.helpers._compat import datetime_fromisoformat as parse -from singer_sdk.helpers.jsonpath import _compile_jsonpath, extract_jsonpath -from singer_sdk.pagination import first +from singer_sdk.helpers.jsonpath import _compile_jsonpath from singer_sdk.streams.core import REPLICATION_FULL_TABLE, REPLICATION_INCREMENTAL from singer_sdk.streams.graphql import GraphQLStream from singer_sdk.streams.rest import RESTStream @@ -24,6 +24,8 @@ from tests.core.conftest import SimpleTestStream if t.TYPE_CHECKING: + import requests_mock + from singer_sdk import Stream, Tap from tests.core.conftest import SimpleTestTap @@ -42,22 +44,16 @@ class RestTestStream(RESTStream): ).to_dict() replication_key = "updatedAt" + +class RestTestStreamLegacyPagination(RestTestStream): + """Test RESTful stream class with pagination.""" + def get_next_page_token( self, - response: requests.Response, - previous_token: str | None, # noqa: ARG002 - ) -> str | None: - if not self.next_page_token_jsonpath: - return response.headers.get("X-Next-Page", None) - - all_matches = extract_jsonpath( - self.next_page_token_jsonpath, - response.json(), - ) - try: - return first(all_matches) - except StopIteration: - return None + response: requests.Response, # noqa: ARG002 + previous_token: int | None, + ) -> int: + return previous_token + 1 if previous_token is not None else 1 class GraphqlTestStream(GraphQLStream): @@ -316,6 +312,21 @@ def test_jsonpath_rest_stream(tap: Tap, path: str, content: str, result: list[di assert list(records) == result +def test_legacy_pagination(tap: Tap): + """Validate legacy pagination is handled correctly.""" + stream = RestTestStreamLegacyPagination(tap) + + with pytest.deprecated_call(): + stream.get_new_paginator() + + page: int | None = None + page = stream.get_next_page_token(None, page) + assert page == 1 + + page = stream.get_next_page_token(None, page) + assert page == 2 + + def test_jsonpath_graphql_stream_default(tap: Tap): """Validate graphql JSONPath, defaults to the stream name.""" content = """{ @@ -437,11 +448,8 @@ def test_next_page_token_jsonpath( RestTestStream.next_page_token_jsonpath = path stream = RestTestStream(tap) - with pytest.warns(DeprecationWarning): - paginator = stream.get_new_paginator() - + paginator = stream.get_new_paginator() next_page = paginator.get_next(fake_response) - assert next_page == result @@ -484,6 +492,45 @@ def calculate_test_cost( assert f"Total Sync costs for stream {stream.name}" in record.message +def test_non_json_payload(tap: Tap, requests_mock: requests_mock.Mocker): + """Test non-JSON payload is handled correctly.""" + + def callback(request: requests.PreparedRequest, context: requests_mock.Context): # noqa: ARG001 + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.body == "my_key=my_value" + + data = urllib.parse.parse_qs(request.body) + + return { + "data": [ + {"id": 1, "value": f"{data['my_key'][0]}_1"}, + {"id": 2, "value": f"{data['my_key'][0]}_2"}, + ] + } + + class NonJsonStream(RestTestStream): + payload_as_json = False + http_method = "POST" + path = "/non-json" + records_jsonpath = "$.data[*]" + + def prepare_request_payload(self, context, next_page_token): # noqa: ARG002 + return {"my_key": "my_value"} + + stream = NonJsonStream(tap) + + requests_mock.post( + "https://example.com/non-json", + json=callback, + ) + + records = list(stream.request_records(None)) + assert records == [ + {"id": 1, "value": "my_value_1"}, + {"id": 2, "value": "my_value_2"}, + ] + + @pytest.mark.parametrize( "input_catalog,selection", [