diff --git a/docs/docs/features/response-errors.md b/docs/docs/features/response-errors.md index d05ecb6a..bb087e8f 100644 --- a/docs/docs/features/response-errors.md +++ b/docs/docs/features/response-errors.md @@ -90,6 +90,23 @@ flowchart TD This means it is possible to, for example, get an HTTP `408 Request Timeout` response that _also_ contains an error detail with a validation error for one of the input headers. Since request timeout has higher priority, that will be the response status code that is returned. +## Error Headers + +Middleware can be used to add headers to all responses, e.g. for cache control, rate limiting, etc. For headers specific to errors or specific handler error responses, you can wrap the error with additional headers as needed: + +```go title="code.go" hl_lines="1-3" +return nil, huma.ErrorWithHeaders( + huma.Error404NotFound("thing not found"), + http.Header{ + "Cache-Control": {"no-store"}, + }, +) +``` + +It is safe to call `huma.ErrorWithHeaders` multiple times, and all the passed headers will be appended to any existing ones. + +Any error which satisfies the `huma.HeadersError` interface will have the headers added to the response. + ## Custom Errors It is possible to provide your own error model and have the built-in error utility functions use that model instead of the default one. This is useful if you want to provide more information in your error responses or your organization has requirements around the error response structure. @@ -147,6 +164,7 @@ To change the default content type that is returned, you can also implement the - [`huma.ErrorModel`](https://pkg.go.dev/github.com/danielgtaylor/huma/v2#ErrorModel) the default error model - [`huma.ErrorDetail`](https://pkg.go.dev/github.com/danielgtaylor/huma/v2#ErrorDetail) describes location & value of an error - [`huma.StatusError`](https://pkg.go.dev/github.com/danielgtaylor/huma/v2#StatusError) interface for custom errors + - [`huma.HeadersError`](https://pkg.go.dev/github.com/danielgtaylor/huma/v2#HeadersError) interface for errors with headers - [`huma.ContentTypeFilter`](https://pkg.go.dev/github.com/danielgtaylor/huma/v2#ContentTypeFilter) interface for custom content types - External Links - [HTTP Status Codes](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status) diff --git a/error.go b/error.go index 1380ad2f..67017c29 100644 --- a/error.go +++ b/error.go @@ -1,6 +1,7 @@ package huma import ( + "errors" "fmt" "net/http" "strconv" @@ -149,6 +150,50 @@ type StatusError interface { Error() string } +// HeadersError is an error that has HTTP headers. When returned from an +// operation handler, these headers are set on the response before sending it +// to the client. Use `ErrorWithHeaders` to wrap an error like +// `huma.Error400BadRequest` with additional headers. +type HeadersError interface { + GetHeaders() http.Header + Error() string +} + +type errWithHeaders struct { + err error + headers http.Header +} + +func (e *errWithHeaders) Error() string { + return e.err.Error() +} + +func (e *errWithHeaders) Unwrap() error { + return e.err +} + +func (e *errWithHeaders) GetHeaders() http.Header { + return e.headers +} + +// ErrorWithHeaders wraps an error with additional headers to be sent to the +// client. This is useful for e.g. caching, rate limiting, or other metadata. +func ErrorWithHeaders(err error, headers http.Header) error { + var he HeadersError + if errors.As(err, &he) { + // There is already a headers error, so we need to merge the headers. This + // lets you chain multiple calls together and have all the headers set. + orig := he.GetHeaders() + for k, values := range headers { + for _, v := range values { + orig.Add(k, v) + } + } + return err + } + return &errWithHeaders{err: err, headers: headers} +} + // NewError creates a new instance of an error model with the given status code, // message, and optional error details. If the error details implement the // `ErrorDetailer` interface, the error details will be used. Otherwise, the diff --git a/error_test.go b/error_test.go index a28f7fa9..43611af8 100644 --- a/error_test.go +++ b/error_test.go @@ -1,6 +1,7 @@ package huma_test import ( + "context" "errors" "fmt" "net/http" @@ -109,3 +110,30 @@ func TestErrorAs(t *testing.T) { require.ErrorAs(t, err, &e) assert.Equal(t, 400, e.GetStatus()) } + +func TestErrorWithHeaders(t *testing.T) { + _, api := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0")) + huma.Get(api, "/test", func(ctx context.Context, input *struct{}) (*struct{}, error) { + err := huma.ErrorWithHeaders( + huma.Error400BadRequest("test"), + http.Header{ + "My-Header": {"bar"}, + }, + ) + + assert.Equal(t, "test", err.Error()) + + // Call again and have all the headers merged + err = huma.ErrorWithHeaders(err, http.Header{ + "Another": {"bar"}, + }) + + return nil, fmt.Errorf("wrapped: %w", err) + }) + + resp := api.Get("/test") + assert.Equal(t, 400, resp.Code) + assert.Equal(t, "bar", resp.Header().Get("My-Header")) + assert.Equal(t, "bar", resp.Header().Get("Another")) + assert.Contains(t, resp.Body.String(), "test") +} diff --git a/huma.go b/huma.go index 899c7820..20aaa9c0 100644 --- a/huma.go +++ b/huma.go @@ -1275,10 +1275,20 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) output, err := handler(ctx.Context(), &input) if err != nil { + var he HeadersError + if errors.As(err, &he) { + for k, values := range he.GetHeaders() { + for _, v := range values { + ctx.AppendHeader(k, v) + } + } + } + status := http.StatusInternalServerError var se StatusError if errors.As(err, &se) { status = se.GetStatus() + err = se } else { err = NewError(http.StatusInternalServerError, err.Error()) }