diff --git a/client.go b/client.go index aee7d5b..462a39e 100644 --- a/client.go +++ b/client.go @@ -25,14 +25,9 @@ import ( const ( defaultRequestTimeout = time.Second * 10 defaultTokenTTL = time.Hour -) -// Config for creating new Client. -type Config struct { - TeamID string // Your Apple Team ID obtained from Apple Developer Account. - ClientID string // Your Service which enable sign-in-with-apple service. - KeyID string // Your Secret Key ID obtained from Apple Developer Account. -} + defaultBaseURL = "https://appleid.apple.com/auth" +) // Client for interaction with apple-id service. type Client struct { @@ -42,6 +37,7 @@ type Client struct { AESCert interface{} // Your Secret Key Created By X509 package. RedirectURI string // Your RedirectURI config in apple website. TokenTTL int64 + BaseURL string hc *http.Client publicKeys map[string]*rsa.PublicKey @@ -79,6 +75,12 @@ func NewClient(opts ...ClientOption) (*Client, error) { client.RedirectURI = *settings.RedirectURI } + if settings.BaseURL != nil { + client.BaseURL = *settings.BaseURL + } else { + client.BaseURL = defaultBaseURL + } + jwkSet, err := client.FetchPublicKeys() if err != nil { return nil, err @@ -100,7 +102,7 @@ func NewClient(opts ...ClientOption) (*Client, error) { // FetchPublicKeys to verify the ID token signature. // https://developer.apple.com/documentation/sign_in_with_apple/fetch_apple_s_public_key_for_verifying_token_signature func (c *Client) FetchPublicKeys() (*JWKSet, error) { - resp, err := c.hc.Get("https://appleid.apple.com/auth/keys") + resp, err := c.hc.Get(c.BaseURL + "/keys") if err != nil { return nil, err } @@ -121,9 +123,12 @@ func (c *Client) FetchPublicKeys() (*JWKSet, error) { // LoadP8CertByByte use x509.ParsePKCS8PrivateKey to Parse cert file. func (c *Client) LoadP8CertByByte(data []byte) error { block, _ := pem.Decode(data) + if block == nil { + return ErrBadCert + } cert, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { - return err + return fmt.Errorf("%w: %v", ErrBadCert, err) } c.AESCert = cert @@ -152,7 +157,7 @@ func (c *Client) CreateCallbackURL(state string) string { u.Add("state", state) u.Add("scope", "name email") - return "https://appleid.apple.com/auth/authorize?" + u.Encode() + return c.BaseURL + "/authorize?" + u.Encode() } // Authenticate with auth token. @@ -160,10 +165,6 @@ func (c *Client) CreateCallbackURL(state string) string { // Response: https://developer.apple.com/documentation/sign_in_with_apple/tokenresponse // Error: https://developer.apple.com/documentation/sign_in_with_apple/errorresponse func (c *Client) Authenticate(ctx context.Context, authCode string) (*TokenResponse, error) { - if c.AESCert == nil { - return nil, ErrMissingCert - } - signature, err := c.getSignature() if err != nil { return nil, err @@ -200,10 +201,6 @@ func (c *Client) Authenticate(ctx context.Context, authCode string) (*TokenRespo // Response: https://developer.apple.com/documentation/sign_in_with_apple/tokenresponse // Error: https://developer.apple.com/documentation/sign_in_with_apple/errorresponse func (c *Client) Refresh(ctx context.Context, refreshToken string) (*TokenResponse, error) { - if c.AESCert == nil { - return nil, ErrMissingCert - } - signature, err := c.getSignature() if err != nil { return nil, err @@ -231,7 +228,6 @@ func (c *Client) ParseUserIdentity(t string) (*UserIdentity, error) { userIdentity := UserIdentity{} if err := json.Unmarshal(body, &userIdentity); err != nil { - return nil, err } @@ -258,7 +254,7 @@ func (c *Client) ValidateToken(t string) error { func (c *Client) doRequest(ctx context.Context, v url.Values) (*TokenResponse, error) { body := strings.NewReader(v.Encode()) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://appleid.apple.com/auth/token", body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseURL+"/token", body) if err != nil { return nil, err } @@ -287,6 +283,10 @@ func (c *Client) doRequest(ctx context.Context, v url.Values) (*TokenResponse, e } func (c *Client) getSignature() (string, error) { + if c.AESCert == nil { + return "", ErrMissingCert + } + token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.StandardClaims{ Issuer: c.TeamID, IssuedAt: time.Now().Unix(), diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..60c6346 --- /dev/null +++ b/client_test.go @@ -0,0 +1,178 @@ +package apple + +import ( + "crypto" + "crypto/ecdsa" + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + KEYS = `{ + "keys": [ + { + "kty": "RSA", + "kid": "86D88Kf", + "use": "sig", + "alg": "RS256", + "n": "iGaLqP6y-SJCCBq5Hv6pGDbG_SQ11MNjH7rWHcCFYz4hGwHC4lcSurTlV8u3avoVNM8jXevG1Iu1SY11qInqUvjJur--hghr1b56OPJu6H1iKulSxGjEIyDP6c5BdE1uwprYyr4IO9th8fOwCPygjLFrh44XEGbDIFeImwvBAGOhmMB2AD1n1KviyNsH0bEB7phQtiLk-ILjv1bORSRl8AK677-1T8isGfHKXGZ_ZGtStDe7Lu0Ihp8zoUt59kx2o9uWpROkzF56ypresiIl4WprClRCjz8x6cPZXU2qNWhu71TQvUFwvIvbkE1oYaJMb0jcOTmBRZA2QuYw-zHLwQ", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "eXaunmL", + "use": "sig", + "alg": "RS256", + "n": "4dGQ7bQK8LgILOdLsYzfZjkEAoQeVC_aqyc8GC6RX7dq_KvRAQAWPvkam8VQv4GK5T4ogklEKEvj5ISBamdDNq1n52TpxQwI2EqxSk7I9fKPKhRt4F8-2yETlYvye-2s6NeWJim0KBtOVrk0gWvEDgd6WOqJl_yt5WBISvILNyVg1qAAM8JeX6dRPosahRVDjA52G2X-Tip84wqwyRpUlq2ybzcLh3zyhCitBOebiRWDQfG26EH9lTlJhll-p_Dg8vAXxJLIJ4SNLcqgFeZe4OfHLgdzMvxXZJnPp_VgmkcpUdRotazKZumj6dBPcXI_XID4Z4Z3OM1KrZPJNdUhxw", + "e": "AQAB" + } + ] + }` + + CERT = `-----BEGIN PRIVATE KEY----- +MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQg+94fs23vSrhBIXNz +OdeRb7+FJkIsVrnTSf7eIYKdf4mgCgYIKoZIzj0DAQehRANCAATyBS3eRgOJ53OQ +LFhGSJw4aiqju7muVwoIWFxCcFJasRwyGcbs0C7vt3xKV/DRJvID4UljaI53wETq +RxlkNCeV +-----END PRIVATE KEY-----` + + BADCERT = `-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCjcGqTkOq0CR3rTx0ZSQSIdTrDrFAYl29611xN8aVgMQIWtDB/ +lD0W5TpKPuU9iaiG/sSn/VYt6EzN7Sr332jj7cyl2WrrHI6ujRswNy4HojMuqtfa +b5FFDpRmCuvl35fge18OvoQTJELhhJ1EvJ5KUeZiuJ3u3YyMnxxXzLuKbQIDAQAB +AoGAPrNDz7TKtaLBvaIuMaMXgBopHyQd3jFKbT/tg2Fu5kYm3PrnmCoQfZYXFKCo +ZUFIS/G1FBVWWGpD/MQ9tbYZkKpwuH+t2rGndMnLXiTC296/s9uix7gsjnT4Naci +5N6EN9pVUBwQmGrYUTHFc58ThtelSiPARX7LSU2ibtJSv8ECQQDWBRrrAYmbCUN7 +ra0DFT6SppaDtvvuKtb+mUeKbg0B8U4y4wCIK5GH8EyQSwUWcXnNBO05rlUPbifs +DLv/u82lAkEAw39sTJ0KmJJyaChqvqAJ8guulKlgucQJ0Et9ppZyet9iVwNKX/aW +9UlwGBMQdafQ36nd1QMEA8AbAw4D+hw/KQJBANJbHDUGQtk2hrSmZNoV5HXB9Uiq +7v4N71k5ER8XwgM5yVGs2tX8dMM3RhnBEtQXXs9LW1uJZSOQcv7JGXNnhN0CQBZe +nzrJAWxh3XtznHtBfsHWelyCYRIAj4rpCHCmaGUM6IjCVKFUawOYKp5mmAyObkUZ +f8ue87emJLEdynC1CLkCQHduNjP1hemAGWrd6v8BHhE3kKtcK6KHsPvJR5dOfzbd +HAqVePERhISfN6cwZt5p8B3/JUwSR8el66DF7Jm57BM= +-----END RSA PRIVATE KEY-----` + + goodSignatureString = `eyJhbGciOiJFUzI1NiIsImtpZCI6ImlkaWRpZGlkIn0.eyJhdWQiOiJodHRwczovL2FwcGxlaWQuYXBwbGUuY29tIiwiZXhwIjoxNTk2ODAwODcwLCJpYXQiOjE1OTY3OTcyNzAsImlzcyI6IjEyMzQ1Njc4OTAiLCJzdWIiOiJjb20uZXhhbXBsZS5hcHAifQ.fuCSrxP5NzkgLM-zjEnUkKn4b_YR0Tbc7_j6MmCor5O9UsM6vpSa51h0SdbXH-l5RYJmGoiVVY6hyug3t5ZPwA` +) + +func TestLoadCertGetSignature(t *testing.T) { + tests := []struct { + name string + cert string + signature string + wantLoadErr bool + forceLoad bool // to check fail at next stage + wantErr bool + }{ + { + name: "bad key", + cert: "bad_key", + wantLoadErr: true, + wantErr: true, + }, + { + name: "bad key wrong format fail load", + cert: BADCERT, + wantLoadErr: true, + forceLoad: false, + }, + { + name: "bad key wrong format force load", + cert: BADCERT, + wantLoadErr: true, + forceLoad: true, + wantErr: true, + }, + { + name: "good key", + cert: CERT, + wantLoadErr: false, + wantErr: false, + }, + } + + srv := setupMockServer(t, "", nil) + c, err := NewClient( + WithBaseURL(srv.URL), + WithCredentials("1234567890", "com.example.app", "idididid"), + ) + require.NoError(t, err) + + var publicKey crypto.PublicKey + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = c.LoadP8CertByByte([]byte(tt.cert)) + if tt.wantLoadErr { + assert.Error(t, err) + if !tt.forceLoad { + return + } + } else { + require.NoError(t, err) + + publicKey = (c.AESCert.(*ecdsa.PrivateKey)).Public() + } + + got, err := c.getSignature() + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err, "expected no error but got %s", err) + require.NotEmpty(t, got, "wanted a secret string returned but got none") + + token, err := jwt.ParseWithClaims( + got, + &jwt.StandardClaims{}, + func(token *jwt.Token) (interface{}, error) { + return publicKey, nil + }) + + require.NoError(t, err, "error while decoding JWT") + require.True(t, token.Valid) + + claims, ok := token.Claims.(*jwt.StandardClaims) + assert.True(t, ok) + + assert.Equal(t, "1234567890", claims.Issuer) + assert.Equal(t, "com.example.app", claims.Subject) + assert.Equal(t, "https://appleid.apple.com", claims.Audience) + assert.Equal(t, c.TokenTTL, claims.ExpiresAt-claims.IssuedAt) + }) + } +} + +// many cases require initial Public Keys response during creating Client or later +func setupMockServer(t *testing.T, expectedRequest string, responseToken *TokenResponse) *httptest.Server { + handler := http.NewServeMux() + + handler.HandleFunc("/keys", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(KEYS)) + }) + + handler.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + s, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + if expectedRequest != "" { + assert.Equal(t, expectedRequest, string(s)) + } + + b, _ := json.Marshal(responseToken) + w.Write(b) + }) + + srv := httptest.NewServer(handler) + + return srv +} diff --git a/error.go b/error.go index d4add8b..57b03f8 100644 --- a/error.go +++ b/error.go @@ -9,6 +9,9 @@ var ( // ErrMissingCert returned, if certificate is missing. ErrMissingCert = errors.New("cert for client not set") + // ErrBadCert returned, if certificate is of bad format. + ErrBadCert = errors.New("cert for client is not in right format") + // ErrFetchPublicKey returned, if client failed fetching public key. ErrFetchPublicKey = errors.New("can't fetch apple public key") diff --git a/go.mod b/go.mod index 89394c4..8768e54 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/jmind-systems/go-apple-signin go 1.14 -require github.com/dgrijalva/jwt-go v3.2.0+incompatible +require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/stretchr/testify v1.6.1 +) diff --git a/go.sum b/go.sum index f964c6b..668d3cc 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,12 @@ -github.com/dgrijalva/jwt-go v1.0.2 h1:KPldsxuKGsS2FPWsNeg9ZO18aCrGKujPoWXn2yo+KQM= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/option.go b/option.go index 3b88a8d..2b84400 100644 --- a/option.go +++ b/option.go @@ -7,10 +7,11 @@ type ClientSettings struct { HTTPClient *http.Client TokenTTL *int64 RedirectURI *string + BaseURL *string - TeamID string - ClientID string - KeyID string + TeamID string // Your Apple Team ID obtained from Apple Developer Account. + ClientID string // Your Service which enable sign-in-with-apple service. + KeyID string // Your Secret Key ID obtained from Apple Developer Account. } // ClientOption is an interface for applying client options. @@ -53,3 +54,9 @@ func WithCredentials(teamID, clientID, keyID string) ClientOption { settings.KeyID = keyID }) } + +func WithBaseURL(url string) ClientOption { + return ClientOptionFunc(func(settings *ClientSettings) { + settings.BaseURL = &url + }) +}