Skip to content

Commit

Permalink
supporting strict() for all libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Sep 16, 2023
1 parent 22b0855 commit 441e06d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
6 changes: 6 additions & 0 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ type Base64Encoding interface {
DecodeString(s string) ([]byte, error)
}

type StrictFunc[T Base64Encoding] func() T

type Stricter[T Base64Encoding] interface {
Strict() T
}

func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding {
return x.Strict()
}

// JSONMarshalFunc is an function type that allows to implement custom JSON
// encoding algorithms.
type JSONMarshalFunc func(v any) ([]byte, error)
Expand Down
25 changes: 14 additions & 11 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@ type Parser struct {
// If populated, only these methods will be considered valid.
validMethods []string

// Use JSON Number format in JSON decoder.
useJSONNumber bool

// Skip claims validation during token parsing.
skipClaimsValidation bool

validator *validator

decoders
decoding
}

type decoders struct {
type decoding struct {
jsonUnmarshal JSONUnmarshalFunc
jsonNewDecoder JSONNewDecoderFunc[JSONDecoder]

rawUrlBase64Encoding Base64Encoding
urlBase64Encoding Base64Encoding
strict StrictFunc[Base64Encoding]

// Use JSON Number format in JSON decoder.
useJSONNumber bool

decodeStrict bool
decodePaddingAllowed bool
Expand Down Expand Up @@ -246,13 +247,15 @@ func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
}

if p.decodeStrict {
// For now we can only support the standard library here because of the
// current state of the type parameter system
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported)
if p.strict != nil {
encoding = p.strict()
} else {
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported)
}
encoding = stricter.Strict()
}
encoding = stricter.Strict()
}

return encoding.DecodeString(seg)
Expand Down
16 changes: 13 additions & 3 deletions parser_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func WithStrictDecoding() ParserOption {
// "github.com/bytedance/sonic"
// )
//
// var parser = NewParser(WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder))
// var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder))
func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption {
return func(p *Parser) {
p.jsonUnmarshal = jsonUnmarshal
Expand Down Expand Up @@ -176,10 +176,20 @@ func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDeco
// asmbase64 "github.com/segmentio/asm/base64"
// )
//
// var parser = NewParser(WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding))
func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption {
// var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding))
func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption {
return func(p *Parser) {
p.rawUrlBase64Encoding = rawURL
p.urlBase64Encoding = url

// Check, whether the library supports the Strict() function
stricter, ok := rawURL.(Stricter[T])
if ok {
// We need to get rid of the type parameter T, so we need to wrap it
// here
p.strict = func() Base64Encoding {
return stricter.Strict()
}
}
}
}

0 comments on commit 441e06d

Please sign in to comment.