diff --git a/encoder.go b/encoder.go new file mode 100644 index 00000000..6f49ec28 --- /dev/null +++ b/encoder.go @@ -0,0 +1,35 @@ +package jwt + +import "io" + +// Base64Encoding represents an object that can encode and decode base64. A +// common example is [encoding/base64.Encoding]. +type Base64Encoding interface { + EncodeToString(src []byte) string + DecodeString(s string) ([]byte, error) +} + +type StrictFunc[T Base64Encoding] func() T + +type Stricter[T Base64Encoding] interface { + Strict() T +} + +func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding { + return x.Strict() +} + +// JSONMarshalFunc is a function type that allows to implement custom JSON +// encoding algorithms. +type JSONMarshalFunc func(v any) ([]byte, error) + +// JSONUnmarshalFunc is a function type that allows to implement custom JSON +// unmarshal algorithms. +type JSONUnmarshalFunc func(data []byte, v any) error + +type JSONDecoder interface { + UseNumber() + Decode(v any) error +} + +type JSONNewDecoderFunc[T JSONDecoder] func(r io.Reader) T diff --git a/errors.go b/errors.go index 23bb616d..a8fe9be9 100644 --- a/errors.go +++ b/errors.go @@ -22,6 +22,7 @@ var ( ErrTokenInvalidId = errors.New("token has invalid id") ErrTokenInvalidClaims = errors.New("token has invalid claims") ErrInvalidType = errors.New("invalid type for claim") + ErrUnsupported = errors.New("operation is unsupported") ) // joinedError is an error type that works similar to what [errors.Join] diff --git a/example_test.go b/example_test.go index 651841de..abd9e38d 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,8 @@ package jwt_test import ( + "encoding/base64" + "encoding/json" "errors" "fmt" "log" @@ -9,6 +11,21 @@ import ( "github.com/golang-jwt/jwt/v5" ) +// Example creating a token by passing jwt.WithJSONEncoder or jwt.WithBase64Encoder to +// options to specify the custom encoders when sign the token to string. +// You can try other encoders when you get tired of the standard library. +func ExampleNew_customEncoder() { + mySigningKey := []byte("AllYourBase") + + customJSONEncoderFunc := json.Marshal + customBase64Encoder := base64.RawURLEncoding + token := jwt.New(jwt.SigningMethodHS256, jwt.WithJSONEncoder(customJSONEncoderFunc), jwt.WithBase64Encoder(customBase64Encoder)) + + ss, err := token.SignedString(mySigningKey) + fmt.Println(ss, err) + // Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.E9f4bo8SFbMyEfLEOEXEO2RGcO9cQhznYfSKqTjWwrM +} + // Example (atypical) using the RegisteredClaims type by itself to parse a token. // The RegisteredClaims type is designed to be embedded into your custom types // to provide standard validation features. You can use it alone, but there's @@ -161,6 +178,35 @@ func ExampleParseWithClaims_customValidation() { // Output: bar test } +// Example parsing a string to a token with using a custom decoders. +// It's convenient to use the jwt.WithJSONDecoder or jwt.WithBase64Decoder options when create a parser +// to parse string to token by using your favorite JSON or Base64 decoders. +func ExampleParseWithClaims_customDecoder() { + tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" + + customJSONUnmarshalFunc := json.Unmarshal + customNewJSONDecoderFunc := json.NewDecoder + + customBase64RawUrlEncoder := base64.RawURLEncoding + customBase64UrlEncoder := base64.URLEncoding + + jwtParser := jwt.NewParser(jwt.WithJSONDecoder(customJSONUnmarshalFunc, customNewJSONDecoderFunc), jwt.WithBase64Decoder(customBase64RawUrlEncoder, customBase64UrlEncoder)) + + token, err := jwtParser.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte("AllYourBase"), nil + }) + if err != nil { + log.Fatal(err) + } + if !token.Valid { + log.Fatal("invalid") + } else { + fmt.Println("valid") + } + + // Output: valid +} + // An example of parsing the error types using errors.Is. func ExampleParse_errorChecking() { // Token from another example. This token is expired diff --git a/parser.go b/parser.go index ecf99af7..5b774e1f 100644 --- a/parser.go +++ b/parser.go @@ -12,16 +12,26 @@ type Parser struct { // If populated, only these methods will be considered valid. validMethods []string - // Use JSON Number format in JSON decoder. - useJSONNumber bool - // Skip claims validation during token parsing. skipClaimsValidation bool validator *Validator - decodeStrict bool + decoding +} + +type decoding struct { + jsonUnmarshal JSONUnmarshalFunc + jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] + + rawUrlBase64Encoding Base64Encoding + urlBase64Encoding Base64Encoding + strict StrictFunc[Base64Encoding] + // Use JSON Number format in JSON decoder. + useJSONNumber bool + + decodeStrict bool decodePaddingAllowed bool } @@ -148,7 +158,18 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke if headerBytes, err = p.DecodeSegment(parts[0]); err != nil { return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err) } - if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + + // Choose our JSON decoder. If no custom function is supplied, we use the standard library. + var unmarshal JSONUnmarshalFunc + if p.jsonUnmarshal != nil { + unmarshal = p.jsonUnmarshal + } else { + unmarshal = json.Unmarshal + } + + // JSON Unmarshal the header + err = unmarshal(headerBytes, &token.Header) + if err != nil { return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err) } @@ -160,25 +181,31 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } - // If `useJSONNumber` is enabled then we must use *json.Decoder to decode - // the claims. However, this comes with a performance penalty so only use - // it if we must and, otherwise, simple use json.Unmarshal. - if !p.useJSONNumber { - // JSON Unmarshal. Special case for map type to avoid weird pointer behavior. - if c, ok := token.Claims.(MapClaims); ok { - err = json.Unmarshal(claimBytes, &c) - } else { - err = json.Unmarshal(claimBytes, &claims) + // If `useJSONNumber` is enabled, then we must use a dedicated JSONDecoder + // to decode the claims. However, this comes with a performance penalty so + // only use it if we must and, otherwise, simple use our existing unmarshal + // function. + if p.useJSONNumber { + unmarshal = func(data []byte, v any) error { + buffer := bytes.NewBuffer(claimBytes) + + var decoder JSONDecoder + if p.jsonNewDecoder != nil { + decoder = p.jsonNewDecoder(buffer) + } else { + decoder = json.NewDecoder(buffer) + } + decoder.UseNumber() + return decoder.Decode(v) } + } + + // JSON Unmarshal the claims. Special case for map type to avoid weird + // pointer behavior. + if c, ok := token.Claims.(MapClaims); ok { + err = unmarshal(claimBytes, &c) } else { - dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) - dec.UseNumber() - // JSON Decode. Special case for map type to avoid weird pointer behavior. - if c, ok := token.Claims.(MapClaims); ok { - err = dec.Decode(&c) - } else { - err = dec.Decode(&claims) - } + err = unmarshal(claimBytes, &claims) } if err != nil { return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err) @@ -200,18 +227,37 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. func (p *Parser) DecodeSegment(seg string) ([]byte, error) { - encoding := base64.RawURLEncoding + var encoding Base64Encoding + if p.rawUrlBase64Encoding != nil { + encoding = p.rawUrlBase64Encoding + } else { + encoding = base64.RawURLEncoding + } if p.decodePaddingAllowed { if l := len(seg) % 4; l > 0 { seg += strings.Repeat("=", 4-l) } - encoding = base64.URLEncoding + + if p.urlBase64Encoding != nil { + encoding = p.urlBase64Encoding + } else { + encoding = base64.URLEncoding + } } if p.decodeStrict { - encoding = encoding.Strict() + if p.strict != nil { + encoding = p.strict() + } else { + stricter, ok := encoding.(Stricter[*base64.Encoding]) + if !ok { + return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported) + } + encoding = stricter.Strict() + } } + return encoding.DecodeString(seg) } diff --git a/parser_option.go b/parser_option.go index 88a780fb..5cf5cd99 100644 --- a/parser_option.go +++ b/parser_option.go @@ -1,6 +1,9 @@ package jwt -import "time" +import ( + "io" + "time" +) // ParserOption is used to implement functional-style options that modify the // behavior of the parser. To add new options, just create a function (ideally @@ -121,8 +124,80 @@ func WithPaddingAllowed() ParserOption { // WithStrictDecoding will switch the codec used for decoding JWTs into strict // mode. In this mode, the decoder requires that trailing padding bits are zero, // as described in RFC 4648 section 3.5. +// +// Note: This is only supported when using [encoding/base64.Encoding], but not +// by any other decoder specified with [WithBase64Decoder]. func WithStrictDecoding() ParserOption { return func(p *Parser) { p.decodeStrict = true } } + +// WithJSONDecoder supports a custom JSON decoder to use in parsing the JWT. +// There are two functions that can be supplied: +// - jsonUnmarshal is a [JSONUnmarshalFunc] that is used for the +// un-marshalling the header and claims when no other options are specified +// - jsonNewDecoder is a [JSONNewDecoderFunc] that is used to create an object +// satisfying the [JSONDecoder] interface. +// +// The latter is used when the [WithJSONNumber] option is used. +// +// If any of the supplied functions is set to nil, the defaults from the Go +// standard library, [encoding/json.Unmarshal] and [encoding/json.NewDecoder] +// are used. +// +// Example using the https://github.com/bytedance/sonic library. +// +// import ( +// "github.com/bytedance/sonic" +// ) +// +// var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder)) +func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption { + return func(p *Parser) { + p.jsonUnmarshal = jsonUnmarshal + // This seems to be necessary, since we don't want to store the specific + // JSONDecoder type in our parser, but need it in the function + // interface. + p.jsonNewDecoder = func(r io.Reader) JSONDecoder { + return jsonNewDecoder(r) + } + } +} + +// WithBase64Decoder supports a custom Base64 when decoding a base64 encoded +// token. Two encoding can be specified: +// - rawURL needs to contain a [Base64Encoding] that is based on base64url +// without padding. This is used for parsing tokens with the default +// options. +// - url needs to contain a [Base64Encoding] based on base64url with padding. +// The sole use of this to decode tokens when [WithPaddingAllowed] is +// enabled. +// +// If any of the supplied encodings are set to nil, the defaults from the Go +// standard library, [encoding/base64.RawURLEncoding] and +// [encoding/base64.URLEncoding] are used. +// +// Example using the https://github.com/segmentio/asm library. +// +// import ( +// asmbase64 "github.com/segmentio/asm/base64" +// ) +// +// var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding)) +func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption { + return func(p *Parser) { + p.rawUrlBase64Encoding = rawURL + p.urlBase64Encoding = url + + // Check, whether the library supports the Strict() function + stricter, ok := rawURL.(Stricter[T]) + if ok { + // We need to get rid of the type parameter T, so we need to wrap it + // here + p.strict = func() Base64Encoding { + return stricter.Strict() + } + } + } +} diff --git a/parser_test.go b/parser_test.go index c0f81711..3319c979 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,6 +3,7 @@ package jwt_test import ( "crypto" "crypto/rsa" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -423,6 +424,39 @@ var jwtTestData = []struct { jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, }, + { + "custom json encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + jwt.NewParser(jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder)), + jwt.SigningMethodRS256, + }, + { + "custom json encoder - use numbers", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + jwt.NewParser( + jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder), + jwt.WithJSONNumber(), + ), + jwt.SigningMethodRS256, + }, + { + "custom base64 encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding, base64.URLEncoding)), + jwt.SigningMethodRS256, + }, { "rejects if exp is required but missing", "", // autogen diff --git a/token.go b/token.go index 352873a2..93b87a36 100644 --- a/token.go +++ b/token.go @@ -34,6 +34,13 @@ type Token struct { Claims Claims // Claims is the second segment of the token in decoded form Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + + encoders +} + +type encoders struct { + jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder + base64Encoding Base64Encoding // base64Encoder is the custom base64 encoding } // New creates a new [Token] with the specified signing method and an empty map @@ -45,7 +52,7 @@ func New(method SigningMethod, opts ...TokenOption) *Token { // NewWithClaims creates a new [Token] with the specified signing method and // claims. Additional options can be specified, but are currently unused. func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *Token { - return &Token{ + t := &Token{ Header: map[string]interface{}{ "typ": "JWT", "alg": method.Alg(), @@ -53,6 +60,10 @@ func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *To Claims: claims, Method: method, } + for _, opt := range opts { + opt(t) + } + return t } // SignedString creates and returns a complete, signed JWT. The token is signed @@ -78,12 +89,19 @@ func (t *Token) SignedString(key interface{}) (string, error) { // of the whole deal. Unless you need this for something special, just go // straight for the SignedString. func (t *Token) SigningString() (string, error) { - h, err := json.Marshal(t.Header) + var marshal JSONMarshalFunc + if t.jsonMarshal != nil { + marshal = t.jsonMarshal + } else { + marshal = json.Marshal + } + + h, err := marshal(t.Header) if err != nil { return "", err } - c, err := json.Marshal(t.Claims) + c, err := marshal(t.Claims) if err != nil { return "", err } @@ -95,6 +113,13 @@ func (t *Token) SigningString() (string, error) { // stripped. In the future, this function might take into account a // [TokenOption]. Therefore, this function exists as a method of [Token], rather // than a global function. -func (*Token) EncodeSegment(seg []byte) string { - return base64.RawURLEncoding.EncodeToString(seg) +func (t *Token) EncodeSegment(seg []byte) string { + var enc Base64Encoding + if t.base64Encoding != nil { + enc = t.base64Encoding + } else { + enc = base64.RawURLEncoding + } + + return enc.EncodeToString(seg) } diff --git a/token_option.go b/token_option.go index b4ae3bad..0fab6a37 100644 --- a/token_option.go +++ b/token_option.go @@ -3,3 +3,15 @@ package jwt // TokenOption is a reserved type, which provides some forward compatibility, // if we ever want to introduce token creation-related options. type TokenOption func(*Token) + +func WithJSONEncoder(f JSONMarshalFunc) TokenOption { + return func(token *Token) { + token.jsonMarshal = f + } +} + +func WithBase64Encoder(enc Base64Encoding) TokenOption { + return func(token *Token) { + token.base64Encoding = enc + } +} diff --git a/token_test.go b/token_test.go index f18329e0..ff9eefe5 100644 --- a/token_test.go +++ b/token_test.go @@ -1,6 +1,8 @@ package jwt_test import ( + "encoding/base64" + "encoding/json" "testing" "github.com/golang-jwt/jwt/v5" @@ -14,6 +16,7 @@ func TestToken_SigningString(t1 *testing.T) { Claims jwt.Claims Signature []byte Valid bool + Options []jwt.TokenOption } tests := []struct { name string @@ -23,6 +26,22 @@ func TestToken_SigningString(t1 *testing.T) { }{ { name: "", + fields: fields{ + Raw: "", + Method: jwt.SigningMethodHS256, + Header: map[string]interface{}{ + "typ": "JWT", + "alg": jwt.SigningMethodHS256.Alg(), + }, + Claims: jwt.RegisteredClaims{}, + Valid: false, + Options: nil, + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", + wantErr: false, + }, + { + name: "encode with custom json and base64 encoder", fields: fields{ Raw: "", Method: jwt.SigningMethodHS256, @@ -32,6 +51,10 @@ func TestToken_SigningString(t1 *testing.T) { }, Claims: jwt.RegisteredClaims{}, Valid: false, + Options: []jwt.TokenOption{ + jwt.WithJSONEncoder(json.Marshal), + jwt.WithBase64Encoder(base64.StdEncoding), + }, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", wantErr: false,