diff --git a/tlslite/constants.py b/tlslite/constants.py index 002b03089..682802ba6 100644 --- a/tlslite/constants.py +++ b/tlslite/constants.py @@ -119,6 +119,7 @@ class ExtensionType: # RFC 6066 / 4366 srp = 12 # RFC 5054 signature_algorithms = 13 # RFC 5246 alpn = 16 # RFC 7301 + signed_certificate_timestamp = 18 # RFC 6962 client_hello_padding = 21 # RFC 7685 encrypt_then_mac = 22 # RFC 7366 extended_master_secret = 23 # RFC 7627 diff --git a/tlslite/extensions.py b/tlslite/extensions.py index 23a2c192e..d502eb411 100644 --- a/tlslite/extensions.py +++ b/tlslite/extensions.py @@ -1248,6 +1248,78 @@ def parse(self, parser): return self +class SCTExtension(TLSExtension): + """ + Client and Server Hello extension from Certificate Transparency. + + Extension containing a list of serialised SignedCertificateTimestamp + objects. + + See RFC 6962 + """ + + def __init__(self): + """Create instance of class""" + extType = ExtensionType.signed_certificate_timestamp + super(SCTExtension, self).__init__(extType=extType) + self.sct_list = None + + def create(self, sct_list): + """ + Set the list of signed certificate timestamps + + @type sct_list: list of bytearrays + @param sct_list: list of serialised certificate time stamps + """ + self.sct_list = sct_list + return self + + @property + def extData(self): + """ + Return raw encoding of the extension + + @rtype: bytearray + """ + if self.sct_list is None: + return bytearray(0) + + writer = Writer() + # elements have 2 byte header lengths + for sct in self.sct_list: + writer.add(len(sct), 2) + writer.bytes += sct + + writer2 = Writer() + writer2.add(len(writer.bytes), 2) + return writer2.bytes + writer.bytes + + def parse(self, parser): + """ + Deserialise extension from on the wire data. + + @type parser: L{tlslite.util.codec.Parser} + @param parser: data to be parsed + + @rtype: L{SCTExtension} + """ + if parser.getRemainingLength() == 0: + self.sct_list = None + return self + + self.sct_list = [] + + parser.startLengthCheck(2) + while not parser.atLengthCheck(): + self.sct_list.append(parser.getVarBytes(2)) + parser.stopLengthCheck() + + if parser.getRemainingLength() != 0: + raise SyntaxError("Trailing data in SCTExtension") + + return self + + TLSExtension._universalExtensions = \ { ExtensionType.server_name: SNIExtension, @@ -1257,6 +1329,7 @@ def parse(self, parser): ExtensionType.srp: SRPExtension, ExtensionType.signature_algorithms: SignatureAlgorithmsExtension, ExtensionType.alpn: ALPNExtension, + ExtensionType.signed_certificate_timestamp: SCTExtension, ExtensionType.supports_npn: NPNExtension, ExtensionType.client_hello_padding: PaddingExtension, ExtensionType.renegotiation_info: RenegotiationInfoExtension} diff --git a/unit_tests/test_tlslite_extensions.py b/unit_tests/test_tlslite_extensions.py index d4773cd48..47bd41227 100644 --- a/unit_tests/test_tlslite_extensions.py +++ b/unit_tests/test_tlslite_extensions.py @@ -12,7 +12,7 @@ SRPExtension, ClientCertTypeExtension, ServerCertTypeExtension,\ TACKExtension, SupportedGroupsExtension, ECPointFormatsExtension,\ SignatureAlgorithmsExtension, PaddingExtension, VarListExtension, \ - RenegotiationInfoExtension, ALPNExtension + RenegotiationInfoExtension, ALPNExtension, SCTExtension from tlslite.utils.codec import Parser from tlslite.constants import NameType, ExtensionType, GroupName,\ ECPointFormat, HashAlgorithm, SignatureAlgorithm @@ -1587,5 +1587,89 @@ def test_parse_from_TLSExtension(self): bytearray(b'spdy/1')]) +class TestSCTExtension(unittest.TestCase): + def setUp(self): + self.ext = SCTExtension() + + def test___int__(self): + self.assertIsNotNone(self.ext) + self.assertEqual(self.ext.extType, 18) + self.assertEqual(self.ext.extData, bytearray()) + self.assertIsNone(self.ext.sct_list) + + def test_create(self): + ext2 = self.ext.create([bytearray(b'SCT number 1'), + bytearray(b'SCT number 2')]) + + self.assertIs(self.ext, ext2) + self.assertEqual(self.ext.sct_list, [bytearray(b'SCT number 1'), + bytearray(b'SCT number 2')]) + + def test_extData_with_empty_array(self): + self.ext.create([]) + + self.assertEqual(self.ext.extData, bytearray(b'\x00\x00')) + + def test_extData_with_empty_SCTs(self): + self.ext.create([bytearray(), bytearray()]) + + self.assertEqual(self.ext.extData, bytearray(b'\x00\x04' + b'\x00\x00' + b'\x00\x00')) + + def test_extData(self): + self.ext.create([bytearray(b'test'), bytearray(b'example')]) + + self.assertEqual(self.ext.extData, bytearray(b'\x00\x0f' + b'\x00\x04test' + b'\x00\x07example')) + + def test_parse_with_empty_data(self): + parser = Parser(bytearray(b'')) + + ret = self.ext.parse(parser) + + self.assertIs(ret, self.ext) + self.assertIsNone(self.ext.sct_list) + + def test_parse_with_empty_array(self): + parser = Parser(bytearray(b'\x00\x00')) + + ret = self.ext.parse(parser) + + self.assertIs(ret, self.ext) + self.assertEqual(self.ext.sct_list, []) + + def test_parse_with_empty_elements(self): + parser = Parser(bytearray(b'\x00\x04\x00\x00\x00\x00')) + + self.ext.parse(parser) + + self.assertEqual(self.ext.sct_list, [bytearray(), bytearray()]) + + def test_parse_with_value(self): + parser = Parser(bytearray(b'\x00\x06\x00\x04test')) + + self.ext.parse(parser) + + self.assertEqual(self.ext.sct_list, [bytearray(b'test')]) + + def test_parse_with_overflowing_data(self): + parser = Parser(bytearray(b'\x00\x00test')) + + with self.assertRaises(SyntaxError): + self.ext.parse(parser) + + def test_parse_from_TLSExtension(self): + ext = TLSExtension() + + parser = Parser(bytearray(b'\x00\x12\x00\x08' + b'\x00\x06\x00\x04test')) + + ret = ext.parse(parser) + self.assertIsInstance(ret, SCTExtension) + self.assertEqual(ret.sct_list, [bytearray(b'test')]) + + if __name__ == '__main__': unittest.main()