Skip to content

Commit 76cddfc

Browse files
authored
make raw payload available to ExtractRaw (#6)
Capture the request payload and reset the http request body so that `ExtractRaw` can still read the original payload if desired. This is useful in situations like validating a request payload against a signature.
1 parent d25815f commit 76cddfc

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed

apiendpoint/api_endpoint.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package apiendpoint
55

66
import (
7+
"bytes"
78
"context"
89
"encoding/json"
910
"errors"
@@ -146,6 +147,8 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ
146147
return apierror.NewBadRequestf("Error unmarshaling request body: %s.", err)
147148
}
148149
}
150+
151+
r.Body = io.NopCloser(bytes.NewReader(reqData))
149152
}
150153

151154
if rawExtractor, ok := any(&req).(RawExtractor); ok {

apiendpoint/api_endpoint_test.go

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"io"
910
"log/slog"
1011
"net/http"
1112
"net/http/httptest"
@@ -48,35 +49,35 @@ func TestMountAndServe(t *testing.T) {
4849
}
4950
}
5051

51-
t.Run("GetEndpointAndExtractRaw", func(t *testing.T) {
52+
t.Run("GetEndpoint", func(t *testing.T) {
5253
t.Parallel()
5354

5455
mux, bundle := setup(t)
5556

56-
req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", nil)
57+
req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint", nil)
5758
mux.ServeHTTP(bundle.recorder, req)
5859

59-
requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder)
60+
requireStatusAndJSONResponse(t, http.StatusOK, &getResponse{Message: "Hello."}, bundle.recorder)
6061
})
6162

6263
t.Run("BodyIgnoredOnGet", func(t *testing.T) {
6364
t.Parallel()
6465

6566
mux, bundle := setup(t)
6667

67-
req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.",
68+
req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint",
6869
bytes.NewBuffer(mustMarshalJSON(t, &getRequest{IgnoredJSONMessage: "Ignored hello."})))
6970
mux.ServeHTTP(bundle.recorder, req)
7071

71-
requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder)
72+
requireStatusAndJSONResponse(t, http.StatusOK, &getResponse{Message: "Hello."}, bundle.recorder)
7273
})
7374

7475
t.Run("MethodNotAllowed", func(t *testing.T) {
7576
t.Parallel()
7677

7778
mux, bundle := setup(t)
7879

79-
req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint/Hello.", nil)
80+
req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint", nil)
8081
mux.ServeHTTP(bundle.recorder, req)
8182

8283
// This error comes from net/http.
@@ -91,11 +92,11 @@ func TestMountAndServe(t *testing.T) {
9192
mux := http.NewServeMux()
9293
Mount(mux, &postEndpoint{}, nil)
9394

94-
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint",
95-
bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."})))
95+
reqPayload := mustMarshalJSON(t, &postRequest{Message: "Hello."})
96+
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(reqPayload))
9697
mux.ServeHTTP(bundle.recorder, req)
9798

98-
requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder)
99+
requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{ID: "123", Message: "Hello.", RawPayload: reqPayload}, bundle.recorder)
99100
})
100101

101102
t.Run("OptionsWithCustomLogger", func(t *testing.T) {
@@ -104,33 +105,32 @@ func TestMountAndServe(t *testing.T) {
104105
_, bundle := setup(t)
105106

106107
mux := http.NewServeMux()
107-
Mount(mux, &postEndpoint{}, &MountOpts{Logger: bundle.logger})
108+
Mount(mux, &getEndpoint{}, &MountOpts{Logger: bundle.logger})
108109

109-
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint",
110-
bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."})))
110+
req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint", nil)
111111
mux.ServeHTTP(bundle.recorder, req)
112112

113-
requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder)
113+
requireStatusAndJSONResponse(t, http.StatusOK, &getResponse{Message: "Hello."}, bundle.recorder)
114114
})
115115

116-
t.Run("PostEndpoint", func(t *testing.T) {
116+
t.Run("PostEndpointAndExtractRaw", func(t *testing.T) {
117117
t.Parallel()
118118

119119
mux, bundle := setup(t)
120120

121-
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint",
122-
bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."})))
121+
reqPayload := mustMarshalJSON(t, &postRequest{Message: "Hello."})
122+
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(reqPayload))
123123
mux.ServeHTTP(bundle.recorder, req)
124124

125-
requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder)
125+
requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{ID: "123", Message: "Hello.", RawPayload: reqPayload}, bundle.recorder)
126126
})
127127

