Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apigateway Account APIs #8119

Merged
merged 6 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions moto/apigateway/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .utils import create_id, to_path

STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
PATCH_OPERATIONS = ["add", "remove", "replace", "move", "copy", "test"]


class Deployment(CloudFormationModel):
Expand Down Expand Up @@ -1020,6 +1021,7 @@ class RestAPI(CloudFormationModel):
PROP_POLICY = "policy"
PROP_DISABLE_EXECUTE_API_ENDPOINT = "disableExecuteApiEndpoint"
PROP_MINIMUM_COMPRESSION_SIZE = "minimumCompressionSize"
PROP_ROOT_RESOURCE_ID = "rootResourceId"

# operations
OPERATION_ADD = "add"
Expand Down Expand Up @@ -1064,6 +1066,7 @@ def __init__(
self.models: Dict[str, Model] = {}
self.request_validators: Dict[str, RequestValidator] = {}
self.default = self.add_child("/") # Add default child
self.root_resource_id = self.default.id

def __repr__(self) -> str:
return str(self.id)
Expand All @@ -1082,6 +1085,7 @@ def to_dict(self) -> Dict[str, Any]:
self.PROP_POLICY: self.policy,
self.PROP_DISABLE_EXECUTE_API_ENDPOINT: self.disableExecuteApiEndpoint,
self.PROP_MINIMUM_COMPRESSION_SIZE: self.minimum_compression_size,
self.PROP_ROOT_RESOURCE_ID: self.root_resource_id,
}

def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None:
Expand Down Expand Up @@ -1529,6 +1533,54 @@ def to_json(self) -> Dict[str, Any]:
return dct


class Account(BaseModel):
def __init__(self) -> None:
self.cloudwatch_role_arn: Optional[str] = None
self.throttle_settings: Dict[str, Any] = {
"burstLimit": 5000,
"rateLimit": 10000.0,
}
self.features: Optional[List[str]] = None
self.api_key_version: str = "1"

def apply_patch_operations(
self, patch_operations: List[Dict[str, Any]]
) -> "Account":
for op in patch_operations:
if "/cloudwatchRoleArn" in op["path"]:
self.cloudwatch_role_arn = op["value"]
elif "/features" in op["path"]:
if op["op"] == "add":
if self.features is None:
self.features = [op["value"]]
else:
self.features.append(op["value"])
elif op["op"] == "remove":
if op["value"] == "UsagePlans":
raise BadRequestException(
"Usage Plans cannot be disabled once enabled"
)
if self.features is not None:
self.features.remove(op["value"])
else:
raise NotImplementedError(
f'Patch operation "{op["op"]}" for "/features" not implemented'
)
else:
raise NotImplementedError(
f'Patch operation "{op["op"]}" for "{op["path"]}" not implemented'
)
return self

def to_json(self) -> Dict[str, Any]:
return {
"cloudwatchRoleArn": self.cloudwatch_role_arn,
"throttleSettings": self.throttle_settings,
"features": self.features,
"apiKeyVersion": self.api_key_version,
}


