Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support other content-types in REST streams #2762

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -45,14 +45,11 @@ def http_headers(self) -> dict:
Returns:
A dictionary of HTTP headers.
"""
headers = {}
if "user_agent" in self.config:
headers["User-Agent"] = self.config.get("user_agent")
{%- if cookiecutter.auth_method not in ("OAuth2", "JWT") %}
# If not using an authenticator, you may also provide inline auth headers:
# headers["Private-Token"] = self.config.get("auth_token")
{%- endif %}
return headers
return {}

def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
"""Parse the response and return an iterator of result records.
Original file line number Diff line number Diff line change
@@ -132,14 +132,11 @@ def http_headers(self) -> dict:
Returns:
A dictionary of HTTP headers.
"""
headers = {}
if "user_agent" in self.config:
headers["User-Agent"] = self.config.get("user_agent")
{%- if cookiecutter.auth_method not in ("OAuth2", "JWT") %}
# If not using an authenticator, you may also provide inline auth headers:
# headers["Private-Token"] = self.config.get("auth_token") # noqa: ERA001
{%- endif %}
return headers
return {}

def get_new_paginator(self) -> BaseAPIPaginator:
"""Create a new pagination helper instance.
10 changes: 10 additions & 0 deletions docs/_templates/stream_class.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{{ fullname }}
{{ "=" * fullname|length }}

.. currentmodule:: {{ module }}

.. autoclass:: {{ name }}
:members:
:show-inheritance:
:inherited-members: Stream
:special-members: __init__
4 changes: 3 additions & 1 deletion docs/classes/singer_sdk.GraphQLStream.rst
Original file line number Diff line number Diff line change
@@ -5,4 +5,6 @@

.. autoclass:: GraphQLStream
:members:
:special-members: __init__, __call__
:show-inheritance:
:inherited-members: Stream
:special-members: __init__
4 changes: 3 additions & 1 deletion docs/classes/singer_sdk.RESTStream.rst
Original file line number Diff line number Diff line change
@@ -5,4 +5,6 @@

.. autoclass:: RESTStream
:members:
:special-members: __init__, __call__
:show-inheritance:
:inherited-members: Stream
:special-members: __init__
4 changes: 3 additions & 1 deletion docs/classes/singer_sdk.SQLStream.rst
Original file line number Diff line number Diff line change
@@ -5,4 +5,6 @@

.. autoclass:: SQLStream
:members:
:special-members: __init__, __call__
:show-inheritance:
:inherited-members: Stream
:special-members: __init__
4 changes: 3 additions & 1 deletion docs/classes/singer_sdk.Stream.rst
Original file line number Diff line number Diff line change
@@ -5,4 +5,6 @@

.. autoclass:: Stream
:members:
:special-members: __init__, __call__
:show-inheritance:
:inherited-members: Stream
:special-members: __init__
2 changes: 1 addition & 1 deletion docs/reference.rst
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ Stream Classes

.. autosummary::
:toctree: classes
:template: class.rst
:template: stream_class.rst

Stream
RESTStream
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -170,7 +170,7 @@ filterwarnings = [
# https://github.com/meltano/sdk/issues/1354
"ignore:The function singer_sdk.testing.get_standard_tap_tests is deprecated:DeprecationWarning",
# https://github.com/meltano/sdk/issues/2744
"ignore::singer_sdk.helpers._compat.SingerSDKDeprecationWarning",
"default::singer_sdk.helpers._compat.SingerSDKDeprecationWarning",
# TODO: Address this SQLite warning in Python 3.13+
"ignore::ResourceWarning",
]
20 changes: 10 additions & 10 deletions singer_sdk/authenticators.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
if t.TYPE_CHECKING:
import logging

from singer_sdk.streams.rest import RESTStream
from singer_sdk.streams.rest import _HTTPStream


def _add_parameters(initial_url: str, extra_parameters: dict) -> str:
@@ -91,7 +91,7 @@ class APIAuthenticatorBase:
auth_params: URL query parameters for authentication.
"""

def __init__(self, stream: RESTStream) -> None:
def __init__(self, stream: _HTTPStream) -> None:
"""Init authenticator.
Args:
@@ -156,7 +156,7 @@ class SimpleAuthenticator(APIAuthenticatorBase):

def __init__(
self,
stream: RESTStream,
stream: _HTTPStream,
auth_headers: dict | None = None,
) -> None:
"""Create a new authenticator.
@@ -186,7 +186,7 @@ class APIKeyAuthenticator(APIAuthenticatorBase):

def __init__(
self,
stream: RESTStream,
stream: _HTTPStream,
key: str,
value: str,
location: str = "header",
@@ -221,7 +221,7 @@ def __init__(
@classmethod
def create_for_stream(
cls: type[APIKeyAuthenticator],
stream: RESTStream,
stream: _HTTPStream,
key: str,
value: str,
location: str,
@@ -249,7 +249,7 @@ class BearerTokenAuthenticator(APIAuthenticatorBase):
'Bearer '. The token will be merged with HTTP headers on the stream.
"""

