Skip to content

Commit 6ef5c50

Browse files
committed
feat: propagate context to adapters on WithContext
1 parent df27bef commit 6ef5c50

File tree

17 files changed

+556
-0
lines changed

17 files changed

+556
-0
lines changed

adapters/humabunrouter/humabunrouter.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ func (c *bunContext) Version() huma.ProtoVersion {
140140
}
141141
}
142142

143+
func (c *bunContext) WithContext(ctx context.Context) huma.Context {
144+
return &bunContext{
145+
op: c.op,
146+
r: c.r.WithContext(ctx),
147+
w: c.w,
148+
status: c.status,
149+
}
150+
}
151+
143152
// NewContext creates a new Huma context from an HTTP request and response.
144153
func NewContext(op *huma.Operation, r bunrouter.Request, w http.ResponseWriter) huma.Context {
145154
return &bunContext{op: op, r: r, w: w}
@@ -243,6 +252,15 @@ func (c *bunCompatContext) Version() huma.ProtoVersion {
243252
}
244253
}
245254

255+
func (c *bunCompatContext) WithContext(ctx context.Context) huma.Context {
256+
return &bunCompatContext{
257+
op: c.op,
258+
r: c.r.WithContext(ctx),
259+
w: c.w,
260+
status: c.status,
261+
}
262+
}
263+
246264
// NewCompatContext creates a new Huma context from an HTTP request and response.
247265
func NewCompatContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context {
248266
return &bunCompatContext{op: op, r: r, w: w}
@@ -310,3 +328,15 @@ func NewCompat(r *bunrouter.CompatRouter, config huma.Config) huma.API {
310328
func New(r *bunrouter.Router, config huma.Config) huma.API {
311329
return huma.NewAPI(config, NewAdapter(r))
312330
}
331+
332+
func middleware(mw bunrouter.MiddlewareFunc) func(ctx huma.Context, next func(huma.Context)) {
333+
return func(ctx huma.Context, next func(huma.Context)) {
334+
r, w := Unwrap(ctx)
335+
f := mw(func(w http.ResponseWriter, r bunrouter.Request) error {
336+
ctx = NewContext(ctx.Operation(), r, w)
337+
next(ctx)
338+
return nil
339+
})
340+
f(w, r)
341+
}
342+
}

adapters/humabunrouter/humabunrouter_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ import (
1414
"testing"
1515
"time"
1616

17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
1719
"github.com/uptrace/bunrouter"
1820

1921
"github.com/danielgtaylor/huma/v2"
22+
"github.com/danielgtaylor/huma/v2/humatest"
2023
)
2124

2225
var lastModified = time.Now()
@@ -320,3 +323,47 @@ func BenchmarkRawBunRouterFast(b *testing.B) {
320323
r.ServeHTTP(w, req)
321324
}
322325
}
326+
327+
// See https://github.com/danielgtaylor/huma/issues/859
328+
func TestWithValueShouldPropagateContext(t *testing.T) {
329+
r := bunrouter.New()
330+
app := New(r, huma.DefaultConfig("Test", "1.0.0"))
331+
332+
type (
333+
testInput struct{}
334+
testOutput struct{}
335+
ctxKey struct{}
336+
)
337+
338+
ctxValue := "sentinelValue"
339+
340+
huma.Register(app, huma.Operation{
341+
OperationID: "test",
342+
Path: "/test",
343+
Method: http.MethodGet,
344+
Middlewares: huma.Middlewares{
345+
func(ctx huma.Context, next func(huma.Context)) {
346+
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
347+
next(ctx)
348+
},
349+
middleware(func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
350+
return func(w http.ResponseWriter, r bunrouter.Request) error {
351+
val, _ := r.Context().Value(ctxKey{}).(string)
352+
_, err := io.WriteString(w, val)
353+
return err
354+
}
355+
}),
356+
},
357+
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
358+
out := &testOutput{}
359+
return out, nil
360+
})
361+
362+
tapi := humatest.Wrap(t, app)
363+
364+
resp := tapi.Get("/test")
365+
assert.Equal(t, http.StatusOK, resp.Code)
366+
out, err := io.ReadAll(resp.Body)
367+
require.NoError(t, err)
368+
assert.Equal(t, ctxValue, string(out))
369+
}

adapters/humachi/humachi.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ func (c *chiContext) Version() huma.ProtoVersion {
146146
}
147147
}
148148

