Skip to content

Commit

Permalink
Skip some RSA tests if test vector file is not found
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Jun 16, 2024
1 parent c88f666 commit 963ccbb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lib/Crypto/SelfTest/PublicKey/test_import_ECC.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load_file(file_name, mode="rb"):
results = file_in.read()

except FileNotFoundError:
warnings.warn("Warning: skipping extended tests for ECC",
warnings.warn("Skipping extended tests for ECC",
UserWarning,
stacklevel=2)

Expand Down
20 changes: 9 additions & 11 deletions lib/Crypto/SelfTest/PublicKey/test_import_RSA.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
import errno
import warnings
import unittest
from unittest import SkipTest

from Crypto.PublicKey import RSA
from Crypto.SelfTest.st_common import a2b_hex, list_test_cases
from Crypto.IO import PEM
from Crypto.Util.py3compat import b, tostr, FileNotFoundError
from Crypto.Util.number import inverse, bytes_to_long
from Crypto.Util.number import inverse
from Crypto.Util import asn1

try:
Expand All @@ -56,10 +57,13 @@ def load_file(file_name, mode="rb"):
results = file_in.read()

except FileNotFoundError:
warnings.warn("Warning: skipping extended tests for RSA",
warnings.warn("Skipping tests for RSA based on %s" % file_name,
UserWarning,
stacklevel=2)

if results is None:
raise SkipTest("Missing %s" % file_name)

return results


Expand Down Expand Up @@ -564,7 +568,7 @@ def test_import_pss(self):
pub_key = RSA.import_key(pub_key_file)

priv_key_file = load_file("rsa2048_pss_private.pem")
priv_key = RSA.import_key(pub_key_file)
priv_key = RSA.import_key(priv_key_file)

self.assertEqual(pub_key.n, priv_key.n)

Expand Down Expand Up @@ -618,11 +622,6 @@ def test_import_pkcs8_private(self):
self.assertEqual(key_ref, key)



if __name__ == '__main__':
unittest.main()


def get_tests(config={}):
tests = []
tests += list_test_cases(ImportKeyTests)
Expand All @@ -632,7 +631,6 @@ def get_tests(config={}):


if __name__ == '__main__':
suite = lambda: unittest.TestSuite(get_tests())
def suite():
return unittest.TestSuite(get_tests())
unittest.main(defaultTest='suite')

# vim:set ts=4 sw=4 sts=4 expandtab:

0 comments on commit 963ccbb

Please sign in to comment.