Skip to content

Commit

Permalink
m2moauthserver: fix security vulnerability
Browse files Browse the repository at this point in the history
  • Loading branch information
jkralik authored and Danielius1922 committed Jul 8, 2024
1 parent 9edccd3 commit 96edc2a
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 199 deletions.
11 changes: 9 additions & 2 deletions certificate-authority/service/grpc/signCertificate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,17 @@ func TestCertificateAuthorityServerSignCSRWithDifferentPublicKeys(t *testing.T)
HTTP: config.MakeHttpClientConfig(),
})

tearDown := service.SetUp(ctx, t, service.WithCAConfig(cfg))
m2mCfg := m2mOauthTest.MakeConfig(t)
serviceOAuthClient := m2mOauthTest.ServiceOAuthClient
serviceOAuthClient.InsertTokenClaims = map[string]interface{}{
config.OWNER_CLAIM: oauthService.DeviceUserID,
}
m2mCfg.OAuthSigner.Clients[0] = &serviceOAuthClient

tearDown := service.SetUp(ctx, t, service.WithCAConfig(cfg), service.WithM2MOAuthConfig(m2mCfg))
defer tearDown()

ctx = kitNetGrpc.CtxWithToken(ctx, m2mOauthTest.GetDefaultAccessToken(t, m2mOauthTest.WithAccessTokenOwner(oauthService.DeviceUserID)))
ctx = kitNetGrpc.CtxWithToken(ctx, m2mOauthTest.GetDefaultAccessToken(t))