149+
func (c *chiContext) WithContext(ctx context.Context) huma.Context {
150+
return &chiContext{
151+
op: c.op,
152+
r: c.r.WithContext(ctx),
153+
w: c.w,
154+
status: c.status,
155+
}
156+
}
157+
149158
// NewContext creates a new Huma context from an HTTP request and response.
150159
func NewContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context {
151160
return &chiContext{op: op, r: r, w: w}
@@ -174,3 +183,13 @@ func NewAdapter(r chi.Router) huma.Adapter {
174183
func New(r chi.Router, config huma.Config) huma.API {
175184
return huma.NewAPI(config, &chiAdapter{router: r})
176185
}
186+
187+
func middleware(mw func(http.Handler) http.Handler) func(ctx huma.Context, next func(huma.Context)) {
188+
return func(ctx huma.Context, next func(huma.Context)) {
189+
r, w := Unwrap(ctx)
190+
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
191+
ctx = NewContext(ctx.Operation(), r, w)
192+
next(ctx)
193+
})).ServeHTTP(w, r)
194+
}
195+
}

adapters/humachi/humachi_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,46 @@ func TestPathParamDecoding(t *testing.T) {
438438
// app.ServeHTTP(w, req)
439439
// }
440440
// }
441+
442+
// See https://github.com/danielgtaylor/huma/issues/859
443+
func TestWithValueShouldPropagateContext(t *testing.T) {
444+
r := chi.NewMux()
445+
app := New(r, huma.DefaultConfig("Test", "1.0.0"))
446+
447+
type (
448+
testInput struct{}
449+
testOutput struct{}
450+
ctxKey struct{}
451+
)
452+
453+
ctxValue := "sentinelValue"
454+
455+
huma.Register(app, huma.Operation{
456+
OperationID: "test",
457+
Path: "/test",
458+
Method: http.MethodGet,
459+
Middlewares: huma.Middlewares{
460+
func(ctx huma.Context, next func(huma.Context)) {
461+
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
462+
next(ctx)
463+
},
464+
middleware(func(h http.Handler) http.Handler {
465+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
466+
val, _ := r.Context().Value(ctxKey{}).(string)
467+
io.WriteString(w, val)
468+
})
469+
}),
470+
},
471+
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
472+
out := &testOutput{}
473+
return out, nil
474+
})
475+
476+
tapi := humatest.Wrap(t, app)
477+
478+
resp := tapi.Get("/test")
479+
assert.Equal(t, http.StatusOK, resp.Code)
480+
out, err := io.ReadAll(resp.Body)
481+
require.NoError(t, err)
482+
assert.Equal(t, ctxValue, string(out))
483+
}

adapters/humaecho/humaecho.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ func (c *echoCtx) Version() huma.ProtoVersion {
138138
}
139139
}
140140

141+
func (c *echoCtx) WithContext(ctx context.Context) huma.Context {
142+
new := c.orig
143+
new.SetRequest(new.Request().WithContext(ctx))
144+
return &echoCtx{
145+
op: c.op,
146+
orig: new,
147+
status: c.status,
148+
}
149+
}
150+
141151
type router interface {
142152
Add(method, path string, handler echo.HandlerFunc, middlewares ...echo.MiddlewareFunc) *echo.Route
143153
}
@@ -170,3 +180,15 @@ func New(r *echo.Echo, config huma.Config) huma.API {
170180
func NewWithGroup(r *echo.Echo, g *echo.Group, config huma.Config) huma.API {
171181
return huma.NewAPI(config, &echoAdapter{Handler: r, router: g})
172182
}
183+
184+
func middleware(mw echo.MiddlewareFunc) func(ctx huma.Context, next func(huma.Context)) {
185+
return func(ctx huma.Context, next func(huma.Context)) {
186+
eCtx := Unwrap(ctx)
187+
f := mw(func(c echo.Context) error {
188+
ctx = &echoCtx{op: ctx.Operation(), orig: eCtx}
189+
next(ctx)
190+
return nil
191+
})
192+
f(eCtx)
193+
}
194+
}

adapters/humaecho/humaecho_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ import (
1313
"time"
1414

1515
"github.com/danielgtaylor/huma/v2"
16+
"github.com/danielgtaylor/huma/v2/humatest"
1617
"github.com/labstack/echo/v4"
18+
"github.com/stretchr/testify/assert"
19+
"github.com/stretchr/testify/require"
1720
)
1821

1922
var lastModified = time.Now()
@@ -240,3 +243,47 @@ func BenchmarkRawEchoFast(b *testing.B) {
240243
r.ServeHTTP(w, req)
241244
}
242245
}
246+
247+
// See https://github.com/danielgtaylor/huma/issues/859
248+
func TestWithValueShouldPropagateContext(t *testing.T) {
249+
r := echo.New()
250+
app := New(r, huma.DefaultConfig("Test", "1.0.0"))
251+
252+
type (
253+
testInput struct{}
254+
testOutput struct{}
255+
ctxKey struct{}
256+
)
257+
258+
ctxValue := "sentinelValue"
259+
260+
huma.Register(app, huma.Operation{
261+
OperationID: "test",
262+
Path: "/test",
263+
Method: http.MethodGet,
264+
Middlewares: huma.Middlewares{
265+
func(ctx huma.Context, next func(huma.Context)) {
266+
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
267+
next(ctx)
268+
},
269+
middleware(func(next echo.HandlerFunc) echo.HandlerFunc {
270+
return func(c echo.Context) error {
271+
val, _ := c.Request().Context().Value(ctxKey{}).(string)
272+
_, err := io.WriteString(c.Response().Writer, val)
273+
return err
274+
}
275+
}),
276+
},
277+
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
278+
out := &testOutput{}
279+
return out, nil
280+
})
281+
282+
tapi := humatest.Wrap(t, app)
283+
284+
resp := tapi.Get("/test")
285+
assert.Equal(t, http.StatusOK, resp.Code)
286+
out, err := io.ReadAll(resp.Body)
287+
require.NoError(t, err)
288+
assert.Equal(t, ctxValue, string(out))
289+
}