128128
t.Run("ValidationError", func(t *testing.T) {
129129
t.Parallel()
130130

131131
mux, bundle := setup(t)
132132

133-
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil)
133+
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", nil)
134134
mux.ServeHTTP(bundle.recorder, req)
135135

136136
requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Field `message` is required."}, bundle.recorder)
@@ -141,7 +141,7 @@ func TestMountAndServe(t *testing.T) {
141141

142142
mux, bundle := setup(t)
143143

144-
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint",
144+
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123",
145145
bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeAPIError: true, Message: "Hello."})))
146146
mux.ServeHTTP(bundle.recorder, req)
147147

@@ -153,7 +153,7 @@ func TestMountAndServe(t *testing.T) {
153153

154154
mux, bundle := setup(t)
155155

156-
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint",
156+
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123",
157157
bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true, Message: "Hello."})))
158158
mux.ServeHTTP(bundle.recorder, req)
159159

@@ -168,7 +168,7 @@ func TestMountAndServe(t *testing.T) {
168168
ctx, cancel := context.WithDeadline(ctx, time.Now())
169169
t.Cleanup(cancel)
170170

171-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint",
171+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint/123",
172172
bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."})))
173173
require.NoError(t, err)
174174
mux.ServeHTTP(bundle.recorder, req)
@@ -181,7 +181,7 @@ func TestMountAndServe(t *testing.T) {
181181

182182
mux, bundle := setup(t)
183183

184-
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint",
184+
req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123",
185185
bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true, Message: "Hello."})))
186186
mux.ServeHTTP(bundle.recorder, req)
187187

@@ -274,19 +274,13 @@ type getEndpoint struct {
274274

275275
func (*getEndpoint) Meta() *EndpointMeta {
276276
return &EndpointMeta{
277-
Pattern: "GET /api/get-endpoint/{message}",
277+
Pattern: "GET /api/get-endpoint",
278278
StatusCode: http.StatusOK,
279279
}
280280
}
281281

282282
type getRequest struct {
283283
IgnoredJSONMessage string `json:"ignored_json" validate:"-"`
284-
Message string `json:"-" validate:"required"`
285-
}
286-
287-
func (req *getRequest) ExtractRaw(r *http.Request) error {
288-
req.Message = r.PathValue("message")
289-
return nil
290284
}
291285

292286
type getResponse struct {
@@ -299,7 +293,7 @@ func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse,
299293
return &getResponse{Message: req.IgnoredJSONMessage}, nil
300294
}
301295

302-
return &getResponse{Message: req.Message}, nil
296+
return &getResponse{Message: "Hello."}, nil
303297
}
304298

305299
//
@@ -312,20 +306,34 @@ type postEndpoint struct {
312306

313307
func (*postEndpoint) Meta() *EndpointMeta {
314308
return &EndpointMeta{
315-
Pattern: "POST /api/post-endpoint",
309+
Pattern: "POST /api/post-endpoint/{id}",
316310
StatusCode: http.StatusCreated,
317311
}
318312
}
319313

320314
type postRequest struct {
315+
ID string `json:"-" validate:"-"`
321316
MakeAPIError bool `json:"make_api_error" validate:"-"`
322317
MakeInternalError bool `json:"make_internal_error" validate:"-"`
323318
MakePostgresError bool `json:"make_postgres_error" validate:"-"`
324319
Message string `json:"message" validate:"required"`
320+
RawPayload []byte `json:"-" validate:"-"`
321+
}
322+
323+
func (req *postRequest) ExtractRaw(r *http.Request) error {
324+
var err error
325+
if req.RawPayload, err = io.ReadAll(r.Body); err != nil {
326+
return err
327+
}
328+
329+
req.ID = r.PathValue("id")
330+
return nil
325331
}
326332

327333
type postResponse struct {
328-
Message string `json:"message"`
334+
ID string `json:"id"`
335+
Message string `json:"message"`
336+
RawPayload json.RawMessage `json:"raw_payload"`
329337
}
330338

331339
func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResponse, error) {
@@ -346,5 +354,5 @@ func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResp
346354
return nil, fmt.Errorf("error running Postgres query: %w", &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege})
347355
}
348356

349-
return &postResponse{Message: req.Message}, nil
357+
return &postResponse{ID: req.ID, Message: req.Message, RawPayload: req.RawPayload}, nil
350358
}

0 commit comments

Comments
 (0)