diff --git a/rest/authentication/authentication.py b/rest/authentication/authentication.py index e4f40d90..c9f75c97 100644 --- a/rest/authentication/authentication.py +++ b/rest/authentication/authentication.py @@ -15,6 +15,7 @@ Internal, CredentialsInvalid, BillingPlanInvalid, + TokenInvalid, ) from authentication.oidc_providers import oidc_providers from authentication.user import OIDCUser, SHUser @@ -60,7 +61,7 @@ def authenticate_user_oidc(self, access_token, oidc_provider_id): user_id = userinfo["sub"] try: - user = OIDCUser(user_id, oidc_userinfo=userinfo) + user = OIDCUser(user_id, oidc_userinfo=userinfo, access_token=access_token) except BillingPlanInvalid: return None diff --git a/rest/authentication/user.py b/rest/authentication/user.py index 410e1570..95faebc4 100644 --- a/rest/authentication/user.py +++ b/rest/authentication/user.py @@ -25,12 +25,15 @@ def get_user_info(self): user_info["default_plan"] = self.default_plan.name return user_info + def get_leftover_credits(self): + pass + def report_usage(self, pu_spent, job_id=None): pass class OIDCUser(User): - def __init__(self, user_id=None, oidc_userinfo={}): + def __init__(self, user_id=None, oidc_userinfo={}, access_token=None): super().__init__(user_id) self.entitlements = [ self.convert_entitlement(entitlement) for entitlement in oidc_userinfo.get("eduperson_entitlement", []) @@ -38,6 +41,7 @@ def __init__(self, user_id=None, oidc_userinfo={}): self.oidc_userinfo = oidc_userinfo self.default_plan = OpenEOPBillingPlan.get_billing_plan(self.entitlements) self.session = central_user_sentinelhub_session + self.access_token = access_token def __str__(self): return f"{self.__class__.__name__}: {self.user_id}" @@ -60,6 +64,9 @@ def get_user_info(self): user_info["info"] = {"oidc_userinfo": self.oidc_userinfo} return user_info + def get_leftover_credits(self): + return usageReporting.get_leftover_credits_for_user(self.access_token) + def report_usage(self, pu_spent, job_id=None): usageReporting.report_usage(self.user_id, pu_spent, job_id) diff --git a/rest/openeoerrors.py b/rest/openeoerrors.py index 5117f689..e04b74af 100644 --- a/rest/openeoerrors.py +++ b/rest/openeoerrors.py @@ -153,3 +153,9 @@ def __init__(self, width, height) -> None: error_code = "ImageDimensionInvalid" http_code = 400 + + +class InsufficientCredits(SHOpenEOError): + error_code = "InsufficientCredits" + http_code = 402 + message = "You do not have sufficient credits to perform this request. Please visit https://portal.terrascope.be/pages/pricing to find more information on how to buy additional credits." diff --git a/rest/processing/processing.py b/rest/processing/processing.py index 36a3b66b..3a49e136 100644 --- a/rest/processing/processing.py +++ b/rest/processing/processing.py @@ -12,7 +12,7 @@ from dynamodb.utils import get_user_defined_processes_graphs from dynamodb import JobsPersistence from const import openEOBatchJobStatus -from openeoerrors import JobNotFound, Timeout +from openeoerrors import InsufficientCredits, JobNotFound, Timeout def check_process_graph_conversion_validity(process_graph): @@ -63,6 +63,9 @@ def start_new_batch_job(sentinel_hub, process, job_id): raise JobNotFound() estimated_sentinelhub_pu, _, _ = create_or_get_estimate_values_from_db(job, new_batch_request_id) + + check_leftover_credits(estimated_sentinelhub_pu) + JobsPersistence.update_key( job["id"], "sum_costs", str(round(float(job.get("sum_costs", 0)) + estimated_sentinelhub_pu, 3)) ) @@ -100,6 +103,9 @@ def start_batch_job(batch_request_id, process, deployment_endpoint, job_id): raise JobNotFound() estimated_sentinelhub_pu, _, _ = create_or_get_estimate_values_from_db(job, job["batch_request_id"]) + + check_leftover_credits(estimated_sentinelhub_pu) + JobsPersistence.update_key( job["id"], "sum_costs", str(round(float(job.get("sum_costs", 0)) + estimated_sentinelhub_pu, 3)) ) @@ -222,3 +228,10 @@ def create_or_get_estimate_values_from_db(job, batch_request_id): estimated_file_size = float(job.get("estimated_file_size", 0)) return estimated_sentinelhub_pu, estimated_platform_credits, estimated_file_size + + +def check_leftover_credits(estimated_pu): + leftover_credits = g.user.get_leftover_credits() + estimated_pu_as_credits = estimated_pu * 0.15 # platform credits === SH PU's * 0.15 + if leftover_credits is not None and leftover_credits < estimated_pu_as_credits: + raise InsufficientCredits() diff --git a/rest/usage_reporting/report_usage.py b/rest/usage_reporting/report_usage.py index f576ec84..e85cbf7f 100644 --- a/rest/usage_reporting/report_usage.py +++ b/rest/usage_reporting/report_usage.py @@ -74,6 +74,26 @@ def reporting_check_health(self): return r.status_code == 200 and content["status"] == "ok" + def get_leftover_credits_for_user(self, user_access_token): + user_url = f"{self.base_url}user" + + headers = {"content-type": "application/json", "Authorization": f"Bearer {user_access_token}"} + + if not self.reporting_check_health(): + log(ERROR, "Services for usage reporting are not healthy") + raise Internal("Services for usage reporting are not healthy") + + r = requests.get(user_url, headers=headers) + + if r.status_code == 200: + content = r.json() + platform_credits = content.get("credits") + + return platform_credits + else: + log(ERROR, f"Error fetching leftover credits: {r.status_code} {r.text}") + raise Internal(f"Problems during fetching leftover credits: {r.status_code} {r.text}") + def report_usage(self, user_id, pu_spent, job_id=None, max_tries=5): reporting_token = self.get_token()