adapters/humafiber/humafiber.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ func (c *fiberWrapper) Version() huma.ProtoVersion {
153153
}
154154
}
155155

156+
func (c *fiberWrapper) WithContext(ctx context.Context) huma.Context {
157+
new := c.orig
158+
new.SetUserContext(ctx)
159+
return &fiberWrapper{
160+
op: c.op,
161+
status: c.status,
162+
orig: new,
163+
ctx: ctx,
164+
}
165+
}
166+
156167
type router interface {
157168
Add(method, path string, handlers ...fiber.Handler) fiber.Router
158169
}
@@ -242,3 +253,14 @@ func New(r *fiber.App, config huma.Config) huma.API {
242253
func NewWithGroup(r *fiber.App, g fiber.Router, config huma.Config) huma.API {
243254
return huma.NewAPI(config, &fiberAdapter{tester: r, router: g})
244255
}
256+
257+
func middleware(mw func(next fiber.Handler) fiber.Handler) func(ctx huma.Context, next func(huma.Context)) {
258+
return func(ctx huma.Context, next func(huma.Context)) {
259+
fCtx := Unwrap(ctx)
260+
h := mw(func(c *fiber.Ctx) error {
261+
next(ctx)
262+
return nil
263+
})
264+
h(fCtx)
265+
}
266+
}

adapters/humafiber/humafiber_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@ package humafiber
22

33
import (
44
"context"
5+
"io"
56
"net/http"
67
"testing"
78

89
"github.com/danielgtaylor/huma/v2"
10+
"github.com/danielgtaylor/huma/v2/humatest"
911
"github.com/gofiber/fiber/v2"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
1014
)
1115

1216
func BenchmarkHumaFiber(b *testing.B) {
@@ -59,3 +63,46 @@ func BenchmarkNotHuma(b *testing.B) {
5963
r.Test(req)
6064
}
6165
}
66+
67+
func TestWithValueShouldPropagateContext(t *testing.T) {
68+
r := fiber.New()
69+
app := New(r, huma.DefaultConfig("Test", "1.0.0"))
70+
71+
type (
72+
testInput struct{}
73+
testOutput struct{}
74+
ctxKey struct{}
75+
)
76+
77+
ctxValue := "sentinelValue"
78+
79+
huma.Register(app, huma.Operation{
80+
OperationID: "test",
81+
Path: "/test",
82+
Method: http.MethodGet,
83+
Middlewares: huma.Middlewares{
84+
func(ctx huma.Context, next func(huma.Context)) {
85+
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
86+
next(ctx)
87+
},
88+
middleware(func(next fiber.Handler) fiber.Handler {
89+
return func(c *fiber.Ctx) error {
90+
val, _ := c.UserContext().Value(ctxKey{}).(string)
91+
_, err := c.WriteString(val)
92+
return err
93+
}
94+
}),
95+
},
96+
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
97+
out := &testOutput{}
98+
return out, nil
99+
})
100+
101+
tapi := humatest.Wrap(t, app)
102+
103+
resp := tapi.Get("/test")
104+
assert.Equal(t, http.StatusOK, resp.Code)
105+
out, err := io.ReadAll(resp.Body)
106+
require.NoError(t, err)
107+
assert.Equal(t, ctxValue, string(out))
108+
}

0 commit comments

Comments
 (0)