diff --git a/aws_google_auth/amazon.py b/aws_google_auth/amazon.py index fa6a0cc..75ef8ef 100644 --- a/aws_google_auth/amazon.py +++ b/aws_google_auth/amazon.py @@ -80,7 +80,8 @@ def roles(self): roles = {} for x in doc.xpath('//*[@Name = "https://aws.amazon.com/SAML/Attributes/Role"]//text()'): if "arn:aws:iam:" in x or "arn:aws-us-gov:iam:" in x: - res = x.split(',') + res = sorted([s.strip() for s in x.split(',')], + key=lambda s: ':role/' in s, reverse=True) roles[res[0]] = res[1] return roles diff --git a/aws_google_auth/tests/test_amazon.py b/aws_google_auth/tests/test_amazon.py index 5ca6021..7b066f5 100644 --- a/aws_google_auth/tests/test_amazon.py +++ b/aws_google_auth/tests/test_amazon.py @@ -49,6 +49,36 @@ def test_role_extraction_too_many_commas(self): "arn:aws:iam::123456789012:role/test"] self.assertEqual(sorted(list(a.roles.keys())), sorted(list_of_testing_roles)) + def test_role_extraction_role_saml_provider_wrong_order(self): + saml_xml = self.read_local_file('valid-response.xml') + provider = 'arn:aws:iam::123456789012:saml-provider/GoogleApps' + list_of_testing_roles = [ + "arn:aws:iam::123456789012:role/admin", + "arn:aws:iam::123456789012:role/read-only", + "arn:aws:iam::123456789012:role/test"] + for role in list_of_testing_roles: + saml_xml = saml_xml.replace(f'{role},{provider}'.encode('utf-8'), + f'{provider},{role}'.encode('utf-8')) + a = amazon.Amazon(self.valid_config, saml_xml) + self.assertIsInstance(a.roles, dict) + self.assertEqual(sorted(list(a.roles.keys())), sorted(list_of_testing_roles)) + self.assertEqual(set(a.roles.values()), set((provider,))) + + def test_role_extraction_whitespace(self): + saml_xml = self.read_local_file('valid-response.xml') + provider = 'arn:aws:iam::123456789012:saml-provider/GoogleApps' + list_of_testing_roles = [ + "arn:aws:iam::123456789012:role/admin", + "arn:aws:iam::123456789012:role/read-only", + "arn:aws:iam::123456789012:role/test"] + for role in list_of_testing_roles: + saml_xml = saml_xml.replace(f'{role},{provider}'.encode('utf-8'), + f'{role}, {provider}'.encode('utf-8')) + a = amazon.Amazon(self.valid_config, saml_xml) + self.assertIsInstance(a.roles, dict) + self.assertEqual(sorted(list(a.roles.keys())), sorted(list_of_testing_roles)) + self.assertEqual(set(a.roles.values()), set((provider,))) + def test_invalid_saml_too_soon(self): saml_xml = self.read_local_file('saml-response-too-soon.xml') self.assertFalse(amazon.Amazon.is_valid_saml_assertion(saml_xml))