@@ -11,76 +11,73 @@ import com.nimbusds.jwt.JWTClaimsSet
1111import com.nimbusds.jwt.SignedJWT
1212import com.nimbusds.oauth2.sdk.TokenRequest
1313import no.nav.security.mock.oauth2.extensions.clientIdAsString
14+ import no.nav.security.mock.oauth2.extensions.issuerId
1415import okhttp3.HttpUrl
15- import java.security.KeyPair
1616import java.security.KeyPairGenerator
1717import java.security.interfaces.RSAPrivateKey
1818import java.security.interfaces.RSAPublicKey
1919import java.time.Duration
2020import java.time.Instant
2121import java.util.Date
2222import java.util.UUID
23+ import java.util.concurrent.ConcurrentHashMap
2324
2425class OAuth2TokenProvider {
25- private val jwkSet: JWKSet = generateJWKSet(DEFAULT_KEYID )
26- private val rsaKey: RSAKey = jwkSet.getKeyByKeyId(DEFAULT_KEYID ) as RSAKey
26+ private val signingKeys: ConcurrentHashMap <String , RSAKey > = ConcurrentHashMap ()
2727
28- fun publicJwkSet (): JWKSet {
29- return jwkSet.toPublicJWKSet()
28+ @JvmOverloads
29+ fun publicJwkSet (issuerId : String = "default"): JWKSet {
30+ return JWKSet (rsaKey(issuerId)).toPublicJWKSet()
3031 }
3132
3233 fun idToken (
3334 tokenRequest : TokenRequest ,
3435 issuerUrl : HttpUrl ,
3536 oAuth2TokenCallback : OAuth2TokenCallback ,
3637 nonce : String? = null
37- ) = createSignedJWT(
38- defaultClaims(
39- issuerUrl,
40- oAuth2TokenCallback.subject(tokenRequest),
41- listOf (tokenRequest.clientIdAsString()),
42- nonce,
43- oAuth2TokenCallback.addClaims(tokenRequest),
44- oAuth2TokenCallback.tokenExpiry()
45- )
46- )
38+ ) = defaultClaims(
39+ issuerUrl,
40+ oAuth2TokenCallback.subject(tokenRequest),
41+ listOf (tokenRequest.clientIdAsString()),
42+ nonce,
43+ oAuth2TokenCallback.addClaims(tokenRequest),
44+ oAuth2TokenCallback.tokenExpiry()
45+ ).sign(issuerUrl.issuerId())
4746
4847 fun accessToken (
4948 tokenRequest : TokenRequest ,
5049 issuerUrl : HttpUrl ,
5150 oAuth2TokenCallback : OAuth2TokenCallback ,
5251 nonce : String? = null
53- ) = createSignedJWT(
54- defaultClaims(
55- issuerUrl,
56- oAuth2TokenCallback.subject(tokenRequest),
57- oAuth2TokenCallback.audience(tokenRequest),
58- nonce,
59- oAuth2TokenCallback.addClaims(tokenRequest),
60- oAuth2TokenCallback.tokenExpiry()
61- )
62- )
52+ ) = defaultClaims(
53+ issuerUrl,
54+ oAuth2TokenCallback.subject(tokenRequest),
55+ oAuth2TokenCallback.audience(tokenRequest),
56+ nonce,
57+ oAuth2TokenCallback.addClaims(tokenRequest),
58+ oAuth2TokenCallback.tokenExpiry()
59+ ).sign(issuerUrl.issuerId())
6360
6461 fun exchangeAccessToken (
6562 tokenRequest : TokenRequest ,
6663 issuerUrl : HttpUrl ,
6764 claimsSet : JWTClaimsSet ,
6865 oAuth2TokenCallback : OAuth2TokenCallback
6966 ) = Instant .now().let { now ->
70- createSignedJWT(
71- JWTClaimsSet .Builder (claimsSet)
72- .issuer(issuerUrl.toString())
73- .expirationTime(Date .from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
74- .notBeforeTime(Date .from(now))
75- .issueTime(Date .from(now))
76- .jwtID(UUID .randomUUID().toString())
77- .audience(oAuth2TokenCallback.audience(tokenRequest))
78- .addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
79- .build()
80- )
67+ JWTClaimsSet .Builder (claimsSet)
68+ .issuer(issuerUrl.toString())
69+ .expirationTime(Date .from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
70+ .notBeforeTime(Date .from(now))
71+ .issueTime(Date .from(now))
72+ .jwtID(UUID .randomUUID().toString())
73+ .audience(oAuth2TokenCallback.audience(tokenRequest))
74+ .addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
75+ .build()
76+ .sign(issuerUrl.issuerId())
8177 }
8278
83- fun jwt (claims : Map <String , Any >, expiry : Duration = Duration .ofHours(1)): SignedJWT =
79+ @JvmOverloads
80+ fun jwt (claims : Map <String , Any >, expiry : Duration = Duration .ofHours(1), issuerId : String = "default"): SignedJWT =
8481 JWTClaimsSet .Builder ().let { builder ->
8582 val now = Instant .now()
8683 builder
@@ -89,18 +86,20 @@ class OAuth2TokenProvider {
8986 .expirationTime(Date .from(now.plusSeconds(expiry.toSeconds())))
9087 builder.addClaims(claims)
9188 builder.build()
92- }.let {
93- createSignedJWT(it)
94- }
89+ }.sign(issuerId)
90+
91+ private fun rsaKey ( issuerId : String ): RSAKey = signingKeys.computeIfAbsent(issuerId) { generateRSAKey(issuerId) }
9592
96- private fun createSignedJWT (claimsSet : JWTClaimsSet ): SignedJWT {
97- val header = JWSHeader .Builder (JWSAlgorithm .RS256 )
98- .keyID(rsaKey.keyID)
99- .type(JOSEObjectType .JWT )
100- val signedJWT = SignedJWT (header.build(), claimsSet)
101- val signer = RSASSASigner (rsaKey.toPrivateKey())
102- signedJWT.sign(signer)
103- return signedJWT
93+ private fun JWTClaimsSet.sign (issuerId : String ): SignedJWT {
94+ val key = rsaKey(issuerId)
95+ return SignedJWT (
96+ JWSHeader .Builder (JWSAlgorithm .RS256 )
97+ .keyID(key.keyID)
98+ .type(JOSEObjectType .JWT ).build(),
99+ this
100+ ).apply {
101+ sign(RSASSASigner (key.toPrivateKey()))
102+ }
104103 }
105104
106105 private fun JWTClaimsSet.Builder.addClaims (claims : Map <String , Any > = emptyMap()) = apply {
@@ -130,21 +129,16 @@ class OAuth2TokenProvider {
130129 }
131130
132131 companion object {
133- private const val DEFAULT_KEYID = " mock-oauth2-server-key"
134- private fun generateJWKSet (keyId : String ) =
135- JWKSet (createRSAKey(keyId, generateKeyPair()))
136-
137- private fun generateKeyPair (): KeyPair =
132+ private fun generateRSAKey (keyId : String ): RSAKey =
138133 KeyPairGenerator .getInstance(" RSA" ).let {
139134 it.initialize(2048 )
140135 it.generateKeyPair()
136+ }.let {
137+ RSAKey .Builder (it.public as RSAPublicKey )
138+ .privateKey(it.private as RSAPrivateKey )
139+ .keyUse(KeyUse .SIGNATURE )
140+ .keyID(keyId)
141+ .build()
141142 }
142-
143- private fun createRSAKey (keyID : String , keyPair : KeyPair ) =
144- RSAKey .Builder (keyPair.public as RSAPublicKey )
145- .privateKey(keyPair.private as RSAPrivateKey )
146- .keyUse(KeyUse .SIGNATURE )
147- .keyID(keyID)
148- .build()
149143 }
150144}
0 commit comments