Skip to content

Commit

Permalink
OPS: add host header if it's missing
Browse files Browse the repository at this point in the history
Co-authored-by: Laszlo Losonczy <[email protected]>
  • Loading branch information
knagy and losonczylaci committed Mar 19, 2024
1 parent f8d9bad commit 1d7a1d3
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 47 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ from escherauth import Escher
request = {
'method': 'POST',
'url': '/',
'host': 'example.com',
'headers': [
['Host', 'example.com'],
['X-Foo', 'bar'],
],
'body': '{"this_is": "a_request_body"}',
}
Expand Down
14 changes: 8 additions & 6 deletions escherauth/escherauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def method(self):

def host(self):
if self.type is requests.models.PreparedRequest:
return self.request.host
parts = urlsplit(self.request.url)
return parts.netloc
if self.type is dict:
return self.request['host']

Expand All @@ -74,12 +75,9 @@ def has_query_param(self, query_param):

def headers(self):
if self.type is requests.models.PreparedRequest:
headers = []
for key, value in self.request.headers.items():
headers.append([key, value])
return headers
return [[key, value] for key, value in self.request.headers.items()]
if self.type is dict:
return self.request['headers']
return self.request.get('headers', [])

def has_header(self, header):
return header.lower() in [key.lower() for key, value in self.headers()]
Expand All @@ -99,6 +97,8 @@ def add_header(self, header, value):
if self.type is requests.models.PreparedRequest:
self.request.headers[header] = value
if self.type is dict:
if 'headers' not in self.request:
self.request['headers'] = []
self.request['headers'].append([header, value])

def set_presigned_url(self, is_presigned_url):
Expand Down Expand Up @@ -254,6 +254,8 @@ def sign_request(self, request, headers_to_sign=None):

current_time = self.current_time or datetime.datetime.now(datetime.timezone.utc)

if not request.has_header('host'):
request.add_header('Host', request.host())
if not request.has_header(self.date_header_name):
if self.date_header_name.lower() == 'date':
request.add_header(self.date_header_name, self.header_date(current_time))
Expand Down
195 changes: 155 additions & 40 deletions tests/test_escherrequest.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,171 @@
import unittest

from escherauth.escherauth import EscherRequest
import requests

from escherauth.escherauth import EscherRequest, EscherException


class EscherRequestTest(unittest.TestCase):
def test_object_basic(self):
def test_invalid_http_method(self):
with self.assertRaisesRegex(EscherException, 'The request method is invalid'):
EscherRequest({
'method': 'INVALID',
'url': '/?foo=bar',
})

def test_invalid_path(self):
with self.assertRaisesRegex(EscherException, 'The request url shouldn\'t contains http or https'):
EscherRequest({
'method': 'GET',
'url': 'http://localhost/?foo=bar',
})

def test_no_body(self):
with self.assertRaisesRegex(EscherException, 'The request body shouldn\'t be empty if the request method is POST'):
EscherRequest({
'method': 'POST',
'url': '/?foo=bar',
})

def test_dict_method(self):
request = EscherRequest({
'method': 'GET',
'host': 'host.foo.com',
'url': '/?foo=bar',
'headers': [
('Date', 'Mon, 09 Sep 2011 23:36:00 GMT'),
('Host', 'host.foo.com'),
],
})

self.assertEqual(request.method(), 'GET')
self.assertEqual(request.host(), 'host.foo.com')

def test_prepared_request_method(self):
request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare())

self.assertEqual(request.method(), 'GET')

def test_dict_host(self):
request = EscherRequest({
'method': 'GET',
'url': '/?foo=bar',
'host': 'localhost:8080'
})

self.assertEqual(request.host(), 'localhost:8080')

def test_prepared_request_host(self):
request = EscherRequest(requests.Request('GET', 'http://localhost:8080/?foo=bar').prepare())

self.assertEqual(request.host(), 'localhost:8080')

def test_dict_path(self):
request = EscherRequest({
'method': 'GET',
'url': '/?foo=bar',
})

self.assertEqual(request.path(), '/')

def test_prepared_request_path(self):
request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare())

