From 338653c8567820524595d11bb1d7b7e86275eca2 Mon Sep 17 00:00:00 2001 From: Louis Duchemin Date: Thu, 16 May 2024 15:00:28 +0200 Subject: [PATCH] fix: unmarshaling of request body when using struct embedding --- huma.go | 19 ++++++++++++------- huma_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/huma.go b/huma.go index 8d0a0a98..9736f80e 100644 --- a/huma.go +++ b/huma.go @@ -569,9 +569,11 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) panic("input must be a struct") } inputParams := findParams(registry, &op, inputType) - inputBodyIndex := -1 + inputBodyIndex := make([]int, 0) + hasInputBody := false if f, ok := inputType.FieldByName("Body"); ok { - inputBodyIndex = f.Index[0] + hasInputBody = true + inputBodyIndex = f.Index if op.RequestBody == nil { required := f.Type.Kind() != reflect.Ptr && f.Type.Kind() != reflect.Interface if f.Tag.Get("required") == "true" { @@ -762,7 +764,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } } - if len(op.Errors) > 0 && (len(inputParams.Paths) > 0 || inputBodyIndex >= -1) { + if len(op.Errors) > 0 && (len(inputParams.Paths) > 0 || hasInputBody) { op.Errors = append(op.Errors, http.StatusUnprocessableEntity) } if len(op.Errors) > 0 { @@ -1124,7 +1126,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) }) // Read input body if defined. - if inputBodyIndex != -1 || rawBodyIndex != -1 { + if hasInputBody || rawBodyIndex != -1 { if op.BodyReadTimeout > 0 { ctx.SetReadDeadline(time.Now().Add(op.BodyReadTimeout)) } else if op.BodyReadTimeout < 0 { @@ -1192,7 +1194,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } } else { parseErrCount := 0 - if inputBodyIndex != -1 && !op.SkipValidateBody { + if hasInputBody && !op.SkipValidateBody { // Validate the input. First, parse the body into []any or map[string]any // or equivalent, which can be easily validated. Then, convert to the // expected struct type to call the handler. @@ -1220,13 +1222,16 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } } - if inputBodyIndex != -1 { + if hasInputBody { // We need to get the body into the correct type now that it has been // validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a // second time is faster than `mapstructure.Decode` or any of the other // common reflection-based approaches when using real-world medium-sized // JSON payloads with lots of strings. - f := v.Field(inputBodyIndex) + f := v + for _, index := range inputBodyIndex { + f = f.Field(index) + } if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil { if parseErrCount == 0 { // Hmm, this should have worked... validator missed something? diff --git a/huma_test.go b/huma_test.go index 15333c7e..4cf3c6d3 100644 --- a/huma_test.go +++ b/huma_test.go @@ -60,6 +60,13 @@ func (UUID) Schema(r huma.Registry) *huma.Schema { return &huma.Schema{Type: huma.TypeString, Format: "uuid"} } +// BodyContainer is an embed request body struct to test request body unmarshalling +type BodyContainer struct { + Body struct { + Name string `json:"name"` + } +} + func TestFeatures(t *testing.T) { for _, feature := range []struct { Name string @@ -624,6 +631,23 @@ func TestFeatures(t *testing.T) { URL: "/body", Body: `{"name": "Name"}`, }, + { + Name: "request-body-embed-struct", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodPost, + Path: "/body", + }, func(ctx context.Context, input *struct { + BodyContainer + }) (*struct{}, error) { + assert.Equal(t, "Name", input.Body.Name) + return nil, nil + }) + }, + Method: http.MethodPost, + URL: "/body", + Body: `{"name": "Name"}`, + }, { Name: "request-ptr-body-required", Register: func(t *testing.T, api huma.API) {