diff --git a/tests/test_ratelimits.py b/tests/test_ratelimits.py new file mode 100644 index 0000000..264615d --- /dev/null +++ b/tests/test_ratelimits.py @@ -0,0 +1,241 @@ +from datetime import datetime, timedelta +from httpx import Response +import pytest +import asyncio + +from tests.common import get_response_json +from xbox.webapi.api.provider.ratelimitedprovider import RateLimitedProvider + +from xbox.webapi.common.exceptions import RateLimitExceededException, XboxException +from xbox.webapi.common.ratelimits import CombinedRateLimit +from xbox.webapi.common.ratelimits.models import TimePeriod + + +def helper_test_combinedratelimit( + crl: CombinedRateLimit, burstLimit: int, sustainLimit: int +): + burst = crl.get_limits_by_period(TimePeriod.BURST) + sustain = crl.get_limits_by_period(TimePeriod.SUSTAIN) + + # These functions should return a list with one element + assert type(burst) == list + assert type(sustain) == list + + assert len(burst) == 1 + assert len(sustain) == 1 + + # Check that their limits are what we expect + assert burst[0].get_limit() == burstLimit + assert sustain[0].get_limit() == sustainLimit + + +def test_ratelimitedprovider_rate_limits_same_rw_values(xbl_client): + class child_class(RateLimitedProvider): + RATE_LIMITS = {"burst": 1, "sustain": 2} + + instance = child_class(xbl_client) + + helper_test_combinedratelimit(instance.rate_limit_read, 1, 2) + helper_test_combinedratelimit(instance.rate_limit_write, 1, 2) + + +def test_ratelimitedprovider_rate_limits_diff_rw_values(xbl_client): + class child_class(RateLimitedProvider): + RATE_LIMITS = { + "burst": {"read": 1, "write": 2}, + "sustain": {"read": 3, "write": 4}, + } + + instance = child_class(xbl_client) + + helper_test_combinedratelimit(instance.rate_limit_read, 1, 3) + helper_test_combinedratelimit(instance.rate_limit_write, 2, 4) + + +def test_ratelimitedprovider_rate_limits_mixed(xbl_client): + class burst_diff(RateLimitedProvider): + RATE_LIMITS = {"burst": {"read": 1, "write": 2}, "sustain": 3} + + burst_diff_inst = burst_diff(xbl_client) + + # Sustain values are the same (third paramater) + helper_test_combinedratelimit(burst_diff_inst.rate_limit_read, 1, 3) + helper_test_combinedratelimit(burst_diff_inst.rate_limit_write, 2, 3) + + class sustain_diff(RateLimitedProvider): + RATE_LIMITS = {"burst": 4, "sustain": {"read": 5, "write": 6}} + + sustain_diff_inst = sustain_diff(xbl_client) + + # Burst values are the same (second paramater) + helper_test_combinedratelimit(sustain_diff_inst.rate_limit_read, 4, 5) + helper_test_combinedratelimit(sustain_diff_inst.rate_limit_write, 4, 6) + + +def test_ratelimitedprovider_rate_limits_missing_values_correct_type(xbl_client): + class child_class(RateLimitedProvider): + RATE_LIMITS = {"incorrect": "values"} + + with pytest.raises(XboxException) as exception: + child_class(xbl_client) + + ex: XboxException = exception.value + assert "RATE_LIMITS object missing required keys" in ex.args[0] + + +def test_ratelimitedprovider_rate_limits_not_set(xbl_client): + class child_class(RateLimitedProvider): + pass + + with pytest.raises(XboxException) as exception: + child_class(xbl_client) + + ex: XboxException = exception.value + assert "RateLimitedProvider as parent class but RATE_LIMITS not set!" in ex.args[0] + + +def test_ratelimitedprovider_rate_limits_incorrect_key_type(xbl_client): + class child_class(RateLimitedProvider): + RATE_LIMITS = {"burst": True, "sustain": False} + + with pytest.raises(XboxException) as exception: + child_class(xbl_client) + + ex: XboxException = exception.value + assert "RATE_LIMITS value types not recognised." in ex.args[0] + + +@pytest.mark.asyncio +async def test_ratelimits_exceeded_burst_only(respx_mock, xbl_client): + async def make_request(): + route = respx_mock.get("https://social.xboxlive.com").mock( + return_value=Response(200, json=get_response_json("people_summary_own")) + ) + ret = await xbl_client.people.get_friends_summary_own() + + assert route.called + + # Record the start time to ensure that the timeouts are the correct length + start_time = datetime.now() + + # Make as many requests as possible without exceeding + max_request_num = xbl_client.people.RATE_LIMITS["burst"] + for i in range(max_request_num): + await make_request() + + # Make another request, ensure that it raises the exception. + with pytest.raises(RateLimitExceededException) as exception: + await make_request() + + # Get the error instance from pytest + ex: RateLimitExceededException = exception.value + + # Assert that the counter matches the max request num (should not have incremented above max value) + assert ex.rate_limit.get_counter() == max_request_num + + # Get the timeout we were issued + try_again_in = ex.rate_limit.get_reset_after() + + # Assert that the timeout is the correct length + delta: timedelta = try_again_in - start_time + assert delta.seconds == TimePeriod.BURST.value # 15 seconds + + +async def helper_reach_and_wait_for_burst( + make_request, start_time, burst_limit: int, expected_counter: int +): + # Make as many requests as possible without exceeding the BURST limit. + for i in range(burst_limit): + await make_request() + + # Make another request, ensure that it raises the exception. + with pytest.raises(RateLimitExceededException) as exception: + await make_request() + + # Get the error instance from pytest + ex: RateLimitExceededException = exception.value + + # Assert that the counter matches the what we expect (burst, 2x burstm etc) + assert ex.rate_limit.get_counter() == expected_counter + + # Get the reset_after value + # (if we call it after waiting for it to expire the function will return None) + burst_resets_after = ex.rate_limit.get_reset_after() + + # Wait for the burst limit timeout to elapse. + await asyncio.sleep(TimePeriod.BURST.value) # 15 seconds + + # Assert that the reset_after value has passed. + assert burst_resets_after < datetime.now() + + +@pytest.mark.asyncio +async def test_ratelimits_exceeded_sustain_only(respx_mock, xbl_client): + async def make_request(): + route = respx_mock.get("https://social.xboxlive.com").mock( + return_value=Response(200, json=get_response_json("people_summary_own")) + ) + ret = await xbl_client.people.get_friends_summary_own() + + assert route.called + + # Record the start time to ensure that the timeouts are the correct length + start_time = datetime.now() + + # Get the max requests for this route. + max_request_num = xbl_client.people.RATE_LIMITS["sustain"] # 30 + burst_max_request_num = xbl_client.people.RATE_LIMITS["burst"] # 10 + + # In this case, the BURST limit is three times that of SUSTAIN, so we need to exceed the burst limit three times. + + # Exceed the burst limit and wait for it to reset (10 requests) + await helper_reach_and_wait_for_burst( + make_request, start_time, burst_limit=burst_max_request_num, expected_counter=10 + ) + + # Repeat: Exceed the burst limit and wait for it to reset (10 requests) + # Counter (the sustain one will be returned) + # For (CombinedRateLimit).get_counter(), the highest counter is returned. (sustain in this case) + await helper_reach_and_wait_for_burst( + make_request, start_time, burst_limit=burst_max_request_num, expected_counter=20 + ) + + # Now, make the rest of the requests (10 left, 20/30 done!) + for i in range(10): + await make_request() + + # Wait for the burst limit to 'reset'. + await asyncio.sleep(TimePeriod.BURST.value) # 15 seconds + + # Now, we have made 30 requests. + # The counters should be as follows: + # - BURST: 0* (will reset on next check) + # - SUSTAIN: 30 + # The next request we make should exceed the SUSTAIN rate limit. + + # Make another request, ensure that it raises the exception. + with pytest.raises(RateLimitExceededException) as exception: + await make_request() + + # Get the error instance from pytest + ex: RateLimitExceededException = exception.value + + # Get the SingleRateLimit objects from the exception + rl: CombinedRateLimit = ex.rate_limit + burst = rl.get_limits_by_period(TimePeriod.BURST)[0] + sustain = rl.get_limits_by_period(TimePeriod.SUSTAIN)[0] + + # Assert that we have only exceeded the sustain limit. + assert not burst.is_exceeded() + assert sustain.is_exceeded() + + # Assert that the counter matches the max request num (should not have incremented above max value) + assert ex.rate_limit.get_counter() == max_request_num + + # Get the timeout we were issued + try_again_in = ex.rate_limit.get_reset_after() + + # Assert that the timeout is the correct length + # The SUSTAIN counter has not been reset during this test, so the try again in should be 300 seconds since we started this test. + delta: timedelta = try_again_in - start_time + assert delta.seconds == TimePeriod.SUSTAIN.value # 300 seconds (5 minutes) diff --git a/xbox/webapi/api/client.py b/xbox/webapi/api/client.py index 6cf17ec..4276685 100644 --- a/xbox/webapi/api/client.py +++ b/xbox/webapi/api/client.py @@ -28,6 +28,8 @@ from xbox.webapi.api.provider.usersearch import UserSearchProvider from xbox.webapi.api.provider.userstats import UserStatsProvider from xbox.webapi.authentication.manager import AuthenticationManager +from xbox.webapi.common.exceptions import RateLimitExceededException +from xbox.webapi.common.ratelimits import RateLimit log = logging.getLogger("xbox.api") @@ -55,6 +57,9 @@ async def request( extra_params = kwargs.pop("extra_params", None) extra_data = kwargs.pop("extra_data", None) + # Rate limit object + rate_limits: RateLimit = kwargs.pop("rate_limits", None) + if include_auth: # Ensure tokens valid await self._auth_mgr.refresh_tokens() @@ -78,10 +83,20 @@ async def request( data = data or {} data.update(extra_data) - return await self._auth_mgr.session.request( + if rate_limits: + # Check if rate limits have been exceeded for this endpoint + if rate_limits.is_exceeded(): + raise RateLimitExceededException("Rate limit exceeded", rate_limits) + + response = await self._auth_mgr.session.request( method, url, **kwargs, headers=headers, params=params, data=data ) + if rate_limits: + rate_limits.increment() + + return response + async def get(self, url: str, **kwargs: Any) -> Response: return await self.request("GET", url, **kwargs) diff --git a/xbox/webapi/api/provider/achievements/__init__.py b/xbox/webapi/api/provider/achievements/__init__.py index 106bcba..27111ec 100644 --- a/xbox/webapi/api/provider/achievements/__init__.py +++ b/xbox/webapi/api/provider/achievements/__init__.py @@ -9,14 +9,16 @@ AchievementResponse, RecentProgressResponse, ) -from xbox.webapi.api.provider.baseprovider import BaseProvider +from xbox.webapi.api.provider.ratelimitedprovider import RateLimitedProvider -class AchievementsProvider(BaseProvider): +class AchievementsProvider(RateLimitedProvider): ACHIEVEMENTS_URL = "https://achievements.xboxlive.com" HEADERS_GAME_360_PROGRESS = {"x-xbl-contract-version": "1"} HEADERS_GAME_PROGRESS = {"x-xbl-contract-version": "2"} + RATE_LIMITS = {"burst": 100, "sustain": 300} + async def get_achievements_detail_item( self, xuid, service_config_id, achievement_id, **kwargs ) -> AchievementResponse: @@ -33,7 +35,10 @@ async def get_achievements_detail_item( """ url = f"{self.ACHIEVEMENTS_URL}/users/xuid({xuid})/achievements/{service_config_id}/{achievement_id}" resp = await self.client.session.get( - url, headers=self.HEADERS_GAME_PROGRESS, **kwargs + url, + headers=self.HEADERS_GAME_PROGRESS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return AchievementResponse(**resp.json()) @@ -54,7 +59,11 @@ async def get_achievements_xbox360_all( url = f"{self.ACHIEVEMENTS_URL}/users/xuid({xuid})/titleachievements?" params = {"titleId": title_id} resp = await self.client.session.get( - url, params=params, headers=self.HEADERS_GAME_360_PROGRESS, **kwargs + url, + params=params, + headers=self.HEADERS_GAME_360_PROGRESS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return Achievement360Response(**resp.json()) @@ -75,7 +84,11 @@ async def get_achievements_xbox360_earned( url = f"{self.ACHIEVEMENTS_URL}/users/xuid({xuid})/achievements?" params = {"titleId": title_id} resp = await self.client.session.get( - url, params=params, headers=self.HEADERS_GAME_360_PROGRESS, **kwargs + url, + params=params, + headers=self.HEADERS_GAME_360_PROGRESS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return Achievement360Response(**resp.json()) @@ -94,7 +107,10 @@ async def get_achievements_xbox360_recent_progress_and_info( """ url = f"{self.ACHIEVEMENTS_URL}/users/xuid({xuid})/history/titles" resp = await self.client.session.get( - url, headers=self.HEADERS_GAME_360_PROGRESS, **kwargs + url, + headers=self.HEADERS_GAME_360_PROGRESS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return Achievement360ProgressResponse(**resp.json()) @@ -115,7 +131,11 @@ async def get_achievements_xboxone_gameprogress( url = f"{self.ACHIEVEMENTS_URL}/users/xuid({xuid})/achievements?" params = {"titleId": title_id} resp = await self.client.session.get( - url, params=params, headers=self.HEADERS_GAME_PROGRESS, **kwargs + url, + params=params, + headers=self.HEADERS_GAME_PROGRESS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return AchievementResponse(**resp.json()) @@ -134,7 +154,10 @@ async def get_achievements_xboxone_recent_progress_and_info( """ url = f"{self.ACHIEVEMENTS_URL}/users/xuid({xuid})/history/titles" resp = await self.client.session.get( - url, headers=self.HEADERS_GAME_PROGRESS, **kwargs + url, + headers=self.HEADERS_GAME_PROGRESS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return RecentProgressResponse(**resp.json()) diff --git a/xbox/webapi/api/provider/people/__init__.py b/xbox/webapi/api/provider/people/__init__.py index 2c1ca7c..93d3882 100644 --- a/xbox/webapi/api/provider/people/__init__.py +++ b/xbox/webapi/api/provider/people/__init__.py @@ -3,7 +3,7 @@ """ from typing import List -from xbox.webapi.api.provider.baseprovider import BaseProvider +from xbox.webapi.api.provider.ratelimitedprovider import RateLimitedProvider from xbox.webapi.api.provider.people.models import ( PeopleDecoration, PeopleResponse, @@ -11,7 +11,7 @@ ) -class PeopleProvider(BaseProvider): +class PeopleProvider(RateLimitedProvider): SOCIAL_URL = "https://social.xboxlive.com" HEADERS_SOCIAL = {"x-xbl-contract-version": "2"} PEOPLE_URL = "https://peoplehub.xboxlive.com" @@ -21,6 +21,9 @@ class PeopleProvider(BaseProvider): } SEPERATOR = "," + # NOTE: Rate Limits are noted for social.xboxlive.com ONLY + RATE_LIMITS = {"burst": 10, "sustain": 30} + def __init__(self, client): """ Initialize Baseclass, set 'Accept-Language' header from client instance @@ -129,7 +132,9 @@ async def get_friends_summary_own(self, **kwargs) -> PeopleSummaryResponse: :class:`PeopleSummaryResponse`: People Summary Response """ url = self.SOCIAL_URL + "/users/me/summary" - resp = await self.client.session.get(url, headers=self.HEADERS_SOCIAL, **kwargs) + resp = await self.client.session.get( + url, headers=self.HEADERS_SOCIAL, rate_limits=self.rate_limit_read, **kwargs + ) resp.raise_for_status() return PeopleSummaryResponse(**resp.json()) @@ -146,7 +151,9 @@ async def get_friends_summary_by_xuid( :class:`PeopleSummaryResponse`: People Summary Response """ url = self.SOCIAL_URL + f"/users/xuid({xuid})/summary" - resp = await self.client.session.get(url, headers=self.HEADERS_SOCIAL, **kwargs) + resp = await self.client.session.get( + url, headers=self.HEADERS_SOCIAL, rate_limits=self.rate_limit_read, **kwargs + ) resp.raise_for_status() return PeopleSummaryResponse(**resp.json()) @@ -163,6 +170,8 @@ async def get_friends_summary_by_gamertag( :class:`PeopleSummaryResponse`: People Summary Response """ url = self.SOCIAL_URL + f"/users/gt({gamertag})/summary" - resp = await self.client.session.get(url, headers=self.HEADERS_SOCIAL, **kwargs) + resp = await self.client.session.get( + url, headers=self.HEADERS_SOCIAL, rate_limits=self.rate_limit_read, **kwargs + ) resp.raise_for_status() return PeopleSummaryResponse(**resp.json()) diff --git a/xbox/webapi/api/provider/profile/__init__.py b/xbox/webapi/api/provider/profile/__init__.py index cad9dd0..4cce30b 100644 --- a/xbox/webapi/api/provider/profile/__init__.py +++ b/xbox/webapi/api/provider/profile/__init__.py @@ -5,15 +5,17 @@ """ from typing import List -from xbox.webapi.api.provider.baseprovider import BaseProvider +from xbox.webapi.api.provider.ratelimitedprovider import RateLimitedProvider from xbox.webapi.api.provider.profile.models import ProfileResponse, ProfileSettings -class ProfileProvider(BaseProvider): +class ProfileProvider(RateLimitedProvider): PROFILE_URL = "https://profile.xboxlive.com" HEADERS_PROFILE = {"x-xbl-contract-version": "3"} SEPARATOR = "," + RATE_LIMITS = {"burst": 10, "sustain": 30} + async def get_profiles(self, xuid_list: List[str], **kwargs) -> ProfileResponse: """ Get profile info for list of xuids @@ -45,7 +47,11 @@ async def get_profiles(self, xuid_list: List[str], **kwargs) -> ProfileResponse: } url = self.PROFILE_URL + "/users/batch/profile/settings" resp = await self.client.session.post( - url, json=post_data, headers=self.HEADERS_PROFILE, **kwargs + url, + json=post_data, + headers=self.HEADERS_PROFILE, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return ProfileResponse(**resp.json()) @@ -83,7 +89,11 @@ async def get_profile_by_xuid(self, target_xuid: str, **kwargs) -> ProfileRespon ) } resp = await self.client.session.get( - url, params=params, headers=self.HEADERS_PROFILE, **kwargs + url, + params=params, + headers=self.HEADERS_PROFILE, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return ProfileResponse(**resp.json()) @@ -121,7 +131,11 @@ async def get_profile_by_gamertag(self, gamertag: str, **kwargs) -> ProfileRespo ) } resp = await self.client.session.get( - url, params=params, headers=self.HEADERS_PROFILE, **kwargs + url, + params=params, + headers=self.HEADERS_PROFILE, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return ProfileResponse(**resp.json()) diff --git a/xbox/webapi/api/provider/ratelimitedprovider.py b/xbox/webapi/api/provider/ratelimitedprovider.py new file mode 100644 index 0000000..2661d68 --- /dev/null +++ b/xbox/webapi/api/provider/ratelimitedprovider.py @@ -0,0 +1,76 @@ +""" +RateLimitedProvider + +Subclassed by providers with rate limit support +""" + +from typing import Union, Dict +from xbox.webapi.api.provider.baseprovider import BaseProvider +from xbox.webapi.common.exceptions import XboxException +from xbox.webapi.common.ratelimits.models import LimitType, ParsedRateLimit, TimePeriod +from xbox.webapi.common.ratelimits import CombinedRateLimit + + +class RateLimitedProvider(BaseProvider): + # dict -> Dict (typing.dict) https://stackoverflow.com/a/63460173 + RATE_LIMITS: Dict[str, Union[int, Dict[str, int]]] + + def __init__(self, client): + """ + Initialize Baseclass + + Args: + client (:class:`XboxLiveClient`): Instance of XboxLiveClient + """ + super().__init__(client) + + # Check that RATE_LIMITS set defined in the child class + if hasattr(self, "RATE_LIMITS"): + # Note: we cannot check (type(self.RATE_LIMITS) == dict) as the type hints have already defined it as such + if "burst" and "sustain" in self.RATE_LIMITS: + # We have the required keys, attempt to parse. + # (type-checking for the values is performed in __parse_rate_limit_key) + self.__handle_rate_limit_setup() + else: + raise XboxException( + "RATE_LIMITS object missing required keys 'burst', 'sustain'" + ) + else: + raise XboxException( + "RateLimitedProvider as parent class but RATE_LIMITS not set!" + ) + + def __handle_rate_limit_setup(self): + # Retrieve burst and sustain from the dict + burst_key = self.RATE_LIMITS["burst"] + sustain_key = self.RATE_LIMITS["sustain"] + + # Parse the rate limit dict values + burst_rate_limits = self.__parse_rate_limit_key(burst_key, TimePeriod.BURST) + sustain_rate_limits = self.__parse_rate_limit_key( + sustain_key, TimePeriod.SUSTAIN + ) + + # Instanciate CombinedRateLimits for read and write respectively + self.rate_limit_read = CombinedRateLimit( + burst_rate_limits, sustain_rate_limits, type=LimitType.READ + ) + self.rate_limit_write = CombinedRateLimit( + burst_rate_limits, sustain_rate_limits, type=LimitType.WRITE + ) + + def __parse_rate_limit_key( + self, key: Union[int, Dict[str, int]], period: TimePeriod + ) -> ParsedRateLimit: + key_type = type(key) + if key_type == int: + return ParsedRateLimit(read=key, write=key, period=period) + elif key_type == dict: + # TODO: schema here? + # Since the key-value pairs match we can just pass the dict to the model + return ParsedRateLimit(**key, period=period) + # return ParsedRateLimit(read=key["read"], write=key["write"]) + else: + raise XboxException( + "RATE_LIMITS value types not recognised. Must be one of 'int, 'dict'." + ) diff --git a/xbox/webapi/api/provider/userstats/__init__.py b/xbox/webapi/api/provider/userstats/__init__.py index ed1f52c..4ad4151 100644 --- a/xbox/webapi/api/provider/userstats/__init__.py +++ b/xbox/webapi/api/provider/userstats/__init__.py @@ -3,19 +3,24 @@ """ from typing import List, Optional -from xbox.webapi.api.provider.baseprovider import BaseProvider +from xbox.webapi.api.provider.ratelimitedprovider import RateLimitedProvider from xbox.webapi.api.provider.userstats.models import ( GeneralStatsField, UserStatsResponse, ) -class UserStatsProvider(BaseProvider): +class UserStatsProvider(RateLimitedProvider): USERSTATS_URL = "https://userstats.xboxlive.com" HEADERS_USERSTATS = {"x-xbl-contract-version": "2"} HEADERS_USERSTATS_WITH_METADATA = {"x-xbl-contract-version": "3"} SEPERATOR = "," + # NOTE: Stats Read (userstats.xboxlive.com) and Stats Write (statswrite.xboxlive.com) + # Are mentioned as their own objects but their rate limits are the same and do not collide + # (Stats Read -> read rate limit, Stats Write -> write rate limit) + RATE_LIMITS = {"burst": 100, "sustain": 300} + async def get_stats( self, xuid: str, @@ -40,7 +45,10 @@ async def get_stats( url = f"{self.USERSTATS_URL}/users/xuid({xuid})/scids/{service_config_id}/stats/{stats}" resp = await self.client.session.get( - url, headers=self.HEADERS_USERSTATS, **kwargs + url, + headers=self.HEADERS_USERSTATS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return UserStatsResponse(**resp.json()) @@ -70,7 +78,11 @@ async def get_stats_with_metadata( url = f"{self.USERSTATS_URL}/users/xuid({xuid})/scids/{service_config_id}/stats/{stats}" params = {"include": "valuemetadata"} resp = await self.client.session.get( - url, params=params, headers=self.HEADERS_USERSTATS_WITH_METADATA, **kwargs + url, + params=params, + headers=self.HEADERS_USERSTATS_WITH_METADATA, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return UserStatsResponse(**resp.json()) @@ -104,7 +116,11 @@ async def get_stats_batch( "xuids": xuids, } resp = await self.client.session.post( - url, json=post_data, headers=self.HEADERS_USERSTATS, **kwargs + url, + json=post_data, + headers=self.HEADERS_USERSTATS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return UserStatsResponse(**resp.json()) @@ -139,7 +155,11 @@ async def get_stats_batch_by_scid( "xuids": xuids, } resp = await self.client.session.post( - url, json=post_data, headers=self.HEADERS_USERSTATS, **kwargs + url, + json=post_data, + headers=self.HEADERS_USERSTATS, + rate_limits=self.rate_limit_read, + **kwargs, ) resp.raise_for_status() return UserStatsResponse(**resp.json()) diff --git a/xbox/webapi/common/exceptions.py b/xbox/webapi/common/exceptions.py index 44da117..0e278e5 100644 --- a/xbox/webapi/common/exceptions.py +++ b/xbox/webapi/common/exceptions.py @@ -3,6 +3,9 @@ """ +from xbox.webapi.common.ratelimits import RateLimit + + class XboxException(Exception): """Base exception for all Xbox exceptions to subclass""" @@ -46,3 +49,10 @@ class NotFoundException(XboxException): """Any exception raised due to a resource being missing will subclass this""" pass + + +class RateLimitExceededException(XboxException): + def __init__(self, message, rate_limit: RateLimit): + self.message = message + self.rate_limit = rate_limit + self.try_again_in = rate_limit.get_reset_after() diff --git a/xbox/webapi/common/ratelimits/__init__.py b/xbox/webapi/common/ratelimits/__init__.py new file mode 100644 index 0000000..9e4e9ec --- /dev/null +++ b/xbox/webapi/common/ratelimits/__init__.py @@ -0,0 +1,270 @@ +from datetime import datetime, timedelta +from typing import Union, List + +from xbox.webapi.common.ratelimits.models import ( + ParsedRateLimit, + TimePeriod, + LimitType, + IncrementResult, +) + +from abc import ABCMeta, abstractmethod + + +class RateLimit(metaclass=ABCMeta): + """ + Abstract class for varying implementations/types of rate limits. + All methods in this class are overriden in every implementation. + However, different implementations may have additional functions not present in this parent abstract class. + + A class implementing RateLimit functions without any external threads. + When the first increment request is recieved (after a counter reset or a new instaniciation) + a reset_after variable is set detailing when the rate limit(s) reset. + + Upon each function invokation, the reset_after variable is checked and the timer is automatically reset if the reset_after time has passed. + """ + + @abstractmethod + def get_counter(self) -> int: + # Docstrings are defined in child classes due to their differing implementations. + pass + + @abstractmethod + def get_reset_after(self) -> Union[datetime, None]: + # Docstrings are defined in child classes due to their differing implementations. + pass + + @abstractmethod + def is_exceeded(self) -> bool: + # Docstrings are defined in child classes due to their differing implementations. + pass + + @abstractmethod + def increment(self) -> IncrementResult: + """ + The increment function adds one to the rate limit request counter. + + If the reset_after time has passed, the counter will first be reset before counting the request. + + When the counter hits 1, the reset_after time is calculated and stored. + + This function returns an `IncrementResult` object, containing the keys `counter: int` and `exceeded: bool`. + This can be used by the caller to determine the current state of the rate-limit object without making an additional function call. + """ + + pass + + +class SingleRateLimit(RateLimit): + """ + A rate limit implementation for a single rate limit, such as a burst or sustain limit. + This class is mainly used by the CombinedRateLimit class. + """ + + def __init__(self, time_period: TimePeriod, type: LimitType, limit: int): + self.__time_period = time_period + self.__type = type + self.__limit = limit + + self.__exceeded: bool = False + self.__counter = 0 + # No requests so far, so reset_after is None. + self.__reset_after: Union[datetime, None] = None + + def get_counter(self) -> int: + """ + This function returns the current request counter variable. + """ + + return self.__counter + + def get_time_period(self) -> "TimePeriod": + return self.__time_period + + def get_limit(self) -> int: + return self.__limit + + def get_limit_type(self) -> "LimitType": + return self.__type + + def get_reset_after(self) -> Union[datetime, None]: + """ + This getter returns the current state of the reset_after counter. + + If the counter in use, it's corresponding `datetime` object is returned. + + If the counter is not in use, `None` is returned. + """ + + return self.__reset_after + + def is_exceeded(self) -> bool: + """ + This functions returns `True` if the rate limit has been exceeded. + """ + + self.__reset_counter_if_required() + return self.__exceeded + + def increment(self) -> IncrementResult: + # Call a function to check if the counter should be reset + self.__reset_counter_if_required() + + # Increment the counter + self.__counter += 1 + + # If the counter is 1, (first request after a reset) set the reset_after value. + if self.__counter == 1: + self.__set_reset_after() + + # Check to see if we have now exceeded the request limit + self.__check_if_exceeded() + + # Return an instance of IncrementResult + return IncrementResult(counter=self.__counter, exceeded=self.__exceeded) + + # Should be called after every inc of the counter + def __check_if_exceeded(self): + if not self.__exceeded: + if self.__counter >= self.__limit: + self.__exceeded = True + # reset-after is now dependent on the time since the first request of this cycle. + # self.__set_reset_after() + + def __reset_counter_if_required(self): + # Check to make sure reset_after is not None + # - This is the case if this function is called before the counter + # is incremented after a reset / new instantiation + if self.__reset_after is not None: + if self.__reset_after < datetime.now(): + self.__exceeded = False + self.__counter = 0 + self.__reset_after = None + + def __set_reset_after(self): + self.__reset_after = datetime.now() + timedelta( + seconds=self.get_time_period().value + ) + + +class CombinedRateLimit(RateLimit): + """ + A rate limit implementation for multiple rate limits, such as burst and sustain. + + """ + + def __init__(self, *parsed_limits: ParsedRateLimit, type: LimitType): + # *parsed_limits is a tuple + + # Create a SingleRateLimit instance for each limit + self.__limits: list[SingleRateLimit] = [] + + for limit in parsed_limits: + # Use the type param (enum LimitType) to determine which limit to select + limit_num = limit.read if type == LimitType.READ else limit.write + + # Create a new instance of SingleRateLimit and append it to the limits array. + srl = SingleRateLimit(limit.period, type, limit_num) + self.__limits.append(srl) + + def get_counter(self) -> int: + """ + This function returns the request counter with the **highest** value. + + A `CombinedRateLimit` consists of multiple different rate limits, which may have differing counter values. + """ + + # Map self.__limits to (limit).get_counter() + counter_map = map(lambda limit: limit.get_counter(), self.__limits) + counters = list(counter_map) + + # Sort the counters list by value + # reverse=True to get highest first + counters.sort(reverse=True) + + # Return the highest value + return counters[0] + + # We don't want a datetime response for a limit that has not been exceeded. + # Otherwise eg. 10 burst requests -> 300s timeout (should be 30 (burst exceeded), 300s (not exceeded) + def get_reset_after(self) -> Union[datetime, None]: + """ + This getter returns either a `datetime` object or `None` object depending on the status of the rate limit. + + If the counter is in use, the rate limit with the **latest** reset_after is returned. + + This is so that this function can reliably be used as a indicator of when all rate limits have been reset. + + If the counter is not in use, `None` is returned. + """ + + # Get a list of limits that *have been exceeded* + dates_exceeded_only = filter(lambda limit: limit.is_exceeded(), self.__limits) + + # Map self.__limits to (limit).get_reset_after() + dates_map = map(lambda limit: limit.get_reset_after(), dates_exceeded_only) + + # Convert the map object to a list + dates = list(dates_map) + + # Construct a new list with only elements of instance datetime + # (Effectively filtering out any None elements) + dates_valid = [elem for elem in dates if type(elem) == datetime] + + # If dates_valid has any elements, return the one with the *later* timestamp. + # This means that if two or more limits have been exceeded, we wait for both to have reset (by returning the later timestamp) + if len(dates_valid) != 0: + # By default dates are sorted with the earliest date first. + # We will set reverse=True so that the first element is the later date. + dates_valid.sort(reverse=True) + + # Return the datetime object. + return dates_valid[0] + + # dates_valid has no elements, return None + return None + + # list -> List (typing.List) https://stackoverflow.com/a/63460173 + def get_limits(self) -> List[SingleRateLimit]: + return self.__limits + + # list -> List (typing.List) https://stackoverflow.com/a/63460173 + def get_limits_by_period(self, period: TimePeriod) -> List[SingleRateLimit]: + # Filter the list for the given LimitType + matches = filter(lambda limit: limit.get_time_period() == period, self.__limits) + # Convert the filter object to a list and return it + return list(matches) + + def is_exceeded(self) -> bool: + """ + This function returns `True` if **any** rate limit has been exceeded. + + It behaves like an OR logic gate. + """ + + # Map self.__limits to (limit).is_exceeded() + is_exceeded_map = map(lambda limit: limit.is_exceeded(), self.__limits) + is_exceeded_list = list(is_exceeded_map) + + # Return True if any variable in list is True + return True in is_exceeded_list + + def increment(self) -> IncrementResult: + # Increment each limit + results: list[IncrementResult] = [] + for limit in self.__limits: + result = limit.increment() + results.append(result) + + # SPEC: Let's pick the *higher* counter + # By default, sorted() returns in ascending order, so let's set reverse=True + # This means that the result with the highest counter will be the first element. + results_sorted = sorted(results, key=lambda i: i.counter, reverse=True) + + # Create an instance of IncrementResult and return it. + return IncrementResult( + counter=results_sorted[ + 0 + ].counter, # Use the highest counter (sorted in descending order) + exceeded=self.is_exceeded(), # Call self.is_exceeded (True if any limit has been exceeded, like an OR gate.) + ) diff --git a/xbox/webapi/common/ratelimits/models.py b/xbox/webapi/common/ratelimits/models.py new file mode 100644 index 0000000..cde7612 --- /dev/null +++ b/xbox/webapi/common/ratelimits/models.py @@ -0,0 +1,23 @@ +from enum import Enum +from pydantic import BaseModel + + +class TimePeriod(Enum): + BURST = 15 # 15 seconds + SUSTAIN = 300 # 5 minutes (300s) + + +class LimitType(Enum): + WRITE = 0 + READ = 1 + + +class IncrementResult(BaseModel): + counter: int + exceeded: bool + + +class ParsedRateLimit(BaseModel): + read: int + write: int + period: TimePeriod