Skip to content

Commit a8fe235

Browse files
BarcoMasileMarco Basile
andauthored
Improved validation constraints for API signatures (#485)
* improved validation of API signatures * corrected Sacct schema instance_id validator max allowed size * adjust __validate_request func to allow not raising exception for missing request.json body to use with PC proxy * split PCProxy schema in two (args and body) and added request size validation for PC proxy APIs * adoption of changed and new schemas * added test for `raise_on_missing_body` * added test for custom validator Co-authored-by: Marco Basile <[email protected]>
1 parent 4c9de80 commit a8fe235

File tree

6 files changed

+102
-36
lines changed

6 files changed

+102
-36
lines changed

api/PclusterApiHandler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from api.security.csrf.csrf import csrf_needed
2828
from api.utils import disable_auth
2929
from api.validation import validated
30-
from api.validation.schemas import PCProxy
30+
from api.validation.schemas import PCProxyArgs, PCProxyBody
3131

3232
USER_POOL_ID = os.getenv("USER_POOL_ID")
3333
AUTH_PATH = os.getenv("AUTH_PATH")
@@ -700,15 +700,15 @@ def _get_params(_request):
700700

701701
@pc.get('/', strict_slashes=False)
702702
@authenticated({'admin'})
703-
@validated(params=PCProxy)
703+
@validated(params=PCProxyArgs)
704704
def pc_proxy_get():
705705
response = sigv4_request(request.method, API_BASE_URL, request.args.get("path"), _get_params(request))
706706
return response.json(), response.status_code
707707

708708
@pc.route('/', methods=['POST','PUT','PATCH','DELETE'], strict_slashes=False)
709709
@authenticated({'admin'})
710710
@csrf_needed
711-
@validated(params=PCProxy)
711+
@validated(params=PCProxyArgs, body=PCProxyBody, raise_on_missing_body=False)
712712
def pc_proxy():
713713
body = None
714714
try:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
from marshmallow import ValidationError
3+
4+
from api.validation.validators import size_not_exceeding
5+
6+
7+
def test_size_not_exceeding():
8+
max_size = 300
9+
test_str_not_exceeding = 'a' * (max_size - 2) # save 2 chars for double quotes
10+
11+
size_not_exceeding(test_str_not_exceeding, max_size)
12+
13+
def test_size_not_exceeding_failing():
14+
max_size = 300
15+
test_str_not_exceeding = 'a' * max_size # will produce "aaa...", max_size + 2
16+
17+
with pytest.raises(ValidationError):
18+
size_not_exceeding(test_str_not_exceeding, max_size)

api/tests/validation/test_api_validation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import Mock, PropertyMock
2+
13
import pytest
24
from marshmallow import Schema, fields, validate, ValidationError
35

@@ -30,6 +32,33 @@ def test_valid_request_successful(mock_csrf_needed, mock_enable_auth):
3032

3133
assert errors == {}
3234

35+
def test_invalid_request_successful_with_raise_on_missing_body_enabled(mock_csrf_needed, mock_enable_auth):
36+
request = MockRequest()
37+
mock_json_property = PropertyMock(side_effect=Exception)
38+
original_json_property = MockRequest.json
39+
MockRequest.json = mock_json_property
40+
41+
with pytest.raises(ValueError):
42+
__validate_request(request, body_schema=MockRequestJsonSchema(), params_schema=MockRequestArgsSchema(),
43+
cookies_schema=MockRequestCookiesSchema())
44+
45+
mock_json_property.assert_called_once()
46+
MockRequest.json = original_json_property
47+
48+
def test_invalid_request_successful_with_raise_on_missing_body_disabled(mock_csrf_needed, mock_enable_auth):
49+
request = MockRequest()
50+
mock_json_property = PropertyMock(side_effect=Exception)
51+
original_json_property = MockRequest.json
52+
MockRequest.json = mock_json_property
53+
54+
errors = __validate_request(request, body_schema=MockRequestJsonSchema(), params_schema=MockRequestArgsSchema(),
55+
cookies_schema=MockRequestCookiesSchema(), raise_on_missing_body=False)
56+
57+
mock_json_property.assert_called_once()
58+
assert errors == {}
59+
60+
MockRequest.json = original_json_property
61+
3362

3463
def test_valid_request_failure(mock_csrf_needed, mock_enable_auth):
3564
request = MockRequest()

api/validation/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from api.validation.schemas import EC2Action
77

88

9-
def __validate_request(_request: Request, *, body_schema: Schema = None, params_schema: Schema = None, cookies_schema: Schema = None):
9+
def __validate_request(_request: Request, *, body_schema: Schema = None, params_schema: Schema = None, cookies_schema: Schema = None, raise_on_missing_body = True):
1010
errors = {}
1111
if body_schema:
1212
try:
1313
errors.update(body_schema.validate(_request.json))
1414
except:
15-
raise ValueError('Expected json body')
15+
if raise_on_missing_body:
16+
raise ValueError('Expected json body')
1617

1718
if params_schema:
1819
errors.update(params_schema.validate(_request.args))
@@ -23,11 +24,11 @@ def __validate_request(_request: Request, *, body_schema: Schema = None, params_
2324
return errors
2425

2526

26-
def validated(*, body: Schema = None, params: Schema = None, cookies: Schema = None):
27+
def validated(*, body: Schema = None, params: Schema = None, cookies: Schema = None, raise_on_missing_body = True):
2728
def wrapper(func):
2829
@wraps(func)
2930
def decorated(*pargs, **kwargs):
30-
errors = __validate_request(request, body_schema=body, params_schema=params, cookies_schema=cookies)
31+
errors = __validate_request(request, body_schema=body, params_schema=params, cookies_schema=cookies, raise_on_missing_body=raise_on_missing_body)
3132
if errors:
3233
raise ValidationError(f'Input validation failed for {request.path}', data=errors)
3334
return func(*pargs, **kwargs)

api/validation/schemas.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
from marshmallow import Schema, fields, validate, INCLUDE
1+
from marshmallow import Schema, fields, validate, INCLUDE, validates_schema, ValidationError
22

3-
from api.validation.validators import comma_splittable, aws_region_validator, is_alphanumeric_with_hyphen, valid_api_log_levels_predicate
4-
from api.logging import VALID_LOG_LEVELS
3+
from api.validation.validators import comma_splittable, aws_region_validator, is_alphanumeric_with_hyphen, \
4+
valid_api_log_levels_predicate, size_not_exceeding
55

66

77
class EC2ActionSchema(Schema):
88
action = fields.String(required=True, validate=validate.OneOf(['stop_instances', 'start_instances']))
9-
instance_ids = fields.String(required=True, validate=comma_splittable)
9+
instance_ids = fields.String(required=True, validate=validate.And(comma_splittable, validate.Length(max=2048)))
1010
region = fields.String(validate=aws_region_validator)
1111

1212

1313
EC2Action = EC2ActionSchema(unknown=INCLUDE)
1414

1515

1616
class CreateUserSchema(Schema):
17-
Username = fields.Email(required=True)
18-
Phonenumber = fields.String()
17+
Username = fields.Email(required=True, validate=validate.Length(max=320)) # Email RFC allows max 320 chars
18+
Phonenumber = fields.String(validate=validate.Length(max=15)) # ITU-T E.164 allows phone numbers no more than 15 digits
1919

2020

2121
CreateUser = CreateUserSchema(unknown=INCLUDE)
@@ -29,13 +29,13 @@ class DeleteUserSchema(Schema):
2929

3030
class GetClusterConfigSchema(Schema):
3131
region = fields.String(validate=aws_region_validator)
32-
cluster_name = fields.String(required=True, validate=is_alphanumeric_with_hyphen)
32+
cluster_name = fields.String(required=True, validate=validate.And(is_alphanumeric_with_hyphen, validate.Length(max=60))) # PC allow cluster name of max 60 chars
3333

3434
GetClusterConfig = GetClusterConfigSchema(unknown=INCLUDE)
3535

3636

3737
class GetCustomImageConfigSchema(Schema):
38-
image_id = fields.String(required=True, validate=is_alphanumeric_with_hyphen)
38+
image_id = fields.String(required=True, validate=validate.And(is_alphanumeric_with_hyphen, validate.Length(min=1, max=1024))) # AMI id min 1, max 1024 chars
3939

4040
GetCustomImageConfig = GetCustomImageConfigSchema(unknown=INCLUDE)
4141

@@ -53,72 +53,82 @@ class GetInstanceTypesSchema(Schema):
5353

5454

5555
class GetDcvSessionSchema(Schema):
56-
user = fields.String()
57-
instance_id = fields.String(required=True)
56+
user = fields.String(validate=validate.Length(max=64))
57+
instance_id = fields.String(required=True, validate=validate.Length(max=60))
5858
region = fields.String(validate=aws_region_validator)
5959

6060
GetDcvSession = GetDcvSessionSchema(unknown=INCLUDE)
6161

6262

6363
class QueueStatusSchema(Schema):
64-
user = fields.String()
65-
instance_id = fields.String(required=True)
64+
user = fields.String(validate=validate.Length(max=64))
65+
instance_id = fields.String(required=True, validate=validate.Length(max=60))
6666
region = fields.String(required=True, validate=aws_region_validator)
6767

6868
QueueStatus = QueueStatusSchema(unknown=INCLUDE)
6969

7070

7171
class ScontrolJobSchema(Schema):
72-
user = fields.String()
73-
instance_id = fields.String(required=True)
74-
job_id = fields.String(required=True)
72+
user = fields.String(validate=validate.Length(max=64))
73+
instance_id = fields.String(required=True, validate=validate.Length(max=60))
74+
job_id = fields.String(required=True, validate=validate.Length(max=256))
7575
region = fields.String(required=True, validate=aws_region_validator)
7676

7777
ScontrolJob = ScontrolJobSchema(unknown=INCLUDE)
7878

7979

8080
class CancelJobSchema(Schema):
81-
user = fields.String()
82-
instance_id = fields.String(required=True)
83-
job_id = fields.String(required=True)
81+
user = fields.String(validate=validate.Length(max=64))
82+
instance_id = fields.String(required=True, validate=validate.Length(max=60))
83+
job_id = fields.String(required=True, validate=validate.Length(max=256))
8484
region = fields.String(required=True, validate=aws_region_validator)
8585

8686
CancelJob = CancelJobSchema(unknown=INCLUDE)
8787

8888

8989
class SacctSchema(Schema):
90-
user = fields.String()
91-
instance_id = fields.String(required=True)
92-
cluster_name = fields.String(required=True, validate=is_alphanumeric_with_hyphen)
90+
user = fields.String(validate=validate.Length(max=64))
91+
instance_id = fields.String(required=True, validate=validate.Length(max=60))
92+
cluster_name = fields.String(required=True, validate=validate.And(is_alphanumeric_with_hyphen, validate.Length(max=60)))
9393
region = fields.String(required=True, validate=aws_region_validator)
9494

9595
Sacct = SacctSchema(unknown=INCLUDE)
9696

9797

9898
class LoginSchema(Schema):
99-
code = fields.String(required=True)
99+
code = fields.String(required=True, validate=validate.Length(max=128))
100100

101101
Login = LoginSchema(unknown=INCLUDE)
102102

103103

104104
class PushLogSchema(Schema):
105105
class PushLogEntrySchema(Schema):
106106
level = fields.String(required=True, validate=valid_api_log_levels_predicate)
107-
message = fields.String(required=True)
107+
message = fields.String(required=True, validate=validate.Length(max=246000)) # CW limit is 256k, leaving 1k to extra and level
108108
extra = fields.Dict()
109109

110110
logs = fields.List(fields.Nested(PushLogEntrySchema), required=True)
111111

112112
PushLog = PushLogSchema(unknown=INCLUDE)
113113

114114
class PriceEstimateSchema(Schema):
115-
cluster_name = fields.String(required=True)
116-
queue_name = fields.String(required=True)
115+
cluster_name = fields.String(required=True, validate=validate.Length(max=60))
116+
queue_name = fields.String(required=True, validate=validate.Length(max=60))
117117
region = fields.String(validate=aws_region_validator, required=True)
118118

119119
PriceEstimate = PriceEstimateSchema(unknown=INCLUDE)
120120

121-
class PCProxySchema(Schema):
122-
path = fields.String(required=True)
121+
class PCProxyArgsSchema(Schema):
122+
path = fields.String(required=True, validate=validate.Length(max=512))
123123

124-
PCProxy = PCProxySchema(unknown=INCLUDE)
124+
PCProxyArgs = PCProxyArgsSchema(unknown=INCLUDE)
125+
126+
class PCProxyBodySchema(Schema):
127+
def __init__(self, max_size, **kwargs):
128+
super().__init__(**kwargs)
129+
self.max_size = max_size
130+
@validates_schema(pass_original=False)
131+
def request_body_not_exceeding(self, data, **kwargs):
132+
size_not_exceeding(data, self.max_size)
133+
134+
PCProxyBody = PCProxyBodySchema(max_size=8192,unknown=INCLUDE)

api/validation/validators.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from marshmallow import validate
1+
import json
2+
3+
from marshmallow import validate, ValidationError
24
import re
35

46
from api.logging import VALID_LOG_LEVELS
@@ -33,3 +35,9 @@ def is_alphanumeric_with_hyphen(arg: str):
3335

3436
def valid_api_log_levels_predicate(loglevel):
3537
return loglevel.lower() in VALID_LOG_LEVELS
38+
39+
def size_not_exceeding(data, size):
40+
bytes_ = bytes(json.dumps(data), 'utf-8')
41+
byte_size = len(bytes_)
42+
if byte_size > size:
43+
raise ValidationError(f'Request body exceeded max size of {size}')

0 commit comments

Comments
 (0)