diff --git a/context.go b/context.go index 82220730..f710e56c 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,7 @@ package chi import ( "context" "net/http" + "slices" "strings" ) @@ -95,6 +96,22 @@ func (x *Context) Reset() { x.parentCtx = nil } +// Clone a routing context so that it may be used outside of the request/response lifecycle. +func (c *Context) Clone() *Context { + clone := *c + + clone.URLParams.Keys = slices.Clone(c.URLParams.Keys) + clone.URLParams.Values = slices.Clone(c.URLParams.Values) + + clone.routeParams.Keys = slices.Clone(c.routeParams.Keys) + clone.routeParams.Values = slices.Clone(c.routeParams.Values) + + clone.RoutePatterns = slices.Clone(c.RoutePatterns) + clone.methodsAllowed = slices.Clone(c.methodsAllowed) + + return &clone +} + // URLParam returns the corresponding URL parameter value from the request // routing context. func (x *Context) URLParam(key string) string { diff --git a/context_test.go b/context_test.go index fa432852..5d89913b 100644 --- a/context_test.go +++ b/context_test.go @@ -102,3 +102,47 @@ func TestReplaceWildcardsConsecutive(t *testing.T) { t.Fatalf("unexpected trailing wildcard behavior: %s", p) } } + +func TestContext_Clone(t *testing.T) { + orig := &Context{ + RoutePatterns: []string{"/v1", "/resources/{id}"}, + methodsAllowed: []methodTyp{mHEAD, mGET}, + URLParams: RouteParams{ + Keys: []string{"foo"}, + Values: []string{"bar"}, + }, + routeParams: RouteParams{ + Keys: []string{"id"}, + Values: []string{"123"}, + }, + } + + clone := orig.Clone() + orig.Reset() + + orig.URLParams.Keys = append(orig.URLParams.Keys, "bar") + orig.URLParams.Values = append(orig.URLParams.Values, "baz") + orig.routeParams.Keys = append(orig.routeParams.Keys, "name") + orig.routeParams.Values = append(orig.routeParams.Values, "foxmulder") + orig.RoutePatterns = append(orig.RoutePatterns, "/mutated") + orig.methodsAllowed = append(orig.methodsAllowed, mPOST) + + if got := clone.URLParams.Keys[0]; got != "foo" { + t.Fatalf("clone URLParams.Keys was corrupted, want %q got %q", "foo", got) + } + if got := clone.URLParams.Values[0]; got != "bar" { + t.Fatalf("clone URLParams.Values was corrupted, want %q got %q", "bar", got) + } + if got := clone.routeParams.Keys[0]; got != "id" { + t.Fatalf("clone routeParams.Keys was corrupted, want %q got %q", "id", got) + } + if got := clone.routeParams.Values[0]; got != "123" { + t.Fatalf("clone routeParams.Values was corrupted, want %q got %q", "123", got) + } + if got := clone.RoutePatterns[0]; got != "/v1" { + t.Fatalf("clone RoutePatterns[0] was corrupted, want %q got %q", "/v1", got) + } + if got := clone.methodsAllowed[0]; got != mHEAD { + t.Fatalf("clone methodsAllowed[0] was corrupted, want %d got %d", mHEAD, got) + } +}