conn, err := grpc.NewClient(config.CERTIFICATE_AUTHORITY_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: test.GetRootCertificatePool(t),
Expand Down
5 changes: 3 additions & 2 deletions charts/plgd-hub/templates/m2m-oauth-server/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ data:
{{- if $createClient}}
- id: {{ .id }}
secretFile: {{ include "plgd-hub.resolveTemplateString" (list $ .secretFile) }}
requireDeviceID: {{ .requireDeviceID }}
requireOwner: {{ .requireOwner }}
accessTokenLifetime: {{ .accessTokenLifetime }}
allowedGrantTypes:
{{- range .allowedGrantTypes }}
Expand All @@ -63,6 +61,9 @@ data:
{{- range .allowedScopes }}
- {{ . | quote }}
{{- end }}
{{- if .insertTokenClaims }}
insertTokenClaims: {{ .insertTokenClaims | toYaml | nindent 12 }}
{{- end }}
{{- if .jwtPrivateKey }}
{{- if .jwtPrivateKey.enabled }}
jwtPrivateKey:
Expand Down
4 changes: 0 additions & 4 deletions charts/plgd-hub/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2863,8 +2863,6 @@ m2moauthserver:
deviceIDClaim:
clients:
- id: "jwt-private-key"
requireOwner: true
requireDeviceID: false
accessTokenLifetime: 0s
allowedGrantTypes:
- client_credentials
Expand All @@ -2877,8 +2875,6 @@ m2moauthserver:
endpoints:
- id: "service"
secretFile: "{{ include \"plgd-hub.m2moauthserver.getClientServiceSecretFile\" . }}"
requireOwner: false
requireDeviceID: false
accessTokenLifetime: 0s
allowedGrantTypes:
- client_credentials
Expand Down
29 changes: 11 additions & 18 deletions m2m-oauth-server/service/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,20 @@ func (c *PrivateKeyJWTConfig) Validate() error {
}

type Client struct {
ID string `yaml:"id"`
SecretFile urischeme.URIScheme `yaml:"secretFile"`
RequireDeviceID bool `yaml:"requireDeviceID"`
RequireOwner bool `yaml:"requireOwner"`
AccessTokenLifetime time.Duration `yaml:"accessTokenLifetime"`
AllowedGrantTypes []GrantType `yaml:"allowedGrantTypes"`
AllowedAudiences []string `yaml:"allowedAudiences"`
AllowedScopes []string `yaml:"allowedScopes"`
JWTPrivateKey PrivateKeyJWTConfig `yaml:"jwtPrivateKey"`
ID string `yaml:"id"`
SecretFile urischeme.URIScheme `yaml:"secretFile"`
AccessTokenLifetime time.Duration `yaml:"accessTokenLifetime"`
AllowedGrantTypes []GrantType `yaml:"allowedGrantTypes"`
AllowedAudiences []string `yaml:"allowedAudiences"`
AllowedScopes []string `yaml:"allowedScopes"`
JWTPrivateKey PrivateKeyJWTConfig `yaml:"jwtPrivateKey"`
InsertTokenClaims map[string]interface{} `yaml:"insertTokenClaims"`

// runtime
secret string `yaml:"-"`
}

func (c *Client) Validate(ownerClaim, deviceIDClaim string) error {
func (c *Client) Validate() error {
if c.ID == "" {
return fmt.Errorf("id('%v')", c.ID)
}
Expand All @@ -73,7 +72,7 @@ func (c *Client) Validate(ownerClaim, deviceIDClaim string) error {
c.secret = string(data)
}
if len(c.AllowedGrantTypes) == 0 {
return fmt.Errorf("allowedGrantTypes('%v')", c.AllowedGrantTypes)
return fmt.Errorf("allowedGrantTypes('%v') - is empty", c.AllowedGrantTypes)
}
for _, gt := range c.AllowedGrantTypes {
switch gt {
Expand All @@ -82,12 +81,6 @@ func (c *Client) Validate(ownerClaim, deviceIDClaim string) error {
return fmt.Errorf("allowedGrantTypes('%v') - only %v is supported", c.AllowedGrantTypes, GrantTypeClientCredentials)
}
}
if c.RequireDeviceID && deviceIDClaim == "" {
return fmt.Errorf("requireDeviceID('%v') - oauthSigner.deviceIDClaim('%v') is empty", c.RequireDeviceID, deviceIDClaim)
}
if c.RequireOwner && ownerClaim == "" {
return fmt.Errorf("requireOwner('%v') - oauthSigner.ownerClaim('%v') is empty", c.RequireOwner, ownerClaim)
}
if err := c.JWTPrivateKey.Validate(); err != nil {
return fmt.Errorf("privateKeyJWT.%w", err)
}
Expand Down Expand Up @@ -180,7 +173,7 @@ func (c *OAuthSignerConfig) Validate() error {
return fmt.Errorf("clients('%v')", c.Clients)
}
for idx, client := range c.Clients {
if err := client.Validate(c.OwnerClaim, c.DeviceIDClaim); err != nil {
if err := client.Validate(); err != nil {
return fmt.Errorf("clients[%v].%w", idx, err)
}
}
Expand Down
116 changes: 63 additions & 53 deletions m2m-oauth-server/service/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"time"

goJwt "github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
Expand All @@ -21,6 +22,10 @@ func setKeyError(key string, err error) error {
return fmt.Errorf("failed to set %v: %w", key, err)
}

func setKeyErrorExt(key, info interface{}, err error) error {
return fmt.Errorf("failed to set %v('%v'): %w", key, info, err)
}

func makeAccessToken(clientCfg *Client, tokenReq tokenRequest, issuedAt, expires time.Time) (jwt.Token, error) {
token := jwt.New()

Expand Down Expand Up @@ -54,30 +59,49 @@ func makeAccessToken(clientCfg *Client, tokenReq tokenRequest, issuedAt, expires
if err := setOwnerClaim(token, tokenReq); err != nil {
return nil, err
}
if err := setOriginTokenClaims(token, tokenReq); err != nil {
return nil, err
}

for k, v := range clientCfg.InsertTokenClaims {
if _, ok := token.Get(k); ok {
continue
}
if err := token.Set(k, v); err != nil {
return nil, setKeyErrorExt(k, v, err)
}
}

return token, nil
}

func getSubject(clientCfg *Client, tokenReq tokenRequest) string {
if tokenReq.Subject != "" {
return tokenReq.Subject
if tokenReq.subject != "" {
return tokenReq.subject
}
if tokenReq.Owner != "" {
return tokenReq.Owner
if tokenReq.owner != "" {
return tokenReq.owner
}
return clientCfg.ID
}

func setDeviceIDClaim(token jwt.Token, tokenReq tokenRequest) error {
if tokenReq.DeviceID != "" && tokenReq.deviceIDClaim != "" {
return token.Set(tokenReq.deviceIDClaim, tokenReq.DeviceID)
if tokenReq.deviceID != "" && tokenReq.deviceIDClaim != "" {
return token.Set(tokenReq.deviceIDClaim, tokenReq.deviceID)
}
return nil
}

func setOwnerClaim(token jwt.Token, tokenReq tokenRequest) error {
if tokenReq.Owner != "" && tokenReq.ownerClaim != "" {
return token.Set(tokenReq.ownerClaim, tokenReq.Owner)
if tokenReq.owner != "" && tokenReq.ownerClaim != "" {
return token.Set(tokenReq.ownerClaim, tokenReq.owner)
}
return nil
}

func setOriginTokenClaims(token jwt.Token, tokenReq tokenRequest) error {
if len(tokenReq.originalTokenClaims) > 0 {
return token.Set(uri.OriginalTokenClaims, tokenReq.originalTokenClaims)
}
return nil
}
Expand Down Expand Up @@ -123,21 +147,21 @@ func generateAccessToken(clientCfg *Client, tokenReq tokenRequest, key interface

type tokenRequest struct {
ClientID string `json:"client_id"`
GrantType GrantType `json:"grant_type"`
Username string `json:"username"`
Password string `json:"password"`
Secret string `json:"client_secret"`
Audience string `json:"audience"`
DeviceID string `json:"https://plgd.dev/deviceId"`
Owner string `json:"https://plgd.dev/owner"`
Subject string `json:"sub"`
GrantType GrantType `json:"grant_type"`
ClientAssertionType string `json:"client_assertion_type"`
ClientAssertion string `json:"client_assertion"`

host string
scopes string
ownerClaim string
deviceIDClaim string
tokenType AccessTokenType
deviceID string `json:"-"`
owner string `json:"-"`
subject string `json:"-"`
host string `json:"-"`
scopes string `json:"-"`
ownerClaim string `json:"-"`
deviceIDClaim string `json:"-"`
tokenType AccessTokenType `json:"-"`
originalTokenClaims goJwt.MapClaims `json:"-"`
}

func (requestHandler *RequestHandler) getDomain() string {
Expand All @@ -158,12 +182,8 @@ func (requestHandler *RequestHandler) postToken(w http.ResponseWriter, r *http.R
}
tokenReq.GrantType = GrantType(r.PostFormValue(uri.GrantTypeKey))
tokenReq.ClientID = r.PostFormValue(uri.ClientIDKey)
tokenReq.Username = r.PostFormValue(uri.UsernameKey)
tokenReq.Password = r.PostFormValue(uri.PasswordKey)
tokenReq.Audience = r.PostFormValue(uri.AudienceKey)
tokenReq.Owner = r.PostFormValue(uri.OwnerKey)
tokenReq.DeviceID = r.PostFormValue(uri.DeviceIDKey)
tokenReq.Subject = r.PostFormValue(uri.SubjectKey)
tokenReq.Secret = r.PostFormValue(uri.ClientSecretKey)
tokenReq.ClientAssertionType = r.PostFormValue(uri.ClientAssertionTypeKey)
tokenReq.ClientAssertion = r.PostFormValue(uri.ClientAssertionKey)
} else {
Expand All @@ -173,17 +193,18 @@ func (requestHandler *RequestHandler) postToken(w http.ResponseWriter, r *http.R
return
}
}
clientID, password, ok := r.BasicAuth()
clientID, secret, ok := r.BasicAuth()
if ok {
tokenReq.ClientID = clientID
tokenReq.Password = password
tokenReq.Secret = secret
}
requestHandler.processResponse(r.Context(), w, tokenReq)
}

func sliceContains[T comparable](s []T, sub []T) bool {
if len(s) == 0 {
return true
// sub must be non-empty
if len(s) > 0 && len(sub) == 0 {
return false
}
check := make(map[T]struct{}, len(sub))
for _, e := range sub {
Expand Down Expand Up @@ -211,12 +232,6 @@ func (requestHandler *RequestHandler) validateTokenRequest(ctx context.Context,
if err := validateAudience(clientCfg, tokenReq); err != nil {
return err
}
if err := validateDeviceID(clientCfg, tokenReq); err != nil {
return err
}
if err := validateOwner(clientCfg, tokenReq); err != nil {
return err
}

return nil
}
Expand All @@ -225,36 +240,27 @@ func validateClient(clientCfg *Client, tokenReq *tokenRequest) error {
if clientCfg == nil {
return fmt.Errorf("client(%v) not found", tokenReq.ClientID)
}
if clientCfg.secret != "" && !clientCfg.JWTPrivateKey.Enabled && clientCfg.secret != tokenReq.Password {
if clientCfg.secret != "" && !clientCfg.JWTPrivateKey.Enabled && clientCfg.secret != tokenReq.Secret {
return errors.New("invalid client secret")
}
return nil
}

func validateGrantType(clientCfg *Client, tokenReq *tokenRequest) error {
// clientCfg.AllowedGrantTypes is always non-empty
if !sliceContains(clientCfg.AllowedGrantTypes, []GrantType{tokenReq.GrantType}) {
return fmt.Errorf("invalid grant type(%v)", tokenReq.GrantType)
}
return nil
}

func validateAudience(clientCfg *Client, tokenReq *tokenRequest) error {
if !sliceContains(clientCfg.AllowedAudiences, []string{tokenReq.Audience}) {
return fmt.Errorf("invalid audience(%v)", tokenReq.Audience)
}
return nil
}

func validateDeviceID(clientCfg *Client, tokenReq *tokenRequest) error {
if clientCfg.RequireDeviceID && tokenReq.DeviceID == "" {
return errors.New("deviceID is required")
var audiences []string
if tokenReq.Audience != "" {
audiences = []string{tokenReq.Audience}
}
return nil
}

func validateOwner(clientCfg *Client, tokenReq *tokenRequest) error {
if clientCfg.RequireOwner && tokenReq.Owner == "" {
return errors.New("owner is required")
if !sliceContains(clientCfg.AllowedAudiences, audiences) {
return fmt.Errorf("invalid audience(%v)", tokenReq.Audience)
}
return nil
}
Expand All @@ -278,20 +284,24 @@ func (requestHandler *RequestHandler) validateClientAssertion(ctx context.Contex
if err != nil {
return fmt.Errorf("invalid client assertion: %w", err)
}
tokenReq.originalTokenClaims = token
claims := pkgJwt.Claims(token)
owner, err := claims.GetOwner(requestHandler.config.OAuthSigner.OwnerClaim)
if err != nil {
return fmt.Errorf("invalid client assertion - claim owner: %w", err)
}
tokenReq.Owner = owner
tokenReq.owner = owner
sub, err := claims.GetSubject()
if err != nil {
return fmt.Errorf("invalid client assertion - claim sub: %w", err)
}
tokenReq.Subject = sub
tokenReq.subject = sub
if requestHandler.config.OAuthSigner.DeviceIDClaim == "" {
return nil
}
deviceID, err := claims.GetDeviceID(requestHandler.config.OAuthSigner.DeviceIDClaim)
if err == nil {
tokenReq.DeviceID = deviceID
tokenReq.deviceID = deviceID
}
return nil
}
Expand Down
Loading

0 comments on commit 96edc2a

Please sign in to comment.