class APIGatewayBackend(BaseBackend):
"""
API Gateway mock.
Expand Down Expand Up @@ -1558,6 +1610,7 @@ class APIGatewayBackend(BaseBackend):

def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.account: Account = Account()
self.apis: Dict[str, RestAPI] = {}
self.keys: Dict[str, ApiKey] = {}
self.usage_plans: Dict[str, UsagePlan] = {}
Expand Down Expand Up @@ -2485,5 +2538,12 @@ def delete_gateway_response(self, rest_api_id: str, response_type: str) -> None:
api = self.get_rest_api(rest_api_id)
api.delete_gateway_response(response_type)

def update_account(self, patch_operations: List[Dict[str, Any]]) -> Account:
account = self.account.apply_patch_operations(patch_operations)
return account

def get_account(self) -> Account:
return self.account


apigateway_backends = BackendDict(APIGatewayBackend, "apigateway")
9 changes: 9 additions & 0 deletions moto/apigateway/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,12 @@ def delete_gateway_response(self) -> TYPE_RESPONSE:
rest_api_id=rest_api_id, response_type=response_type
)
return 202, {}, json.dumps(dict())

def update_account(self) -> str:
patch_operations = self._get_param("patchOperations")
account = self.backend.update_account(patch_operations)
return json.dumps(account.to_json())

def get_account(self) -> str:
account = self.backend.get_account()
return json.dumps(account.to_json())
1 change: 1 addition & 0 deletions moto/apigateway/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"{0}/restapis/(?P<api_id>[^/]+)/gatewayresponses/(?P<response_type>[^/]+)/?$": APIGatewayResponse.dispatch,
"{0}/vpclinks$": APIGatewayResponse.dispatch,
"{0}/vpclinks/(?P<vpclink_id>[^/]+)": APIGatewayResponse.dispatch,
"{0}/account$": APIGatewayResponse.dispatch,
}

# Also manages the APIGatewayV2
Expand Down
81 changes: 81 additions & 0 deletions tests/test_apigateway/test_apigateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_create_and_get_rest_api():
name="my_api", description="this is my api", disableExecuteApiEndpoint=True
)
api_id = response["id"]
root_resource_id = response["rootResourceId"]

response = client.get_rest_api(restApiId=api_id)

Expand All @@ -34,6 +35,7 @@ def test_create_and_get_rest_api():
"endpointConfiguration": {"types": ["EDGE"]},
"tags": {},
"disableExecuteApiEndpoint": True,
"rootResourceId": root_resource_id,
}


Expand All @@ -42,6 +44,7 @@ def test_update_rest_api():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
root_resource_id = response["rootResourceId"]
patchOperations = [
{"op": "replace", "path": "/name", "value": "new-name"},
{"op": "replace", "path": "/description", "value": "new-description"},
Expand Down Expand Up @@ -71,6 +74,7 @@ def test_update_rest_api():
"endpointConfiguration": {"types": ["EDGE"]},
"tags": {},
"disableExecuteApiEndpoint": True,
"rootResourceId": root_resource_id,
}
# should fail with wrong apikeysoruce
patchOperations = [
Expand Down Expand Up @@ -2486,3 +2490,80 @@ def test_update_path_mapping_with_unknown_stage():
assert ex.value.response["Error"]["Message"] == "Invalid stage identifier specified"
assert ex.value.response["Error"]["Code"] == "BadRequestException"
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400


@mock_aws
def test_update_account():
client = boto3.client("apigateway", region_name="eu-west-1")

patch_operations = [
{
"op": "replace",
"path": "/cloudwatchRoleArn",
"value": "arn:aws:iam:123456789012:role/moto-test-apigw-role-1",
},
{"op": "add", "path": "/features", "value": "UsagePlans"},
{"op": "add", "path": "/features", "value": "TestFeature"},
]

account = client.update_account(patchOperations=patch_operations)

assert (
account["cloudwatchRoleArn"]
== "arn:aws:iam:123456789012:role/moto-test-apigw-role-1"
)
assert account["features"] == ["UsagePlans", "TestFeature"]

patch_operations = [
{
"op": "replace",
"path": "/cloudwatchRoleArn",
"value": "arn:aws:iam:123456789012:role/moto-test-apigw-role-2",
},
{"op": "remove", "path": "/features", "value": "TestFeature"},
]

account = client.update_account(patchOperations=patch_operations)

assert (
account["cloudwatchRoleArn"]
== "arn:aws:iam:123456789012:role/moto-test-apigw-role-2"
)
assert account["throttleSettings"]["burstLimit"] == 5000
assert account["throttleSettings"]["rateLimit"] == 10000.0
assert account["apiKeyVersion"] == "1"
assert account["features"] == ["UsagePlans"]


@mock_aws
def test_update_account_error():
client = boto3.client("apigateway", region_name="eu-west-1")
patch_operations = [
{
"op": "remove",
"path": "/features",
"value": "UsagePlans",
},
]

with pytest.raises(ClientError) as ex:
client.update_account(patchOperations=patch_operations)

assert (
ex.value.response["Error"]["Message"]
== "Usage Plans cannot be disabled once enabled"
)
assert ex.value.response["Error"]["Code"] == "BadRequestException"
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400


@mock_aws
def test_get_account():
client = boto3.client("apigateway", region_name="eu-west-1")
account = client.get_account()

assert account["throttleSettings"]["burstLimit"] == 5000
assert account["throttleSettings"]["rateLimit"] == 10000.0
assert account["apiKeyVersion"] == "1"
assert "features" not in account
assert "cloudwatchRoleArn" not in account
Loading