From 4fe21304aadf240f1604603cc74d44e69a6e06de Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Thu, 9 May 2024 13:33:06 -0700 Subject: [PATCH] Configurable target URI rewrites Add a `TARGET_REWRITES` environment variable that allows rewriting the URI in an encapsulated request. See changes to `README.md` for discussion of motivation and utilization. --- README.md | 36 ++++++ gateway_test.go | 330 ++++++++++++++++++++++++++++++++++++++++++++---- handler.go | 36 +++++- main.go | 12 +- 4 files changed, 384 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 57604ad..c7a318d 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,42 @@ The behavior of the gateway is configurable via a number of environment variable - ALLOWED_TARGET_ORIGINS: This environment variable contains a comma-separated list of target origin names that the gateway is allowed to access. When configured, the gateway will only attempt to resolve requests to target origins in this list. Any other request will yield a HTTP 403 Forbidden return code. - CERT: This environment variable is the name of a file containing the certificate (chain) used to serve TLS connections. - KEY: This environment variable is the name of a file containing the private key used to serve TLS connections. +- TARGET_REWRITES: This environment variable contains a JSON document instructing the gateway to rewrite the target URL found in an encapsulated request to some specified scheme and host. + +### Target URL rewrites + +The `TARGET_REWRITES` configuration option is useful to set up forwarding between the gateway and target when both are on a private network or if they share a loopback interface. For example, suppose that the target is exposed to the internet at `https://example.org`, but also reachable by the gateway at `http://localhost:8080` (note `http` and not `https`). It's more efficient to redirect traffic over `localhost` than back out over the internet, so you could set `TARGET_REWRITES` to: + +```json +{ + "example.org": { "Scheme": "http", "Host": "localhost:8080" } +} +``` + +Then the encapsulated HTTP requests + +```http +POST /some-cool-api HTTP/1.1 +Host: example.org + +some content +``` + +or + +```http +POST https://example.org/some-cool-api HTTP/1.1 + +some content +``` + +...would both be rewritten to: + +```http +POST http://localhost:8080/some-cool-api HTTP/1.1 + +some content +``` ## Custom Application Payloads {#custom-config} diff --git a/gateway_test.go b/gateway_test.go index 082f233..272f4aa 100644 --- a/gateway_test.go +++ b/gateway_test.go @@ -5,11 +5,12 @@ package main import ( "bytes" + "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" + "net/url" "strconv" "strings" "testing" @@ -20,11 +21,16 @@ import ( ) var ( - LEGACY_KEY_ID = uint8(0x00) - CURRENT_KEY_ID = uint8(LEGACY_KEY_ID + 1) - FORBIDDEN_TARGET = "forbidden.example" - ALLOWED_TARGET = "allowed.example" - GATEWAY_DEBUG = true + LEGACY_KEY_ID = uint8(0x00) + CURRENT_KEY_ID = uint8(LEGACY_KEY_ID + 1) + FORBIDDEN_TARGET = "forbidden.example" + ALLOWED_TARGET = "allowed.example" + GATEWAY_DEBUG = true + BINARY_HTTP_GATEWAY_ENDPOINT = "/binary-http-gateway" + TARGET_REWRITES = `{ + "original-1.example": { "Scheme": "http", "Host": "localhost:8888" }, + "original-2.example": { "Scheme": "https", "Host": "localhost:9999"} +}` ) func createGateway(t *testing.T) ohttp.Gateway { @@ -71,27 +77,21 @@ func (f *MockMetricsFactory) Create(eventName string) Metrics { return metrics } -type ForbiddenCheckHttpRequestHandler struct { - forbidden string -} - func mustGetMetricsFactory(t *testing.T, gateway gatewayResource) *MockMetricsFactory { factory, ok := gateway.metricsFactory.(*MockMetricsFactory) if !ok { - panic("Failed to get metrics factory") + t.Fatal("Failed to get metrics factory") } return factory } -func (h ForbiddenCheckHttpRequestHandler) Handle(req *http.Request, metrics Metrics) (*http.Response, error) { - if req.Host == h.forbidden { - metrics.Fire(metricsResultTargetRequestForbidden) - return nil, GatewayTargetForbiddenError - } +type MockHTTPRequestHandler struct{} - metrics.Fire(metricsResultSuccess) +func (d MockHTTPRequestHandler) Handle(req *http.Request, metrics Metrics) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, + // Echo the URL back so tests can examine the scheme and host + Body: io.NopCloser(strings.NewReader(req.URL.String())), }, nil } @@ -101,11 +101,33 @@ func createMockEchoGatewayServer(t *testing.T) gatewayResource { gateway: gateway, appHandler: EchoAppHandler{}, } + + allowedOrigins := make(map[string]bool) + allowedOrigins[ALLOWED_TARGET] = true mockProtoHTTPFilterHandler := DefaultEncapsulationHandler{ gateway: gateway, appHandler: ProtoHTTPAppHandler{ - httpHandler: ForbiddenCheckHttpRequestHandler{ - FORBIDDEN_TARGET, + httpHandler: FilteredHttpRequestHandler{ + client: MockHTTPRequestHandler{}, + allowedOrigins: allowedOrigins, + logForbiddenErrors: false, + targetRewrites: nil, + }, + }, + } + + var targetRewrites map[string]TargetRewrite + if err := json.Unmarshal([]byte(TARGET_REWRITES), &targetRewrites); err != nil { + t.Fatal("failed to unmarshal JSON target rewrites") + } + mockBinaryHTTPFilterHandler := DefaultEncapsulationHandler{ + gateway: gateway, + appHandler: BinaryHTTPAppHandler{ + httpHandler: FilteredHttpRequestHandler{ + client: MockHTTPRequestHandler{}, + allowedOrigins: nil, + logForbiddenErrors: false, + targetRewrites: targetRewrites, }, }, } @@ -113,6 +135,7 @@ func createMockEchoGatewayServer(t *testing.T) gatewayResource { encapHandlers := make(map[string]EncapsulationHandler) encapHandlers[defaultEchoEndpoint] = echoEncapHandler encapHandlers[defaultGatewayEndpoint] = mockProtoHTTPFilterHandler + encapHandlers[BINARY_HTTP_GATEWAY_ENDPOINT] = mockBinaryHTTPFilterHandler return gatewayResource{ gateway: gateway, encapsulationHandlers: encapHandlers, @@ -143,7 +166,7 @@ func TestLegacyConfigHandler(t *testing.T) { t.Fatal(fmt.Errorf("Failed request with error code: %d", status)) } - body, err := ioutil.ReadAll(rr.Result().Body) + body, err := io.ReadAll(rr.Result().Body) if err != nil { t.Fatal("Failed to read body:", err) } @@ -187,7 +210,7 @@ func TestConfigHandler(t *testing.T) { t.Fatal(fmt.Errorf("Failed request with error code: %d", status)) } - body, err := ioutil.ReadAll(rr.Result().Body) + body, err := io.ReadAll(rr.Result().Body) if err != nil { t.Fatal("Failed to read body:", err) } @@ -515,7 +538,7 @@ func TestGatewayHandlerProtoHTTPRequestWithForbiddenTarget(t *testing.T) { t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusOK, status)) } - bodyBytes, err := ioutil.ReadAll(rr.Body) + bodyBytes, err := io.ReadAll(rr.Body) if err != nil { t.Fatal(err) } @@ -591,7 +614,7 @@ func TestGatewayHandlerProtoHTTPRequestWithAllowedTarget(t *testing.T) { t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusOK, status)) } - bodyBytes, err := ioutil.ReadAll(rr.Body) + bodyBytes, err := io.ReadAll(rr.Body) if err != nil { t.Fatal(err) } @@ -617,3 +640,264 @@ func TestGatewayHandlerProtoHTTPRequestWithAllowedTarget(t *testing.T) { testMetricsContainsResult(t, mustGetMetricsFactory(t, target), metricsEventGatewayRequest, metricsResultSuccess) } + +func TestGatewayHandlerBinaryHTTPWithTargetRewrite(t *testing.T) { + target := createMockEchoGatewayServer(t) + + handler := http.HandlerFunc(target.gatewayHandler) + + config, err := target.gateway.Config(CURRENT_KEY_ID) + if err != nil { + t.Fatal(err) + } + client := ohttp.NewDefaultClient(config) + + httpRequest, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s%s", "original-1.example", BINARY_HTTP_GATEWAY_ENDPOINT), nil) + if err != nil { + t.Fatal(err) + } + + binaryRequest := ohttp.BinaryRequest(*httpRequest) + + encodedRequest, err := binaryRequest.Marshal() + if err != nil { + t.Fatal(err) + } + req, context, err := client.EncapsulateRequest(encodedRequest) + if err != nil { + t.Fatal(err) + } + + reqEnc := req.Marshal() + + request, err := http.NewRequest(http.MethodPost, BINARY_HTTP_GATEWAY_ENDPOINT, bytes.NewReader(reqEnc)) + if err != nil { + t.Fatal(err) + } + request.Header.Add("Content-Type", "message/ohttp-req") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + if status := rr.Result().StatusCode; status != http.StatusOK { + t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusOK, status)) + } + + bodyBytes, err := io.ReadAll(rr.Body) + if err != nil { + t.Fatal(err) + } + + encapResp, err := ohttp.UnmarshalEncapsulatedResponse(bodyBytes) + if err != nil { + t.Fatal(err) + } + + binaryResp, err := context.DecapsulateResponse(encapResp) + if err != nil { + t.Fatal(err) + } + + resp, err := ohttp.UnmarshalBinaryResponse(binaryResp) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatal(fmt.Errorf("Encapsulated result did not yield %d, got %d instead", http.StatusForbidden, resp.StatusCode)) + } + + encapsulatedRespBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + rewrittenURL, err := url.Parse(string(encapsulatedRespBody)) + if err != nil { + t.Fatal(err) + } + + if rewrittenURL.Scheme != "http" { + t.Fatalf("rewritten request URL does not have scheme http: %s", rewrittenURL) + } + + if rewrittenURL.Host != "localhost:8888" { + t.Fatalf("rewritten request URL does not have expected host: %s", rewrittenURL.Host) + } + + testMetricsContainsResult(t, mustGetMetricsFactory(t, target), metricsEventGatewayRequest, metricsResultSuccess) +} + +func TestGatewayHandlerBinaryHTTPWithTargetRewriteChangingScheme(t *testing.T) { + target := createMockEchoGatewayServer(t) + + handler := http.HandlerFunc(target.gatewayHandler) + + config, err := target.gateway.Config(CURRENT_KEY_ID) + if err != nil { + t.Fatal(err) + } + client := ohttp.NewDefaultClient(config) + + httpRequest, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s%s", "original-2.example", BINARY_HTTP_GATEWAY_ENDPOINT), nil) + if err != nil { + t.Fatal(err) + } + + binaryRequest := ohttp.BinaryRequest(*httpRequest) + + encodedRequest, err := binaryRequest.Marshal() + if err != nil { + t.Fatal(err) + } + req, context, err := client.EncapsulateRequest(encodedRequest) + if err != nil { + t.Fatal(err) + } + + reqEnc := req.Marshal() + + request, err := http.NewRequest(http.MethodPost, BINARY_HTTP_GATEWAY_ENDPOINT, bytes.NewReader(reqEnc)) + if err != nil { + t.Fatal(err) + } + request.Header.Add("Content-Type", "message/ohttp-req") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + if status := rr.Result().StatusCode; status != http.StatusOK { + t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusOK, status)) + } + + bodyBytes, err := io.ReadAll(rr.Body) + if err != nil { + t.Fatal(err) + } + + encapResp, err := ohttp.UnmarshalEncapsulatedResponse(bodyBytes) + if err != nil { + t.Fatal(err) + } + + binaryResp, err := context.DecapsulateResponse(encapResp) + if err != nil { + t.Fatal(err) + } + + resp, err := ohttp.UnmarshalBinaryResponse(binaryResp) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatal(fmt.Errorf("Encapsulated result did not yield %d, got %d instead", http.StatusForbidden, resp.StatusCode)) + } + + encapsulatedRespBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + rewrittenURL, err := url.Parse(string(encapsulatedRespBody)) + if err != nil { + t.Fatal(err) + } + + if rewrittenURL.Scheme != "https" { + t.Fatalf("rewritten request URL does not have scheme https: %s", rewrittenURL) + } + + if rewrittenURL.Host != "localhost:9999" { + t.Fatalf("rewritten request URL does not have expected host: %s", rewrittenURL.Host) + } + + testMetricsContainsResult(t, mustGetMetricsFactory(t, target), metricsEventGatewayRequest, metricsResultSuccess) +} + +func TestGatewayHandlerBinaryHTTPWithTargetRewriteNoRewrite(t *testing.T) { + target := createMockEchoGatewayServer(t) + + handler := http.HandlerFunc(target.gatewayHandler) + + config, err := target.gateway.Config(CURRENT_KEY_ID) + if err != nil { + t.Fatal(err) + } + client := ohttp.NewDefaultClient(config) + + httpRequest, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s%s", "original-3.example", BINARY_HTTP_GATEWAY_ENDPOINT), nil) + if err != nil { + t.Fatal(err) + } + + binaryRequest := ohttp.BinaryRequest(*httpRequest) + + encodedRequest, err := binaryRequest.Marshal() + if err != nil { + t.Fatal(err) + } + req, context, err := client.EncapsulateRequest(encodedRequest) + if err != nil { + t.Fatal(err) + } + + reqEnc := req.Marshal() + + request, err := http.NewRequest(http.MethodPost, BINARY_HTTP_GATEWAY_ENDPOINT, bytes.NewReader(reqEnc)) + if err != nil { + t.Fatal(err) + } + request.Header.Add("Content-Type", "message/ohttp-req") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + if status := rr.Result().StatusCode; status != http.StatusOK { + t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusOK, status)) + } + + bodyBytes, err := io.ReadAll(rr.Body) + if err != nil { + t.Fatal(err) + } + + encapResp, err := ohttp.UnmarshalEncapsulatedResponse(bodyBytes) + if err != nil { + t.Fatal(err) + } + + binaryResp, err := context.DecapsulateResponse(encapResp) + if err != nil { + t.Fatal(err) + } + + resp, err := ohttp.UnmarshalBinaryResponse(binaryResp) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatal(fmt.Errorf("Encapsulated result did not yield %d, got %d instead", http.StatusForbidden, resp.StatusCode)) + } + + encapsulatedRespBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + rewrittenURL, err := url.Parse(string(encapsulatedRespBody)) + if err != nil { + t.Fatal(err) + } + + if rewrittenURL.Scheme != "http" { + t.Fatalf("rewritten request URL does not have scheme http: %s", rewrittenURL) + } + + if rewrittenURL.Host != "original-3.example" { + t.Fatalf("rewritten request URL does not have expected host: %s", rewrittenURL.Host) + } + + testMetricsContainsResult(t, mustGetMetricsFactory(t, target), metricsEventGatewayRequest, metricsResultSuccess) +} diff --git a/handler.go b/handler.go index c2a9d33..bc0f040 100644 --- a/handler.go +++ b/handler.go @@ -180,7 +180,7 @@ func (h EchoAppHandler) Handle(binaryRequest []byte, metrics Metrics) ([]byte, e // ProtoHTTPAppHandler is an AppContentHandler that parses the application request as // a protobuf-based HTTP request for resolution with an HttpRequestHandler. type ProtoHTTPAppHandler struct { - httpHandler HttpRequestHandler + httpHandler HTTPRequestHandler } // returns the same object format as for PayloadSuccess moving error inside successful response @@ -244,7 +244,7 @@ func (h ProtoHTTPAppHandler) Handle(binaryRequest []byte, metrics Metrics) ([]by // BinaryHTTPAppHandler is an AppContentHandler that parses the application request as // a binary HTTP request for resolution with an HttpRequestHandler. type BinaryHTTPAppHandler struct { - httpHandler HttpRequestHandler + httpHandler HTTPRequestHandler } func (h BinaryHTTPAppHandler) wrappedError(e error, metrics Metrics) ([]byte, error) { @@ -291,17 +291,34 @@ func (h BinaryHTTPAppHandler) Handle(binaryRequest []byte, metrics Metrics) ([]b return binaryRespEnc, r } -// HttpRequestHandler handles HTTP requests to produce responses. -type HttpRequestHandler interface { +// TargetRewrite represents a rewritten target request. +type TargetRewrite struct { + Scheme string + Host string +} + +// HTTPRequestHandler handles HTTP requests to produce responses. +type HTTPRequestHandler interface { // Handle takes a http.Request and resolves it to produce a http.Response. Handle(req *http.Request, metrics Metrics) (*http.Response, error) } +// HTTPClientRequestHandler represents a HttpRequestHandler that handles requests by sending them +// with an http.Client. +type HTTPClientRequestHandler struct { + client *http.Client +} + +func (h HTTPClientRequestHandler) Handle(req *http.Request, metrics Metrics) (*http.Response, error) { + return h.client.Do(req) +} + // FilteredHttpRequestHandler represents a HttpRequestHandler that restricts // outbound HTTP requests to an allowed set of targets. type FilteredHttpRequestHandler struct { - client *http.Client + client HTTPRequestHandler allowedOrigins map[string]bool + targetRewrites map[string]TargetRewrite logForbiddenErrors bool } @@ -320,7 +337,14 @@ func (h FilteredHttpRequestHandler) Handle(req *http.Request, metrics Metrics) ( } } - resp, err := h.client.Do(req) + if h.targetRewrites != nil { + if newTarget, ok := h.targetRewrites[req.URL.Host]; ok { + req.URL.Scheme = newTarget.Scheme + req.URL.Host = newTarget.Host + } + } + + resp, err := h.client.Handle(req, metrics) if err != nil { metrics.Fire(metricsResultTargetRequestFailed) return nil, err diff --git a/main.go b/main.go index a661f7a..76f8226 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "bytes" "crypto/rand" "encoding/hex" + "encoding/json" "flag" "fmt" "io" @@ -59,6 +60,7 @@ const ( gatewayDebugEnvironmentVariable = "GATEWAY_DEBUG" gatewayVerboseEnvironmentVariable = "VERBOSE" logSecretsEnvironmentVariable = "LOG_SECRETS" + targetRewritesVariables = "TARGET_REWRITES" ) var versionFlag = flag.Bool("version", false, "print name and version to stdout") @@ -174,6 +176,13 @@ func main() { } } + var targetRewrites map[string]TargetRewrite + if targetRewritesJson := os.Getenv(targetRewritesVariables); targetRewritesJson != "" { + if err := json.Unmarshal([]byte(targetRewritesJson), &targetRewrites); err != nil { + log.Fatalf("Failed to parse target rewrites: %s", err) + } + } + var certFile string if certFile = os.Getenv(certificateEnvironmentVariable); certFile == "" { certFile = "cert.pem" @@ -209,9 +218,10 @@ func main() { // Create the default HTTP handler httpHandler := FilteredHttpRequestHandler{ - client: &http.Client{}, + client: HTTPClientRequestHandler{client: &http.Client{}}, allowedOrigins: allowedOrigins, logForbiddenErrors: verbose, + targetRewrites: targetRewrites, } // Create the default gateway and its request handler chain