diff --git a/jwe.go b/jwe.go index a03ff53..3e5dbfb 100644 --- a/jwe.go +++ b/jwe.go @@ -19,17 +19,26 @@ func NewJWE(alg KeyAlgorithm, key interface{}, method EncryptionType, plaintext } // Generate a random Content Encryption Key (CEK). - cek, err := generateKey(chipher.keySize) + var encrypter Encrypter + switch alg { + case KeyAlgorithmRSAOAEP: + encrypter, err = createRsaEncrypter(key) + case KeyAlgorithmDIR: + encrypter, err = createDirEncrypter(key) + default: + return nil, ErrUnsupportedKeyAlgorithm + } if err != nil { return nil, err } - // Encrypt the CEK with the recipient's public key to produce the JWE Encrypted Key. - jwe.protected.Alg = alg - encrypter, err := createEncrypter(key) + // Generate a Content Encryption Key (CEK). + cek, err := encrypter.GenerateCEK(method) if err != nil { return nil, err } + + jwe.protected.Alg = alg jwe.recipientKey, err = encrypter.Encrypt(cek, alg) if err != nil { return nil, err diff --git a/jwe_decrypt.go b/jwe_decrypt.go index 0d42f9c..7ee17e3 100644 --- a/jwe_decrypt.go +++ b/jwe_decrypt.go @@ -13,7 +13,6 @@ var ( // Decrypt decrypts JWE ciphertext with the key func (jwe jwe) Decrypt(key interface{}) ([]byte, error) { - method := jwe.protected.Enc if len(method) == 0 { return nil, ErrMissingEncHeader @@ -27,7 +26,16 @@ func (jwe jwe) Decrypt(key interface{}) ([]byte, error) { if len(alg) == 0 { return nil, ErrMissingAlgHeader } - decrypter, err := createDecrypter(key) + + var decrypter Decrypter + switch alg { + case KeyAlgorithmRSAOAEP: + decrypter, err = createRsaDecrypter(key) + case KeyAlgorithmDIR: + decrypter, err = createDirDecrypter(key) + default: + return nil, ErrUnsupportedKeyAlgorithm + } if err != nil { return nil, err } diff --git a/jwe_test.go b/jwe_test.go index bf88645..d1541fa 100644 --- a/jwe_test.go +++ b/jwe_test.go @@ -6,11 +6,12 @@ import ( "crypto/rsa" "encoding/base64" "fmt" - "github.com/golang-jwt/jwe" "math/big" "os" "strings" "testing" + + "github.com/golang-jwt/jwe" ) func TestParseEncrypted(t *testing.T) { @@ -101,26 +102,49 @@ func TestRFC7516_A1(t *testing.T) { } func TestDecrypt(t *testing.T) { - keyData, _ := os.ReadFile("test/sample_key.pub") - pk, _ := jwe.ParseRSAPublicKeyFromPEM(keyData) + t.Run("RSA-OAEP", func(t *testing.T) { + keyData, _ := os.ReadFile("test/sample_key.pub") + pk, _ := jwe.ParseRSAPublicKeyFromPEM(keyData) + + originalText := "The true sign of intelligence is not knowledge but imagination." + token, err := jwe.NewJWE(jwe.KeyAlgorithmRSAOAEP, pk, jwe.EncryptionTypeA256GCM, []byte(originalText)) + if err != nil { + t.Error(err) + return + } - originalText := "The true sign of intelligence is not knowledge but imagination." - token, err := jwe.NewJWE(jwe.KeyAlgorithmRSAOAEP, pk, jwe.EncryptionTypeA256GCM, []byte(originalText)) - if err != nil { - t.Error(err) - return - } + keyData, _ = os.ReadFile("test/sample_key") + k, _ := jwe.ParseRSAPrivateKeyFromPEM(keyData) - keyData, _ = os.ReadFile("test/sample_key") - k, _ := jwe.ParseRSAPrivateKeyFromPEM(keyData) + decrypted, err := token.Decrypt(k) + if err != nil { + t.Error(err) + return + } - decrypted, err := token.Decrypt(k) - if err != nil { - t.Error(err) - return - } + if string(decrypted) != originalText { + t.Errorf("%s != %s", decrypted, originalText) + } + }) - if string(decrypted) != originalText { - t.Errorf("%s != %s", decrypted, originalText) - } + t.Run("DIR", func(t *testing.T) { + cek := []byte("a-secret-that-is-32-bytes-long-x") + + originalText := "The true sign of intelligence is not knowledge but imagination." + token, err := jwe.NewJWE(jwe.KeyAlgorithmDIR, cek, jwe.EncryptionTypeA256GCM, []byte(originalText)) + if err != nil { + t.Error(err) + return + } + + decrypted, err := token.Decrypt(cek) + if err != nil { + t.Error(err) + return + } + + if string(decrypted) != originalText { + t.Errorf("%s != %s", decrypted, originalText) + } + }) } diff --git a/keycrypter.go b/keycrypter.go index 61cba9e..85a92e0 100644 --- a/keycrypter.go +++ b/keycrypter.go @@ -1,9 +1,6 @@ package jwe import ( - "crypto/rand" - "crypto/rsa" - "crypto/sha1" "errors" ) @@ -14,52 +11,16 @@ var ( type KeyAlgorithm string -var KeyAlgorithmRSAOAEP = KeyAlgorithm("RSA-OAEP") - -type rsaEncrypter struct { - key *rsa.PublicKey -} - -func (r *rsaEncrypter) Encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) { - switch alg { - case KeyAlgorithmRSAOAEP: - return rsa.EncryptOAEP(sha1.New(), rand.Reader, r.key, cek, []byte{}) - default: - return nil, ErrUnsupportedKeyAlgorithm - } -} - -type rsaDecrypter struct { - key *rsa.PrivateKey -} - -func (r *rsaDecrypter) Decrypt(encryptedKey []byte, alg KeyAlgorithm) ([]byte, error) { - switch alg { - case KeyAlgorithmRSAOAEP: - return rsa.DecryptOAEP(sha1.New(), rand.Reader, r.key, encryptedKey, []byte{}) - default: - return nil, ErrUnsupportedKeyAlgorithm - } -} +var ( + KeyAlgorithmRSAOAEP = KeyAlgorithm("RSA-OAEP") + KeyAlgorithmDIR = KeyAlgorithm("dir") +) -func createEncrypter(key interface{}) (*rsaEncrypter, error) { - switch pbk := key.(type) { - case *rsa.PublicKey: - return &rsaEncrypter{ - key: pbk, - }, nil - default: - return nil, ErrUnsupportedKeyType - } +type Encrypter interface { + GenerateCEK(method EncryptionType) ([]byte, error) + Encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) } -func createDecrypter(key interface{}) (*rsaDecrypter, error) { - switch pk := key.(type) { - case *rsa.PrivateKey: - return &rsaDecrypter{ - key: pk, - }, nil - default: - return nil, ErrUnsupportedKeyType - } +type Decrypter interface { + Decrypt(encryptedKey []byte, alg KeyAlgorithm) ([]byte, error) } diff --git a/keycrypter_dir.go b/keycrypter_dir.go new file mode 100644 index 0000000..9a79f2f --- /dev/null +++ b/keycrypter_dir.go @@ -0,0 +1,50 @@ +package jwe + +import "errors" + +var ( + ErrInvalidEncryptedKey = errors.New("invalid encrypted key") +) + +type dirEncrypter struct { + key []byte +} + +func (r *dirEncrypter) GenerateCEK(method EncryptionType) ([]byte, error) { + return r.key, nil +} + +func (r *dirEncrypter) Encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) { + return nil, nil +} + +type dirDecrypter struct { + key []byte +} + +func (r *dirDecrypter) Decrypt(encryptedKey []byte, alg KeyAlgorithm) ([]byte, error) { + if len(encryptedKey) > 0 { + return nil, ErrInvalidEncryptedKey + } + return r.key, nil +} + +func createDirEncrypter(key interface{}) (*dirEncrypter, error) { + if key, ok := key.([]byte); ok { + return &dirEncrypter{ + key: key, + }, nil + } else { + return nil, ErrUnsupportedKeyType + } +} + +func createDirDecrypter(key interface{}) (*dirDecrypter, error) { + if key, ok := key.([]byte); ok { + return &dirDecrypter{ + key: key, + }, nil + } else { + return nil, ErrUnsupportedKeyType + } +} diff --git a/keycrypter_rsa.go b/keycrypter_rsa.go new file mode 100644 index 0000000..0255c3b --- /dev/null +++ b/keycrypter_rsa.go @@ -0,0 +1,63 @@ +package jwe + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" +) + +type rsaEncrypter struct { + key *rsa.PublicKey +} + +func (r *rsaEncrypter) GenerateCEK(method EncryptionType) ([]byte, error) { + chipher, err := getCipher(method) + if err != nil { + return nil, err + } + // Generate a random Content Encryption Key (CEK). + cek, err := generateKey(chipher.keySize) + if err != nil { + return nil, err + } + + return cek, nil +} + +func (r *rsaEncrypter) Encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) { + if alg != KeyAlgorithmRSAOAEP { + return nil, ErrUnsupportedKeyAlgorithm + } + return rsa.EncryptOAEP(sha1.New(), rand.Reader, r.key, cek, []byte{}) +} + +type rsaDecrypter struct { + key *rsa.PrivateKey +} + +func (r *rsaDecrypter) Decrypt(encryptedKey []byte, alg KeyAlgorithm) ([]byte, error) { + if alg != KeyAlgorithmRSAOAEP { + return nil, ErrUnsupportedKeyAlgorithm + } + return rsa.DecryptOAEP(sha1.New(), rand.Reader, r.key, encryptedKey, []byte{}) +} + +func createRsaEncrypter(key interface{}) (*rsaEncrypter, error) { + if key, ok := key.(*rsa.PublicKey); ok { + return &rsaEncrypter{ + key: key, + }, nil + } else { + return nil, ErrUnsupportedKeyType + } +} + +func createRsaDecrypter(key interface{}) (*rsaDecrypter, error) { + if key, ok := key.(*rsa.PrivateKey); ok { + return &rsaDecrypter{ + key: key, + }, nil + } else { + return nil, ErrUnsupportedKeyType + } +}