def __init__(self, stream: RESTStream, token: str) -> None:
def __init__(self, stream: _HTTPStream, token: str) -> None:
"""Create a new authenticator.
Args:
@@ -266,7 +266,7 @@ def __init__(self, stream: RESTStream, token: str) -> None:
@classmethod
def create_for_stream(
cls: type[BearerTokenAuthenticator],
stream: RESTStream,
stream: _HTTPStream,
token: str,
) -> BearerTokenAuthenticator:
"""Create an Authenticator object specific to the Stream class.
@@ -299,7 +299,7 @@ class BasicAuthenticator(APIAuthenticatorBase):

def __init__(
self,
stream: RESTStream,
stream: _HTTPStream,
username: str,
password: str,
) -> None:
@@ -323,7 +323,7 @@ def __init__(
@classmethod
def create_for_stream(
cls: type[BasicAuthenticator],
stream: RESTStream,
stream: _HTTPStream,
username: str,
password: str,
) -> BasicAuthenticator:
@@ -346,7 +346,7 @@ class OAuthAuthenticator(APIAuthenticatorBase):

def __init__(
self,
stream: RESTStream,
stream: _HTTPStream,
auth_endpoint: str | None = None,
oauth_scopes: str | None = None,
default_expiration: int | None = None,
185 changes: 137 additions & 48 deletions singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import abc
import copy
import logging
import sys
import typing as t
from functools import cached_property
from http import HTTPStatus
@@ -17,6 +18,7 @@
from singer_sdk import metrics
from singer_sdk.authenticators import SimpleAuthenticator
from singer_sdk.exceptions import FatalAPIError, RetriableAPIError
from singer_sdk.helpers._compat import SingerSDKDeprecationWarning
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import (
BaseAPIPaginator,
@@ -26,7 +28,13 @@
)
from singer_sdk.streams.core import Stream

if sys.version_info < (3, 13):
from typing_extensions import deprecated
else:
from warnings import deprecated # pragma: no cover

if t.TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from datetime import datetime

from backoff.types import Details
@@ -41,28 +49,21 @@
_TToken = t.TypeVar("_TToken")


class RESTStream(Stream, t.Generic[_TToken], metaclass=abc.ABCMeta): # noqa: PLR0904
"""Abstract base class for REST API streams."""
class _HTTPStream(Stream, t.Generic[_TToken], metaclass=abc.ABCMeta): # noqa: PLR0904
"""Abstract base class for HTTP streams."""

_page_size: int = DEFAULT_PAGE_SIZE
_requests_session: requests.Session | None

#: HTTP method to use for requests. Defaults to "GET".
rest_method = "GET"

#: JSONPath expression to extract records from the API response.
records_jsonpath: str = "$[*]"

#: Response code reference for rate limit retries
extra_retry_statuses: t.Sequence[int] = [HTTPStatus.TOO_MANY_REQUESTS]

#: Optional JSONPath expression to extract a pagination token from the API response.
#: Example: `"$.next_page"`
next_page_token_jsonpath: str | None = None

#: Optional flag to disable HTTP redirects. Defaults to False.
allow_redirects: bool = True

#: Set this to True if the API expects a JSON payload in the request body.
payload_as_json: bool = False

# Private constants. May not be supported in future releases:
_LOG_REQUEST_METRICS: bool = True
# Disabled by default for safety:
@@ -90,7 +91,7 @@ def __init__(
schema: dict[str, t.Any] | Schema | None = None,
path: str | None = None,
) -> None:
"""Initialize the REST stream.
"""Initialize the HTTP stream.
Args:
tap: Singer Tap this stream belongs to.
@@ -103,8 +104,6 @@ def __init__(
self.path = path
self._http_headers: dict = {"User-Agent": self.user_agent}
self._requests_session = requests.Session()
self._compiled_jsonpath = None
self._next_page_token_compiled_jsonpath = None

@staticmethod
def _url_encode(val: str | datetime | bool | int | list[str]) -> str: # noqa: FBT001
@@ -140,6 +139,24 @@ def get_url(self, context: Context | None) -> str:

# HTTP Request functions

@property
@deprecated(
"Use `http_method` instead.",
category=SingerSDKDeprecationWarning,
)
def rest_method(self) -> str:
"""HTTP method to use for requests. Defaults to "GET".
.. deprecated:: 0.43.0
Override :meth:`~singer_sdk.RESTStream.http_method` instead.
"""
return "GET"

@property
def http_method(self) -> str:
"""HTTP method to use for requests. Defaults to "GET"."""
return self.rest_method

@property
def requests_session(self) -> requests.Session:
"""Get requests session.
@@ -369,19 +386,25 @@ def prepare_request(
Build a request with the stream's URL, path, query parameters,
HTTP headers and authenticator.
"""
http_method = self.rest_method
http_method = self.http_method
url: str = self.get_url(context)
params: dict | str = self.get_url_params(context, next_page_token)
request_data = self.prepare_request_payload(context, next_page_token)
headers = self.http_headers

return self.build_prepared_request(
method=http_method,
url=url,
params=params,
headers=headers,
json=request_data,
)
prepare_kwargs: dict[str, t.Any] = {
"method": http_method,
"url": url,
"params": params,
"headers": headers,
}

if self.payload_as_json:
prepare_kwargs["json"] = request_data
else:
prepare_kwargs["data"] = request_data

return self.build_prepared_request(**prepare_kwargs)

def request_records(self, context: Context | None) -> t.Iterable[dict]:
"""Request records from REST endpoint(s), returning response records.
@@ -522,8 +545,16 @@ def prepare_request_payload(
self,
context: Context | None,
next_page_token: _TToken | None,
) -> dict | None:
"""Prepare the data payload for the REST API request.
) -> (
Iterable[bytes]
| str
| bytes
| list[tuple[t.Any, t.Any]]
| tuple[tuple[t.Any, t.Any]]
| Mapping[str, t.Any]
| None
):
"""Prepare the data payload for the HTTP request.
By default, no payload will be sent (return None).
@@ -537,27 +568,6 @@ def prepare_request_payload(
next page of data.
"""

def get_new_paginator(self) -> BaseAPIPaginator:
"""Get a fresh paginator for this API endpoint.
Returns:
A paginator instance.
"""
if hasattr(self, "get_next_page_token"):
warn(
"`RESTStream.get_next_page_token` is deprecated and will not be used "
"in a future version of the Meltano Singer SDK. "
"Override `RESTStream.get_new_paginator` instead.",
DeprecationWarning,
stacklevel=2,
)
return LegacyStreamPaginator(self)

if self.next_page_token_jsonpath:
return JSONPathPaginator(self.next_page_token_jsonpath)

return SimpleHeaderPaginator("X-Next-Page")

@property
def http_headers(self) -> dict:
"""Return headers dict to be used for HTTP requests.
@@ -601,6 +611,9 @@ def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
continue
yield transformed_record

# Abstract methods:

@abc.abstractmethod
def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
"""Parse the response and return an iterator of result records.
@@ -610,9 +623,16 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
Yields:
One item for every item found in the response.
"""
yield from extract_jsonpath(self.records_jsonpath, input=response.json())
...

# Abstract methods:
@abc.abstractmethod
def get_new_paginator(self) -> BaseAPIPaginator:
"""Get a fresh paginator for this endpoint.
Returns:
A paginator instance.
"""
...

@property
def authenticator(self) -> Auth:
@@ -712,3 +732,72 @@ def backoff_runtime( # noqa: PLR6301
exception = yield # type: ignore[misc]
while True:
exception = yield value(exception)


class RESTStream(_HTTPStream, t.Generic[_TToken], metaclass=abc.ABCMeta):
"""Abstract base class for REST API streams."""

#: JSONPath expression to extract records from the API response.
records_jsonpath: str = "$[*]"

#: Optional JSONPath expression to extract a pagination token from the API response.
#: Example: `"$.next_page"`
next_page_token_jsonpath: str | None = None

payload_as_json: bool = True
"""Set this to False if the API expects something other than JSON in the request
body.
.. versionadded:: 0.43.0
"""

def __init__(
self,
tap: Tap,
name: str | None = None,
schema: dict[str, t.Any] | Schema | None = None,
path: str | None = None,
) -> None:
"""Initialize the REST stream.
Args:
tap: Singer Tap this stream belongs to.
schema: JSON schema for records in this stream.
name: Name of this stream.
path: URL path for this entity stream.
"""
super().__init__(tap, name, schema, path)
self._compiled_jsonpath = None
self._next_page_token_compiled_jsonpath = None

def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
"""Parse the response and return an iterator of result records.
Args:
response: A raw :class:`requests.Response`
Yields:
One item for every item found in the response.
"""
yield from extract_jsonpath(self.records_jsonpath, input=response.json())

def get_new_paginator(self) -> BaseAPIPaginator:
"""Get a fresh paginator for this API endpoint.
Returns:
A paginator instance.
"""
if hasattr(self, "get_next_page_token"):
warn(
"`RESTStream.get_next_page_token` is deprecated and will not be used "
"in a future version of the Meltano Singer SDK. "
"Override `RESTStream.get_new_paginator` instead.",
DeprecationWarning,
stacklevel=2,
)
return LegacyStreamPaginator(self)

if self.next_page_token_jsonpath:
return JSONPathPaginator(self.next_page_token_jsonpath)

return SimpleHeaderPaginator("X-Next-Page")
87 changes: 67 additions & 20 deletions tests/core/test_streams.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import datetime
import logging
import typing as t
import urllib.parse

import pytest
import requests
@@ -15,15 +16,16 @@
)
from singer_sdk.helpers._classproperty import classproperty
from singer_sdk.helpers._compat import datetime_fromisoformat as parse
from singer_sdk.helpers.jsonpath import _compile_jsonpath, extract_jsonpath
from singer_sdk.pagination import first
from singer_sdk.helpers.jsonpath import _compile_jsonpath
from singer_sdk.streams.core import REPLICATION_FULL_TABLE, REPLICATION_INCREMENTAL
from singer_sdk.streams.graphql import GraphQLStream
from singer_sdk.streams.rest import RESTStream
from singer_sdk.typing import IntegerType, PropertiesList, Property, StringType
from tests.core.conftest import SimpleTestStream

if t.TYPE_CHECKING:
import requests_mock

from singer_sdk import Stream, Tap
from tests.core.conftest import SimpleTestTap

@@ -42,22 +44,16 @@ class RestTestStream(RESTStream):
).to_dict()
replication_key = "updatedAt"


class RestTestStreamLegacyPagination(RestTestStream):
"""Test RESTful stream class with pagination."""

def get_next_page_token(
self,
response: requests.Response,
previous_token: str | None, # noqa: ARG002
) -> str | None:
if not self.next_page_token_jsonpath:
return response.headers.get("X-Next-Page", None)

all_matches = extract_jsonpath(
self.next_page_token_jsonpath,
response.json(),
)
try:
return first(all_matches)
except StopIteration:
return None
response: requests.Response, # noqa: ARG002
previous_token: int | None,
) -> int:
return previous_token + 1 if previous_token is not None else 1


class GraphqlTestStream(GraphQLStream):
@@ -316,6 +312,21 @@ def test_jsonpath_rest_stream(tap: Tap, path: str, content: str, result: list[di
assert list(records) == result


def test_legacy_pagination(tap: Tap):
"""Validate legacy pagination is handled correctly."""
stream = RestTestStreamLegacyPagination(tap)

with pytest.deprecated_call():
stream.get_new_paginator()

page: int | None = None
page = stream.get_next_page_token(None, page)
assert page == 1

page = stream.get_next_page_token(None, page)
assert page == 2


def test_jsonpath_graphql_stream_default(tap: Tap):
"""Validate graphql JSONPath, defaults to the stream name."""
content = """{
@@ -437,11 +448,8 @@ def test_next_page_token_jsonpath(
RestTestStream.next_page_token_jsonpath = path
stream = RestTestStream(tap)

with pytest.warns(DeprecationWarning):
paginator = stream.get_new_paginator()

paginator = stream.get_new_paginator()
next_page = paginator.get_next(fake_response)

assert next_page == result


@@ -484,6 +492,45 @@ def calculate_test_cost(
assert f"Total Sync costs for stream {stream.name}" in record.message


def test_non_json_payload(tap: Tap, requests_mock: requests_mock.Mocker):
"""Test non-JSON payload is handled correctly."""

def callback(request: requests.PreparedRequest, context: requests_mock.Context): # noqa: ARG001
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
assert request.body == "my_key=my_value"

data = urllib.parse.parse_qs(request.body)

return {
"data": [
{"id": 1, "value": f"{data['my_key'][0]}_1"},
{"id": 2, "value": f"{data['my_key'][0]}_2"},
]
}

class NonJsonStream(RestTestStream):
payload_as_json = False
http_method = "POST"
path = "/non-json"
records_jsonpath = "$.data[*]"

def prepare_request_payload(self, context, next_page_token): # noqa: ARG002
return {"my_key": "my_value"}

stream = NonJsonStream(tap)

requests_mock.post(
"https://example.com/non-json",
json=callback,
)

records = list(stream.request_records(None))
assert records == [
{"id": 1, "value": "my_value_1"},
{"id": 2, "value": "my_value_2"},
]


@pytest.mark.parametrize(
"input_catalog,selection",
[