diff --git a/src/spake2/spake2.py b/src/spake2/spake2.py index f9d37a7..b8aaf20 100644 --- a/src/spake2/spake2.py +++ b/src/spake2/spake2.py @@ -36,24 +36,25 @@ class ReflectionThwarted(SPAKEError): # Y = scalarmult(g, y) # Y* = Y + scalarmult(N, int(pw)) # KA = scalarmult(Y* + scalarmult(N, -int(pw)), x) -# key = H(H(idA), H(idB), X*, Y*, KA) +# key = H(H(pw) + H(idA) + H(idB) + X* + Y* + KA) # KB = scalarmult(X* + scalarmult(M, -int(pw)), y) -# key = H(H(idA), H(idB), X*, Y*, KB) +# key = H(H(pw) + H(idA) + H(idB) + X* + Y* + KB) # to serialize intermediate state, just remember x and A-vs-B. And U/V. def finalize_SPAKE2(idA, idB, X_msg, Y_msg, K_bytes, pw): - transcript = b"".join([sha256(idA).digest(), sha256(idB).digest(), - X_msg, Y_msg, K_bytes, pw]) + transcript = b"".join([sha256(pw).digest(), + sha256(idA).digest(), sha256(idB).digest(), + X_msg, Y_msg, K_bytes]) key = sha256(transcript).digest() return key def finalize_SPAKE2_symmetric(idSymmetric, msg1, msg2, K_bytes, pw): # since we don't know which side is which, we must sort the messages first_msg, second_msg = sorted([msg1, msg2]) - transcript = b"".join([sha256(idSymmetric).digest(), - first_msg, second_msg, K_bytes, - pw]) + transcript = b"".join([sha256(pw).digest(), + sha256(idSymmetric).digest(), + first_msg, second_msg, K_bytes]) key = sha256(transcript).digest() return key diff --git a/src/spake2/test/test_compat.py b/src/spake2/test/test_compat.py index b3b5378..cfa70cf 100644 --- a/src/spake2/test/test_compat.py +++ b/src/spake2/test/test_compat.py @@ -30,7 +30,7 @@ def test_asymmetric(self): kA,kB = sA.finish(m1B), sB.finish(m1A) self.assertEqual(hexlify(kA), - b"9134475e92119062ca9026db6f1a11127f51b8b77133b0b488ac3f328cc9bbdf") + b"a480bca13fa04464bb644f10e340125e96c9494f7399fef7c2bda67eb0fdf06d") self.assertEqual(hexlify(kA), hexlify(kB)) self.assertEqual(len(kA), len(sha256().digest())) @@ -46,7 +46,7 @@ def test_symmetric(self): k1,k2 = s1.finish(m12), s2.finish(m11) self.assertEqual(hexlify(k1), - b"8a69eb6c4b6ad7b871a64f2bde5b8c1fa12268526ee478ef8b53aad44687e1e9") + b"9c4fccaa3f0740615cee6fd10ed5d3a311b91b5bdc65f53e4ea7cb2fe8aa96eb") self.assertEqual(hexlify(k1), hexlify(k2)) self.assertEqual(len(k1), len(sha256().digest())) @@ -219,13 +219,13 @@ class Finalize(unittest.TestCase): def test_asymmetric(self): key = finalize_SPAKE2(b"idA", b"idB", b"X_msg", b"Y_msg", b"K_bytes", b"pw") - self.assertEqual(hexlify(key), b"b90002522d29f405fbd5de17741c45c96dec0a4d48c44b05ad53c374c5a48a30") + self.assertEqual(hexlify(key), b"aa02a627537543399bb1b4b430646480b6d36ab5c44842e738c8f78694d8afac") def test_symmetric(self): key1 = finalize_SPAKE2_symmetric(b"idSymmetric", b"X_msg", b"Y_msg", b"K_bytes", b"pw") - self.assertEqual(hexlify(key1), b"8a3738cdf3d99390d8b4d2e581b88184d7ab59125767f5b5a84d5643dbab1cb7") + self.assertEqual(hexlify(key1), b"330a7ce7bb010fea7dae7e15b2261315403ab5dc269e461f6eb1cc6566620790") key2 = finalize_SPAKE2_symmetric(b"idSymmetric", b"Y_msg", b"X_msg", b"K_bytes", b"pw")