From 1d7a1d38af231a093b21c4e35038c91afc8b9b06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nagy=20Kriszti=C3=A1n?= Date: Tue, 19 Mar 2024 15:44:54 +0000 Subject: [PATCH] OPS: add host header if it's missing Co-authored-by: Laszlo Losonczy --- README.md | 3 +- escherauth/escherauth.py | 14 +-- tests/test_escherrequest.py | 195 ++++++++++++++++++++++++++++-------- 3 files changed, 165 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 792ec0f..4f7a0da 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,9 @@ from escherauth import Escher request = { 'method': 'POST', 'url': '/', + 'host': 'example.com', 'headers': [ - ['Host', 'example.com'], + ['X-Foo', 'bar'], ], 'body': '{"this_is": "a_request_body"}', } diff --git a/escherauth/escherauth.py b/escherauth/escherauth.py index cf38e6f..2a3ffd3 100644 --- a/escherauth/escherauth.py +++ b/escherauth/escherauth.py @@ -59,7 +59,8 @@ def method(self): def host(self): if self.type is requests.models.PreparedRequest: - return self.request.host + parts = urlsplit(self.request.url) + return parts.netloc if self.type is dict: return self.request['host'] @@ -74,12 +75,9 @@ def has_query_param(self, query_param): def headers(self): if self.type is requests.models.PreparedRequest: - headers = [] - for key, value in self.request.headers.items(): - headers.append([key, value]) - return headers + return [[key, value] for key, value in self.request.headers.items()] if self.type is dict: - return self.request['headers'] + return self.request.get('headers', []) def has_header(self, header): return header.lower() in [key.lower() for key, value in self.headers()] @@ -99,6 +97,8 @@ def add_header(self, header, value): if self.type is requests.models.PreparedRequest: self.request.headers[header] = value if self.type is dict: + if 'headers' not in self.request: + self.request['headers'] = [] self.request['headers'].append([header, value]) def set_presigned_url(self, is_presigned_url): @@ -254,6 +254,8 @@ def sign_request(self, request, headers_to_sign=None): current_time = self.current_time or datetime.datetime.now(datetime.timezone.utc) + if not request.has_header('host'): + request.add_header('Host', request.host()) if not request.has_header(self.date_header_name): if self.date_header_name.lower() == 'date': request.add_header(self.date_header_name, self.header_date(current_time)) diff --git a/tests/test_escherrequest.py b/tests/test_escherrequest.py index e11545a..f7ab97c 100644 --- a/tests/test_escherrequest.py +++ b/tests/test_escherrequest.py @@ -1,56 +1,171 @@ import unittest -from escherauth.escherauth import EscherRequest +import requests + +from escherauth.escherauth import EscherRequest, EscherException class EscherRequestTest(unittest.TestCase): - def test_object_basic(self): + def test_invalid_http_method(self): + with self.assertRaisesRegex(EscherException, 'The request method is invalid'): + EscherRequest({ + 'method': 'INVALID', + 'url': '/?foo=bar', + }) + + def test_invalid_path(self): + with self.assertRaisesRegex(EscherException, 'The request url shouldn\'t contains http or https'): + EscherRequest({ + 'method': 'GET', + 'url': 'http://localhost/?foo=bar', + }) + + def test_no_body(self): + with self.assertRaisesRegex(EscherException, 'The request body shouldn\'t be empty if the request method is POST'): + EscherRequest({ + 'method': 'POST', + 'url': '/?foo=bar', + }) + + def test_dict_method(self): request = EscherRequest({ 'method': 'GET', - 'host': 'host.foo.com', 'url': '/?foo=bar', - 'headers': [ - ('Date', 'Mon, 09 Sep 2011 23:36:00 GMT'), - ('Host', 'host.foo.com'), - ], }) + self.assertEqual(request.method(), 'GET') - self.assertEqual(request.host(), 'host.foo.com') + + def test_prepared_request_method(self): + request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare()) + + self.assertEqual(request.method(), 'GET') + + def test_dict_host(self): + request = EscherRequest({ + 'method': 'GET', + 'url': '/?foo=bar', + 'host': 'localhost:8080' + }) + + self.assertEqual(request.host(), 'localhost:8080') + + def test_prepared_request_host(self): + request = EscherRequest(requests.Request('GET', 'http://localhost:8080/?foo=bar').prepare()) + + self.assertEqual(request.host(), 'localhost:8080') + + def test_dict_path(self): + request = EscherRequest({ + 'method': 'GET', + 'url': '/?foo=bar', + }) + + self.assertEqual(request.path(), '/') + + def test_prepared_request_path(self): + request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare()) + self.assertEqual(request.path(), '/') - self.assertListEqual(request.query_parts(), [ - ('foo', 'bar'), - ]) - self.assertListEqual(request.headers(), [ - ('Date', 'Mon, 09 Sep 2011 23:36:00 GMT'), - ('Host', 'host.foo.com'), - ]) - self.assertEqual(request.body(), None) # there was no body specified - def test_object_complex(self): + def test_dict_query_parts(self): request = EscherRequest({ - 'method': 'POST', - 'host': 'host.foo.com', - 'url': '/example/path/?foo=bar&abc=cba', - 'headers': [], - 'body': 'HELLO WORLD!', - }) - self.assertEqual(request.method(), 'POST') - self.assertEqual(request.host(), 'host.foo.com') - self.assertEqual(request.path(), '/example/path/') - self.assertListEqual(request.query_parts(), [ - ('foo', 'bar'), - ('abc', 'cba'), - ]) - self.assertListEqual(request.headers(), []) - self.assertEqual(request.body(), 'HELLO WORLD!') - - def test_object_add_header(self): + 'method': 'GET', + 'url': '/?foo=bar', + }) + + self.assertEqual(request.query_parts(), [('foo', 'bar')]) + + def test_prepared_request_query_parts(self): + request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare()) + + self.assertEqual(request.query_parts(), [('foo', 'bar')]) + + def test_dict_has_query_param(self): + request = EscherRequest({ + 'method': 'GET', + 'url': '/?foo=bar', + }) + + self.assertTrue(request.has_query_param('foo')) + self.assertFalse(request.has_query_param('bar')) + + def test_prepared_request_has_query_param(self): + request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare()) + + self.assertTrue(request.has_query_param('foo')) + self.assertFalse(request.has_query_param('bar')) + + def test_dict_headers(self): + request = EscherRequest({ + 'method': 'GET', + 'url': '/?foo=bar', + 'headers': [['Foo', 'bar']], + }) + + self.assertListEqual(request.headers(), [['Foo', 'bar']]) + + def test_prepared_request_headers(self): + request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar', headers={'Foo': 'bar'}).prepare()) + + self.assertListEqual(request.headers(), [['Foo', 'bar']]) + + def test_dict_has_header(self): + request = EscherRequest({ + 'method': 'GET', + 'url': '/?foo=bar', + 'headers': [['Foo', 'bar']], + }) + + self.assertTrue(request.has_header('foo')) + self.assertFalse(request.has_header('bar')) + + def test_prepared_request_has_header(self): + request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar', headers={'Foo': 'bar'}).prepare()) + + self.assertTrue(request.has_header('foo')) + self.assertFalse(request.has_header('bar')) + + def test_presigned_url_body(self): + request = EscherRequest({ + 'method': 'GET', + 'url': '/?foo=bar', + }) + request.set_presigned_url(True) + + self.assertEqual(request.body(), 'UNSIGNED-PAYLOAD') + + def test_dict_body(self): request = EscherRequest({ 'method': 'POST', - 'host': 'host.foo.com', - 'url': '/example/path/?foo=bar&abc=cba', - 'headers': [], - 'body': 'HELLO WORLD!', + 'url': '/?foo=bar', + 'body': 'foo', }) - request.add_header('Foo', 'Bar') - self.assertListEqual(request.headers(), [['Foo', 'Bar']]) + + self.assertEqual(request.body(), 'foo') + + def test_prepared_request_body(self): + request = EscherRequest(requests.Request('POST', 'http://localhost/?foo=bar', data='foo').prepare()) + + self.assertEqual(request.body(), 'foo') + + def test_prepared_request_byte_body(self): + request = EscherRequest(requests.Request('POST', 'http://localhost/?foo=bar', data=b'foo').prepare()) + + self.assertEqual(request.body(), 'foo') + + def test_dict_add_header(self): + request = EscherRequest({ + 'method': 'GET', + 'url': '/?foo=bar', + }) + + self.assertFalse(request.has_header('bar')) + request.add_header('Bar', 'baz') + self.assertTrue(request.has_header('bar')) + + def test_prepared_request_add_header(self): + request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare()) + + self.assertFalse(request.has_header('bar')) + request.add_header('Bar', 'baz') + self.assertTrue(request.has_header('bar'))