diff --git a/README.md b/README.md index 91ca05d..9c23c92 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ func main() { Length: 128, } - fmt.Println(p.VerifyCode()) + fmt.Println(p.VerifyCode()) // This is optional // Output: 45d7820e694481f399e7fb9c444f0cb63301a7254d1401443835d9af2c9a6a5ec5b243c3470feb945336025964ef05c8d2f0e44baf76762ba6136914 fmt.Println(p.ChallengeCode()) @@ -64,4 +64,4 @@ func main() { } ``` -> Note: You have to call `p.VerifyCode()` before `p.ChallengeCode()` to generate a random string and its hash +> Note: Calling `p.VerifyCode()` optional, but calling it after `p.ChallengeCode()` will reset `pkce.RandomString` diff --git a/pkce.go b/pkce.go index 02f6495..09c7c10 100644 --- a/pkce.go +++ b/pkce.go @@ -47,8 +47,15 @@ func (p *Pkce) VerifyCode() (string, error) { // ChallengeCode returns a challenge code as mentioned in https://tools.ietf.org/html/rfc7636#section-4.2. // The code is based on // code_challenge = BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) -func (p *Pkce) ChallengeCode() string { +func (p *Pkce) ChallengeCode() (string, error) { + if p.RandomString == "" { + code, err := p.VerifyCode() + if err != nil { + return "", err + } + p.RandomString = code + } hash := sha256.Sum256([]byte(p.RandomString)) toBase64 := base64.RawURLEncoding.EncodeToString(hash[:]) - return toBase64 + return toBase64, nil } diff --git a/pkce_test.go b/pkce_test.go index 4bae38f..2177134 100644 --- a/pkce_test.go +++ b/pkce_test.go @@ -13,7 +13,30 @@ func TestPkce_ChallengeCode(t *testing.T) { RandomString: "45d7820e694481f399e7fb9c444f0cb63301a7254d1401443835d9af2c9a6a5ec5b243c3470feb945336025964ef05c8d2f0e44baf76762ba6136914", } - assert.Equal(t, "iQoF8w9kq5RnuMdisRXypyOoMCF7FGz-ro7dwHjC28U", p.ChallengeCode()) + code, err := p.ChallengeCode() + if assert.Nil(t, err) { + assert.Equal(t, "iQoF8w9kq5RnuMdisRXypyOoMCF7FGz-ro7dwHjC28U", code) + } +} + +func TestPkce_ChallengeCode2(t *testing.T) { + p := Pkce{ + Length: 50, + } + + code, err := p.ChallengeCode() + if assert.Nil(t, err) { + assert.IsType(t, "", code) + } +} + +func TestPkce_ChallengeCode3(t *testing.T) { + p := Pkce{} + + _, err := p.ChallengeCode() + if assert.Error(t, err) { + assert.Equal(t, errors.New("length should be greater than 0"), err) + } } func TestPkce_VerifyCode(t *testing.T) { @@ -103,7 +126,11 @@ func ExamplePkce_ChallengeCode() { p := Pkce{ RandomString: "45d7820e694481f399e7fb9c444f0cb63301a7254d1401443835d9af2c9a6a5ec5b243c3470feb945336025964ef05c8d2f0e44baf76762ba6136914", } - fmt.Println(p.ChallengeCode()) + code, err := p.ChallengeCode() + if err != nil { + panic(err) + } + fmt.Println(code) // Output: iQoF8w9kq5RnuMdisRXypyOoMCF7FGz-ro7dwHjC28U } @@ -115,5 +142,9 @@ func ExamplePkce_ChallengeCode_generatedString() { if err != nil { panic(err) } - fmt.Println(code, p.ChallengeCode()) + challengeCode, err := p.ChallengeCode() + if err != nil { + panic(err) + } + fmt.Println(code, challengeCode) }