self.assertEqual(request.path(), '/')
self.assertListEqual(request.query_parts(), [
('foo', 'bar'),
])
self.assertListEqual(request.headers(), [
('Date', 'Mon, 09 Sep 2011 23:36:00 GMT'),
('Host', 'host.foo.com'),
])
self.assertEqual(request.body(), None) # there was no body specified

def test_object_complex(self):
def test_dict_query_parts(self):
request = EscherRequest({
'method': 'POST',
'host': 'host.foo.com',
'url': '/example/path/?foo=bar&abc=cba',
'headers': [],
'body': 'HELLO WORLD!',
})
self.assertEqual(request.method(), 'POST')
self.assertEqual(request.host(), 'host.foo.com')
self.assertEqual(request.path(), '/example/path/')
self.assertListEqual(request.query_parts(), [
('foo', 'bar'),
('abc', 'cba'),
])
self.assertListEqual(request.headers(), [])
self.assertEqual(request.body(), 'HELLO WORLD!')

def test_object_add_header(self):
'method': 'GET',
'url': '/?foo=bar',
})

self.assertEqual(request.query_parts(), [('foo', 'bar')])

def test_prepared_request_query_parts(self):
request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare())

self.assertEqual(request.query_parts(), [('foo', 'bar')])

def test_dict_has_query_param(self):
request = EscherRequest({
'method': 'GET',
'url': '/?foo=bar',
})

self.assertTrue(request.has_query_param('foo'))
self.assertFalse(request.has_query_param('bar'))

def test_prepared_request_has_query_param(self):
request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare())

self.assertTrue(request.has_query_param('foo'))
self.assertFalse(request.has_query_param('bar'))

def test_dict_headers(self):
request = EscherRequest({
'method': 'GET',
'url': '/?foo=bar',
'headers': [['Foo', 'bar']],
})

self.assertListEqual(request.headers(), [['Foo', 'bar']])

def test_prepared_request_headers(self):
request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar', headers={'Foo': 'bar'}).prepare())

self.assertListEqual(request.headers(), [['Foo', 'bar']])

def test_dict_has_header(self):
request = EscherRequest({
'method': 'GET',
'url': '/?foo=bar',
'headers': [['Foo', 'bar']],
})

self.assertTrue(request.has_header('foo'))
self.assertFalse(request.has_header('bar'))

def test_prepared_request_has_header(self):
request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar', headers={'Foo': 'bar'}).prepare())

self.assertTrue(request.has_header('foo'))
self.assertFalse(request.has_header('bar'))

def test_presigned_url_body(self):
request = EscherRequest({
'method': 'GET',
'url': '/?foo=bar',
})
request.set_presigned_url(True)

self.assertEqual(request.body(), 'UNSIGNED-PAYLOAD')

def test_dict_body(self):
request = EscherRequest({
'method': 'POST',
'host': 'host.foo.com',
'url': '/example/path/?foo=bar&abc=cba',
'headers': [],
'body': 'HELLO WORLD!',
'url': '/?foo=bar',
'body': 'foo',
})
request.add_header('Foo', 'Bar')
self.assertListEqual(request.headers(), [['Foo', 'Bar']])

self.assertEqual(request.body(), 'foo')

def test_prepared_request_body(self):
request = EscherRequest(requests.Request('POST', 'http://localhost/?foo=bar', data='foo').prepare())

self.assertEqual(request.body(), 'foo')

def test_prepared_request_byte_body(self):
request = EscherRequest(requests.Request('POST', 'http://localhost/?foo=bar', data=b'foo').prepare())

self.assertEqual(request.body(), 'foo')

def test_dict_add_header(self):
request = EscherRequest({
'method': 'GET',
'url': '/?foo=bar',
})

self.assertFalse(request.has_header('bar'))
request.add_header('Bar', 'baz')
self.assertTrue(request.has_header('bar'))

def test_prepared_request_add_header(self):
request = EscherRequest(requests.Request('GET', 'http://localhost/?foo=bar').prepare())

self.assertFalse(request.has_header('bar'))
request.add_header('Bar', 'baz')
self.assertTrue(request.has_header('bar'))

0 comments on commit 1d7a1d3

Please sign in to comment.