From 3dc943fcb4e9ef489d0be910d0d7d7c375c579ad Mon Sep 17 00:00:00 2001 From: Macy Date: Sat, 30 Mar 2024 14:51:23 -0700 Subject: [PATCH 1/6] preferred header takes precedence over lowest 2xx --- mock/mock_engine.go | 45 +++++++++++++++++++-- mock/mock_engine_test.go | 86 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 125 insertions(+), 6 deletions(-) diff --git a/mock/mock_engine.go b/mock/mock_engine.go index 226cd40..b389476 100644 --- a/mock/mock_engine.go +++ b/mock/mock_engine.go @@ -297,11 +297,23 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, } - // get the lowest success code + preferred := rme.extractPreferred(request) lo := rme.findLowestSuccessCode(operation) - // find the lowest success code. - mt, noMT := rme.lookForResponseCodes(operation, request, []string{lo}) + var mt *v3.MediaType + var noMT bool = true + + if preferred != "" { + // If an explicit preferred header is present, let it have a chance to take precedence + // This can lead to a preferred header leading to a 3xx, 4xx, or 5xx example response. + mt, lo, noMT = rme.findMediaTypeContainingNamedExample(operation, request, preferred) + } + + if (noMT) { + // find the lowest success code. + mt, noMT = rme.lookForResponseCodes(operation, request, []string{lo}) + } + if mt == nil && noMT { mtString := rme.extractMediaTypeHeader(request) return rme.buildError( @@ -326,6 +338,33 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, return mock, c, nil } +func (rme *ResponseMockEngine) findMediaTypeContainingNamedExample( + operation *v3.Operation, + request *http.Request, + preferredExample string) (*v3.MediaType, string, bool) { + + mediaTypeString := rme.extractMediaTypeHeader(request) + + for codePairs := operation.Responses.Codes.First(); codePairs != nil; codePairs = codePairs.Next() { + resp := codePairs.Value() + + if resp.Content != nil { + responseBody := resp.Content.GetOrZero(mediaTypeString) + if responseBody == nil { + responseBody = resp.Content.GetOrZero("application/json") + } + + _, present := responseBody.Examples.Get(preferredExample) + + if present { + return responseBody, codePairs.Key(), false + } + } + } + + return nil, "", true +} + func (rme *ResponseMockEngine) findLowestSuccessCode(operation *v3.Operation) string { var lowestCode = 299 diff --git a/mock/mock_engine_test.go b/mock/mock_engine_test.go index 7641f34..29671f1 100644 --- a/mock/mock_engine_test.go +++ b/mock/mock_engine_test.go @@ -6,13 +6,14 @@ package mock import ( "bytes" "encoding/json" + "io" + "net/http" + "testing" + "github.com/pb33f/libopenapi" "github.com/pb33f/libopenapi-validator/helpers" v3 "github.com/pb33f/libopenapi/datamodel/high/v3" "github.com/stretchr/testify/assert" - "io" - "net/http" - "testing" ) // var doc *v3.Document @@ -843,6 +844,85 @@ components: } +// https://github.com/pb33f/wiretap/issues/84 +func TestNewMockEngine_UseExamples_Preferred_From_400(t *testing.T) { + + spec := `openapi: 3.1.0 +paths: + /test: + get: + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Thing' + examples: + happyDays: + value: + name: happy days + description: a terrible show from a time that never existed. + robocop: + value: + name: robocop + description: perhaps the best cyberpunk movie ever made. + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorThing' + examples: + sadErrorDays: + value: + name: sad error days + description: a sad error prone show + sadcop: + value: + name: sad cop + description: perhaps the saddest cyberpunk movie ever made. +components: + schemas: + Thing: + type: object + properties: + name: + type: string + example: nameExample + description: + type: string + example: descriptionExample + ErrorThing: + type: object + properties: + name: + type: string + example: errorNameExample + description: + type: string + example: errorDescriptionExample +` + + d, _ := libopenapi.NewDocument([]byte(spec)) + doc, _ := d.BuildV3Model() + + me := NewMockEngine(&doc.Model, false) + + request, _ := http.NewRequest(http.MethodGet, "https://api.pb33f.io/test", nil) + request.Header.Set(helpers.Preferred, "sadcop") + + b, status, err := me.GenerateResponse(request) + + assert.NoError(t, err) + assert.Equal(t, 400, status) + + var decoded map[string]any + _ = json.Unmarshal(b, &decoded) + + assert.Equal(t, "sad cop", decoded["name"]) + assert.Equal(t, "perhaps the saddest cyberpunk movie ever made.", decoded["description"]) + +} + // https://github.com/pb33f/wiretap/issues/84 func TestNewMockEngine_UseExamples_FromSchema(t *testing.T) { From e66c5032e2ac8a4493379e2d056802732bca72d5 Mon Sep 17 00:00:00 2001 From: Macy Date: Sat, 30 Mar 2024 14:57:17 -0700 Subject: [PATCH 2/6] clean --- mock/mock_engine.go | 3 +- mock/mock_engine_test.go | 101 +++++++++++++++++++-------------------- 2 files changed, 51 insertions(+), 53 deletions(-) diff --git a/mock/mock_engine.go b/mock/mock_engine.go index b389476..d0e1fb8 100644 --- a/mock/mock_engine.go +++ b/mock/mock_engine.go @@ -305,7 +305,8 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, if preferred != "" { // If an explicit preferred header is present, let it have a chance to take precedence - // This can lead to a preferred header leading to a 3xx, 4xx, or 5xx example response. + // This allows a developer to cause a 3xx, 4xx, or 5xx mocked response by passing + // the appropriate example header value. mt, lo, noMT = rme.findMediaTypeContainingNamedExample(operation, request, preferred) } diff --git a/mock/mock_engine_test.go b/mock/mock_engine_test.go index 29671f1..3cb59df 100644 --- a/mock/mock_engine_test.go +++ b/mock/mock_engine_test.go @@ -6,14 +6,13 @@ package mock import ( "bytes" "encoding/json" - "io" - "net/http" - "testing" - "github.com/pb33f/libopenapi" "github.com/pb33f/libopenapi-validator/helpers" v3 "github.com/pb33f/libopenapi/datamodel/high/v3" "github.com/stretchr/testify/assert" + "io" + "net/http" + "testing" ) // var doc *v3.Document @@ -845,7 +844,7 @@ components: } // https://github.com/pb33f/wiretap/issues/84 -func TestNewMockEngine_UseExamples_Preferred_From_400(t *testing.T) { +func TestNewMockEngine_UseExamples_FromSchema(t *testing.T) { spec := `openapi: 3.1.0 paths: @@ -857,29 +856,6 @@ paths: application/json: schema: $ref: '#/components/schemas/Thing' - examples: - happyDays: - value: - name: happy days - description: a terrible show from a time that never existed. - robocop: - value: - name: robocop - description: perhaps the best cyberpunk movie ever made. - '400': - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorThing' - examples: - sadErrorDays: - value: - name: sad error days - description: a sad error prone show - sadcop: - value: - name: sad cop - description: perhaps the saddest cyberpunk movie ever made. components: schemas: Thing: @@ -891,15 +867,6 @@ components: description: type: string example: descriptionExample - ErrorThing: - type: object - properties: - name: - type: string - example: errorNameExample - description: - type: string - example: errorDescriptionExample ` d, _ := libopenapi.NewDocument([]byte(spec)) @@ -908,23 +875,22 @@ components: me := NewMockEngine(&doc.Model, false) request, _ := http.NewRequest(http.MethodGet, "https://api.pb33f.io/test", nil) - request.Header.Set(helpers.Preferred, "sadcop") b, status, err := me.GenerateResponse(request) assert.NoError(t, err) - assert.Equal(t, 400, status) + assert.Equal(t, 200, status) var decoded map[string]any _ = json.Unmarshal(b, &decoded) - assert.Equal(t, "sad cop", decoded["name"]) - assert.Equal(t, "perhaps the saddest cyberpunk movie ever made.", decoded["description"]) + assert.Equal(t, "nameExample", decoded["name"]) + assert.Equal(t, "descriptionExample", decoded["description"]) } // https://github.com/pb33f/wiretap/issues/84 -func TestNewMockEngine_UseExamples_FromSchema(t *testing.T) { +func TestNewMockEngine_UseExamples_FromSchema_Generated(t *testing.T) { spec := `openapi: 3.1.0 paths: @@ -943,10 +909,8 @@ components: properties: name: type: string - example: nameExample description: type: string - example: descriptionExample ` d, _ := libopenapi.NewDocument([]byte(spec)) @@ -964,13 +928,12 @@ components: var decoded map[string]any _ = json.Unmarshal(b, &decoded) - assert.Equal(t, "nameExample", decoded["name"]) - assert.Equal(t, "descriptionExample", decoded["description"]) + assert.NotEmpty(t, decoded["name"]) + assert.NotEmpty(t, decoded["description"]) } -// https://github.com/pb33f/wiretap/issues/84 -func TestNewMockEngine_UseExamples_FromSchema_Generated(t *testing.T) { +func TestNewMockEngine_UseExamples_Preferred_From_400(t *testing.T) { spec := `openapi: 3.1.0 paths: @@ -982,6 +945,29 @@ paths: application/json: schema: $ref: '#/components/schemas/Thing' + examples: + happyDays: + value: + name: happy days + description: a terrible show from a time that never existed. + robocop: + value: + name: robocop + description: perhaps the best cyberpunk movie ever made. + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorThing' + examples: + sadErrorDays: + value: + name: sad error days + description: a sad error prone show + sadcop: + value: + name: sad cop + description: perhaps the saddest cyberpunk movie ever made. components: schemas: Thing: @@ -989,8 +975,19 @@ components: properties: name: type: string + example: nameExample description: type: string + example: descriptionExample + ErrorThing: + type: object + properties: + name: + type: string + example: errorNameExample + description: + type: string + example: errorDescriptionExample ` d, _ := libopenapi.NewDocument([]byte(spec)) @@ -999,16 +996,16 @@ components: me := NewMockEngine(&doc.Model, false) request, _ := http.NewRequest(http.MethodGet, "https://api.pb33f.io/test", nil) + request.Header.Set(helpers.Preferred, "sadcop") b, status, err := me.GenerateResponse(request) assert.NoError(t, err) - assert.Equal(t, 200, status) + assert.Equal(t, 400, status) var decoded map[string]any _ = json.Unmarshal(b, &decoded) - assert.NotEmpty(t, decoded["name"]) - assert.NotEmpty(t, decoded["description"]) - + assert.Equal(t, "sad cop", decoded["name"]) + assert.Equal(t, "perhaps the saddest cyberpunk movie ever made.", decoded["description"]) } From f1f8680b9b77079f95f6eb271ede87647224ba23 Mon Sep 17 00:00:00 2001 From: Macy Date: Mon, 1 Apr 2024 09:18:45 -0700 Subject: [PATCH 3/6] Fix failing test. Add another test to catch possible panic. --- mock/mock_engine.go | 24 ++++++--- mock/mock_engine_test.go | 103 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 8 deletions(-) diff --git a/mock/mock_engine.go b/mock/mock_engine.go index d0e1fb8..7c9c132 100644 --- a/mock/mock_engine.go +++ b/mock/mock_engine.go @@ -248,7 +248,7 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, // check the request is valid against security requirements. err = rme.ValidateSecurity(request, operation) if err != nil { - mt, _ := rme.lookForResponseCodes(operation, request, []string{"401"}) + mt, _ := rme.findBestMediaTypeMatch(operation, request, []string{"401"}) if mt != nil { mock, mockErr := rme.mockEngine.GenerateMock(mt, rme.extractPreferred(request)) if mockErr != nil { @@ -275,7 +275,7 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, // validate the request against the document. _, validationErrors := rme.validator.ValidateHttpRequest(request) if len(validationErrors) > 0 { - mt, _ := rme.lookForResponseCodes(operation, request, []string{"422", "400"}) + mt, _ := rme.findBestMediaTypeMatch(operation, request, []string{"422", "400"}) if mt == nil { // no default, no valid response, inform use with a 500 return rme.buildErrorWithPayload( @@ -298,8 +298,8 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, } preferred := rme.extractPreferred(request) - lo := rme.findLowestSuccessCode(operation) + var lo string var mt *v3.MediaType var noMT bool = true @@ -311,8 +311,9 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, } if (noMT) { - // find the lowest success code. - mt, noMT = rme.lookForResponseCodes(operation, request, []string{lo}) + // When no preferred header is passed, or preferred header did not match a named example + lo = rme.findLowestSuccessCode(operation) + mt, noMT = rme.findBestMediaTypeMatch(operation, request, []string{lo}) } if mt == nil && noMT { @@ -325,7 +326,7 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, ), 415, nil } - mock, mockErr := rme.mockEngine.GenerateMock(mt, rme.extractPreferred(request)) + mock, mockErr := rme.mockEngine.GenerateMock(mt, preferred) if mockErr != nil { return rme.buildError( 422, @@ -355,6 +356,10 @@ func (rme *ResponseMockEngine) findMediaTypeContainingNamedExample( responseBody = resp.Content.GetOrZero("application/json") } + if responseBody == nil { + continue; + } + _, present := responseBody.Examples.Get(preferredExample) if present { @@ -381,14 +386,15 @@ func (rme *ResponseMockEngine) findLowestSuccessCode(operation *v3.Operation) st return fmt.Sprintf("%d", lowestCode) } -func (rme *ResponseMockEngine) lookForResponseCodes( +func (rme *ResponseMockEngine) findBestMediaTypeMatch( op *v3.Operation, request *http.Request, resultCodes []string) (*v3.MediaType, bool) { mediaTypeString := rme.extractMediaTypeHeader(request) - // check if the media type exists in the response. + // Try to find a matching media type in responses matching + // parameterized result codes for _, code := range resultCodes { resp := op.Responses.Codes.GetOrZero(code) @@ -410,6 +416,8 @@ func (rme *ResponseMockEngine) lookForResponseCodes( } } + // As a last resort, check if a default response is specified and attempt + // to use that if op.Responses.Default != nil && op.Responses.Default.Content != nil { if op.Responses.Default.Content.GetOrZero(mediaTypeString) != nil { return op.Responses.Default.Content.GetOrZero(mediaTypeString), false diff --git a/mock/mock_engine_test.go b/mock/mock_engine_test.go index 3cb59df..45df3e7 100644 --- a/mock/mock_engine_test.go +++ b/mock/mock_engine_test.go @@ -1009,3 +1009,106 @@ components: assert.Equal(t, "sad cop", decoded["name"]) assert.Equal(t, "perhaps the saddest cyberpunk movie ever made.", decoded["description"]) } + +func TestNewMockEngine_UseExamples_Preferred_First_200_Has_Hideous_Media_Type(t *testing.T) { +// A little far-fetched for an API to behave this way, +// where lowest 2xx response is html and second is json, +// including the test case to catch a panic case + spec := `openapi: 3.1.0 +paths: + /test: + get: + responses: + '200': + content: + text/html: + schema: + $ref: '#/components/schemas/HtmlThing' + examples: + happyHtmlDays: + value:

Happy Days + robocopInHtml: + value:

Robo cop + '202': + content: + application/json: + schema: + $ref: '#/components/schemas/Thing' + examples: + happyDays: + value: + name: happy days + description: a terrible show from a time that never existed. + robocop: + value: + name: robocop + description: perhaps the best cyberpunk movie ever made. + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorThing' + examples: + sadErrorDays: + value: + name: sad error days + description: a sad error prone show + sadcop: + value: + name: sad cop + description: perhaps the saddest cyberpunk movie ever made. +components: + schemas: + Thing: + type: object + properties: + name: + type: string + example: nameExample + description: + type: string + example: descriptionExample + HtmlThing: + type: string + ErrorThing: + type: object + properties: + name: + type: string + example: errorNameExample + description: + type: string + example: errorDescriptionExample +` + + d, _ := libopenapi.NewDocument([]byte(spec)) + doc, _ := d.BuildV3Model() + + me := NewMockEngine(&doc.Model, false) + + // Check that we don't panic if first 2xx does not match media type + request, _ := http.NewRequest(http.MethodGet, "https://api.pb33f.io/test", nil) + request.Header.Set(helpers.Preferred, "robocop") + + b, status, err := me.GenerateResponse(request) + + assert.NoError(t, err) + assert.Equal(t, 202, status) + + var decoded map[string]any + _ = json.Unmarshal(b, &decoded) + + assert.Equal(t, "robocop", decoded["name"]) + assert.Equal(t, "perhaps the best cyberpunk movie ever made.", decoded["description"]) + + // Now see if html will work + request, _ = http.NewRequest(http.MethodGet, "https://api.pb33f.io/test", nil) + request.Header.Set(helpers.Preferred, "happyHtmlDays") + request.Header.Set("Content-Type", "text/html") + + b, status, err = me.GenerateResponse(request) + + assert.NoError(t, err) + assert.Equal(t, 200, status) + assert.Equal(t, "

Happy Days", string(b[:])) +} \ No newline at end of file From a7d6a8c2c08878570161ca17676bb3dc89bb9ebb Mon Sep 17 00:00:00 2001 From: Macy Date: Mon, 1 Apr 2024 09:23:12 -0700 Subject: [PATCH 4/6] rename test --- mock/mock_engine_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mock/mock_engine_test.go b/mock/mock_engine_test.go index 45df3e7..41bf3f9 100644 --- a/mock/mock_engine_test.go +++ b/mock/mock_engine_test.go @@ -1010,10 +1010,10 @@ components: assert.Equal(t, "perhaps the saddest cyberpunk movie ever made.", decoded["description"]) } -func TestNewMockEngine_UseExamples_Preferred_First_200_Has_Hideous_Media_Type(t *testing.T) { +func TestNewMockEngine_UseExamples_Preferred_200_Not_Json(t *testing.T) { // A little far-fetched for an API to behave this way, // where lowest 2xx response is html and second is json, -// including the test case to catch a panic case +// including the test case just in case spec := `openapi: 3.1.0 paths: /test: From 5e9102a5a9ab83d8537a54c948bef13b9684049c Mon Sep 17 00:00:00 2001 From: Jacob Moore Date: Wed, 3 Apr 2024 14:18:32 -0400 Subject: [PATCH 5/6] Added support for Websockets --- cmd/handle_http_traffic.go | 15 +++ cmd/root_command.go | 24 +++- daemon/handle_request.go | 231 ++++++++++++++++++++++++++++++++----- daemon/wiretap_service.go | 5 +- shared/config.go | 92 ++++++++------- 5 files changed, 294 insertions(+), 73 deletions(-) diff --git a/cmd/handle_http_traffic.go b/cmd/handle_http_traffic.go index f45b300..502fc19 100644 --- a/cmd/handle_http_traffic.go +++ b/cmd/handle_http_traffic.go @@ -27,12 +27,27 @@ func handleHttpTraffic(wiretapConfig *shared.WiretapConfiguration, wtService *da wtService.HandleHttpRequest(requestModel) } + handleWebsocket := func(w http.ResponseWriter, r *http.Request) { + id, _ := uuid.NewUUID() + requestModel := &model.Request{ + Id: &id, + HttpRequest: r, + HttpResponseWriter: w, + } + wtService.HandleWebsocketRequest(requestModel) + } + // create a new mux. mux := http.NewServeMux() // handle the index mux.HandleFunc("/", handleTraffic) + // Handle Websockets + for websocket := range wiretapConfig.WebsocketConfigs { + mux.HandleFunc(websocket, handleWebsocket) + } + pterm.Info.Println(pterm.LightMagenta(fmt.Sprintf("API Gateway UI booting on port %s...", wiretapConfig.Port))) var httpErr error diff --git a/cmd/root_command.go b/cmd/root_command.go index 76a57fe..ab2c853 100644 --- a/cmd/root_command.go +++ b/cmd/root_command.go @@ -349,6 +349,16 @@ var ( printLoadedRedirectAllowList(config.RedirectAllowList) } + if len(config.WebsocketConfigs) > 0 { + for _, config := range config.WebsocketConfigs { + if config.VerifyCert == nil { + config.VerifyCert = func() *bool { b := true; return &b }() + } + } + + printLoadedWebsockets(config.WebsocketConfigs) + } + // static headers if config.Headers != nil && len(config.Headers.DropHeaders) > 0 { pterm.Info.Printf("Dropping the following %d %s globally:\n", len(config.Headers.DropHeaders), @@ -625,8 +635,7 @@ func Execute(version, commit, date string, fs embed.FS) { rootCmd.Flags().IntP("hard-validation-code", "q", 400, "Set a custom http error code for non-compliant requests when using the hard-error flag") rootCmd.Flags().IntP("hard-validation-return-code", "y", 502, "Set a custom http error code for non-compliant responses when using the hard-error flag") rootCmd.Flags().BoolP("mock-mode", "x", false, "Run in mock mode, responses are mocked and no traffic is sent to the target API (requires OpenAPI spec)") - rootCmd.Flags().StringP("config", "c", "", - "Location of wiretap configuration file to use (default is .wiretap in current directory)") + rootCmd.Flags().StringP("config", "c", "", "Location of wiretap configuration file to use (default is .wiretap in current directory)") rootCmd.Flags().StringP("base", "b", "", "Set a base path to resolve relative file references from, or a overriding base URL to resolve remote references from") rootCmd.Flags().BoolP("debug", "l", false, "Enable debug logging") rootCmd.Flags().StringP("har", "z", "", "Load a HAR file instead of sniffing traffic") @@ -706,7 +715,7 @@ func printLoadedIgnoreRedirectPaths(ignoreRedirects []string) { } func printLoadedRedirectAllowList(allowRedirects []string) { - pterm.Info.Printf("Loaded %d allows listed redirect %s :\n", len(allowRedirects), + pterm.Info.Printf("Loaded %d allows listed redirect %s:\n", len(allowRedirects), shared.Pluralize(len(allowRedirects), "path", "paths")) for _, x := range allowRedirects { @@ -714,3 +723,12 @@ func printLoadedRedirectAllowList(allowRedirects []string) { } pterm.Println() } + +func printLoadedWebsockets(websockets map[string]*shared.WiretapWebsocketConfig) { + pterm.Info.Printf("Loaded %d %s: \n", len(websockets), shared.Pluralize(len(websockets), "websocket", "websockets")) + + for websocket := range websockets { + pterm.Printf("🔌 Paths prefixed '%s' will be managed as a websocket\n", pterm.LightCyan(websocket)) + } + pterm.Println() +} diff --git a/daemon/handle_request.go b/daemon/handle_request.go index 6a0239b..b1acfe7 100644 --- a/daemon/handle_request.go +++ b/daemon/handle_request.go @@ -4,8 +4,10 @@ package daemon import ( + "crypto/tls" _ "embed" "fmt" + "github.com/gorilla/websocket" "io" "net/http" "os" @@ -99,32 +101,7 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) { } } - var dropHeaders []string - var injectHeaders map[string]string - - // add global headers with injection. - if config.Headers != nil { - dropHeaders = config.Headers.DropHeaders - injectHeaders = config.Headers.InjectHeaders - } - - // now add path specific headers. - matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config) - auth := "" - if len(matchedPaths) > 0 { - for _, path := range matchedPaths { - auth = path.Auth - if path.Headers != nil { - dropHeaders = append(dropHeaders, path.Headers.DropHeaders...) - newInjectHeaders := path.Headers.InjectHeaders - for key := range injectHeaders { - newInjectHeaders[key] = injectHeaders[key] - } - injectHeaders = newInjectHeaders - } - break - } - } + dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request) newReq := CloneExistingRequest(CloneRequest{ Request: request.HttpRequest, @@ -238,8 +215,210 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) { _, _ = request.HttpResponseWriter.Write(body) } +var gorillaDropHeaders = []string{ + // Gorilla fills in the following headers, and complains if they are already present + "Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol", + "Sec-Websocket-Extensions", +} + +func (ws *WiretapService) handleWebsocketRequest(request *model.Request) { + + configStore, _ := ws.controlsStore.Get(shared.ConfigKey) + config := configStore.(*shared.WiretapConfiguration) + + // Get the Websocket Configuration + websocketUrl := request.HttpRequest.URL.String() + websocketConfig, ok := config.WebsocketConfigs[websocketUrl] + if !ok { + ws.config.Logger.Error(fmt.Sprintf("Unable to find websocket config for URL: %s", websocketUrl)) + } + + // There's nothing to do if we're in mock mode + if config.MockMode { + return + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + // Upgrade the connection from the client to open a websocket connection + clientConn, err := upgrader.Upgrade(request.HttpResponseWriter, request.HttpRequest, nil) + if err != nil { + ws.config.Logger.Error("Unable to upgrade websocket connection") + return + } + defer func(clientConn *websocket.Conn) { + _ = clientConn.Close() + }(clientConn) + + if config.Headers == nil || len(config.Headers.DropHeaders) == 0 { + config.Headers = &shared.WiretapHeaderConfig{ + DropHeaders: []string{}, + } + } + + // Get the updated headers and auth + dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request) + + dropHeaders = append(dropHeaders, gorillaDropHeaders...) + dropHeaders = append(dropHeaders, websocketConfig.DropHeaders...) + + // Determine the correct websocket protocol based on redirect protocol + var protocol string + if config.RedirectProtocol == "https" { + protocol = "wss" + } else if config.RedirectProtocol == "http" { + protocol = "ws" + } else if config.RedirectProtocol != "wss" && config.RedirectProtocol != "ws" { + config.Logger.Error(fmt.Sprintf("Unsupported Redirect Protocol: %s", config.RedirectProtocol)) + return + } + + // Create a new request, which fills in the URL and other information + newRequest := CloneExistingRequest(CloneRequest{ + Request: request.HttpRequest, + Protocol: protocol, + Host: config.RedirectHost, + BasePath: config.RedirectBasePath, + Port: config.RedirectPort, + DropHeaders: dropHeaders, + InjectHeaders: injectHeaders, + Auth: auth, + Variables: config.CompiledVariables, + }) + + // Open a new websocket connection with the server + dialer := *websocket.DefaultDialer + dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: !*websocketConfig.VerifyCert} + serverConn, _, err := dialer.Dial(newRequest.URL.String(), newRequest.Header) + if err != nil { + ws.config.Logger.Error("Unable to create server connection") + return + } + defer func(serverConn *websocket.Conn) { + _ = serverConn.Close() + }(serverConn) + + // Create sentinel channels + clientSentinel := make(chan struct{}) + serverSentinel := make(chan struct{}) + + // Go-Routine for communication between Client -> Server + go func() { + defer close(clientSentinel) + + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + _ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + } + + err = serverConn.WriteMessage(messageType, message) + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + return + } + } + }() + + // Go-Routine for communication between Server -> Client + go func() { + defer close(serverSentinel) + + for { + messageType, message, err := serverConn.ReadMessage() + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + _ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + } + + err = clientConn.WriteMessage(messageType, message) + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + return + } + } + }() + + // Loop until at least one of our sentinel channels have been closed + for { + select { + case <-clientSentinel: + return + case <-serverSentinel: + return + } + } +} + func setCORSHeaders(headers map[string][]string) { headers["Access-Control-Allow-Headers"] = []string{"*"} headers["Access-Control-Allow-Origin"] = []string{"*"} headers["Access-Control-Allow-Methods"] = []string{"OPTIONS,POST,GET,DELETE,PATCH,PUT"} } + +func getCloseCode(err error) (int, bool) { + unexpectedClose := websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure, + ) + + if ce, ok := err.(*websocket.CloseError); ok { + return ce.Code, unexpectedClose + } + return -1, unexpectedClose +} + +func logWebsocketClose(config *shared.WiretapConfiguration, closeCode int, isUnexpected bool) { + if isUnexpected { + config.Logger.Error(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode)) + } else { + config.Logger.Info(fmt.Sprintf("Websocket closed expectedly with code: %d", closeCode)) + } +} + +func (ws *WiretapService) getHeadersAndAuth(config *shared.WiretapConfiguration, request *model.Request) ([]string, map[string]string, string) { + var dropHeaders []string + var injectHeaders map[string]string + + // add global headers with injection. + if config.Headers != nil { + dropHeaders = config.Headers.DropHeaders + injectHeaders = config.Headers.InjectHeaders + } + + // now add path specific headers. + matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config) + auth := "" + if len(matchedPaths) > 0 { + for _, path := range matchedPaths { + auth = path.Auth + if path.Headers != nil { + dropHeaders = append(dropHeaders, path.Headers.DropHeaders...) + newInjectHeaders := path.Headers.InjectHeaders + for key := range injectHeaders { + newInjectHeaders[key] = injectHeaders[key] + } + injectHeaders = newInjectHeaders + } + break + } + } + + return dropHeaders, injectHeaders, auth +} diff --git a/daemon/wiretap_service.go b/daemon/wiretap_service.go index 416bfe0..1b1d7d0 100644 --- a/daemon/wiretap_service.go +++ b/daemon/wiretap_service.go @@ -95,6 +95,9 @@ func (ws *WiretapService) HandleServiceRequest(request *model.Request, core serv } func (ws *WiretapService) HandleHttpRequest(request *model.Request) { - ws.handleHttpRequest(request) } + +func (ws *WiretapService) HandleWebsocketRequest(request *model.Request) { + ws.handleWebsocketRequest(request) +} diff --git a/shared/config.go b/shared/config.go index 05a9a47..e7552cb 100644 --- a/shared/config.go +++ b/shared/config.go @@ -13,49 +13,50 @@ import ( ) type WiretapConfiguration struct { - Contract string `json:"-" yaml:"-"` - RedirectHost string `json:"redirectHost,omitempty" yaml:"redirectHost,omitempty"` - RedirectPort string `json:"redirectPort,omitempty" yaml:"redirectPort,omitempty"` - RedirectBasePath string `json:"redirectBasePath,omitempty" yaml:"redirectBasePath,omitempty"` - RedirectProtocol string `json:"redirectProtocol,omitempty" yaml:"redirectProtocol,omitempty"` - RedirectURL string `json:"redirectURL,omitempty" yaml:"redirectURL,omitempty"` - Port string `json:"port,omitempty" yaml:"port,omitempty"` - MonitorPort string `json:"monitorPort,omitempty" yaml:"monitorPort,omitempty"` - WebSocketHost string `json:"webSocketHost,omitempty" yaml:"webSocketHost,omitempty"` - WebSocketPort string `json:"webSocketPort,omitempty" yaml:"webSocketPort,omitempty"` - GlobalAPIDelay int `json:"globalAPIDelay,omitempty" yaml:"globalAPIDelay,omitempty"` - StaticDir string `json:"staticDir,omitempty" yaml:"staticDir,omitempty"` - StaticIndex string `json:"staticIndex,omitempty" yaml:"staticIndex,omitempty"` - PathConfigurations map[string]*WiretapPathConfig `json:"paths,omitempty" yaml:"paths,omitempty"` - Headers *WiretapHeaderConfig `json:"headers,omitempty" yaml:"headers,omitempty"` - StaticPaths []string `json:"staticPaths,omitempty" yaml:"staticPaths,omitempty"` - Variables map[string]string `json:"variables,omitempty" yaml:"variables,omitempty"` - Spec string `json:"contract,omitempty" yaml:"contract,omitempty"` - Certificate string `json:"certificate,omitempty" yaml:"certificate,omitempty"` - CertificateKey string `json:"certificateKey,omitempty" yaml:"certificateKey,omitempty"` - HardErrors bool `json:"hardValidation,omitempty" yaml:"hardValidation,omitempty"` - HardErrorCode int `json:"hardValidationCode,omitempty" yaml:"hardValidationCode,omitempty"` - HardErrorReturnCode int `json:"hardValidationReturnCode,omitempty" yaml:"hardValidationReturnCode,omitempty"` - PathDelays map[string]int `json:"pathDelays,omitempty" yaml:"pathDelays,omitempty"` - MockMode bool `json:"mockMode,omitempty" yaml:"mockMode,omitempty"` - MockModePretty bool `json:"mockModePretty,omitempty" yaml:"mockModePretty,omitempty"` - Base string `json:"base,omitempty" yaml:"base,omitempty"` - HAR string `json:"har,omitempty" yaml:"har,omitempty"` - HARValidate bool `json:"harValidate,omitempty" yaml:"harValidate,omitempty"` - HARPathAllowList []string `json:"harPathAllowList,omitempty" yaml:"harPathAllowList,omitempty"` - StreamReport bool `json:"streamReport,omitempty" yaml:"streamReport,omitempty"` - ReportFile string `json:"reportFilename,omitempty" yaml:"reportFilename,omitempty"` - IgnoreRedirects []string `json:"ignoreRedirects,omitempty" yaml:"ignoreRedirects,omitempty"` - RedirectAllowList []string `json:"redirectAllowList,omitempty" yaml:"redirectAllowList,omitempty"` - HARFile *harhar.HAR `json:"-" yaml:"-"` - CompiledPathDelays map[string]*CompiledPathDelay `json:"-" yaml:"-"` - CompiledVariables map[string]*CompiledVariable `json:"-" yaml:"-"` - Version string `json:"-" yaml:"-"` - StaticPathsCompiled []glob.Glob `json:"-" yaml:"-"` - CompiledPaths map[string]*CompiledPath `json:"-"` - CompiledIgnoreRedirects []*CompiledRedirect `json:"-" yaml:"-"` - CompiledRedirectAllowList []*CompiledRedirect `json:"-" yaml:"-"` - FS embed.FS `json:"-"` + Contract string `json:"-" yaml:"-"` + RedirectHost string `json:"redirectHost,omitempty" yaml:"redirectHost,omitempty"` + RedirectPort string `json:"redirectPort,omitempty" yaml:"redirectPort,omitempty"` + RedirectBasePath string `json:"redirectBasePath,omitempty" yaml:"redirectBasePath,omitempty"` + RedirectProtocol string `json:"redirectProtocol,omitempty" yaml:"redirectProtocol,omitempty"` + RedirectURL string `json:"redirectURL,omitempty" yaml:"redirectURL,omitempty"` + Port string `json:"port,omitempty" yaml:"port,omitempty"` + MonitorPort string `json:"monitorPort,omitempty" yaml:"monitorPort,omitempty"` + WebSocketHost string `json:"webSocketHost,omitempty" yaml:"webSocketHost,omitempty"` + WebSocketPort string `json:"webSocketPort,omitempty" yaml:"webSocketPort,omitempty"` + GlobalAPIDelay int `json:"globalAPIDelay,omitempty" yaml:"globalAPIDelay,omitempty"` + StaticDir string `json:"staticDir,omitempty" yaml:"staticDir,omitempty"` + StaticIndex string `json:"staticIndex,omitempty" yaml:"staticIndex,omitempty"` + PathConfigurations map[string]*WiretapPathConfig `json:"paths,omitempty" yaml:"paths,omitempty"` + Headers *WiretapHeaderConfig `json:"headers,omitempty" yaml:"headers,omitempty"` + StaticPaths []string `json:"staticPaths,omitempty" yaml:"staticPaths,omitempty"` + Variables map[string]string `json:"variables,omitempty" yaml:"variables,omitempty"` + Spec string `json:"contract,omitempty" yaml:"contract,omitempty"` + Certificate string `json:"certificate,omitempty" yaml:"certificate,omitempty"` + CertificateKey string `json:"certificateKey,omitempty" yaml:"certificateKey,omitempty"` + HardErrors bool `json:"hardValidation,omitempty" yaml:"hardValidation,omitempty"` + HardErrorCode int `json:"hardValidationCode,omitempty" yaml:"hardValidationCode,omitempty"` + HardErrorReturnCode int `json:"hardValidationReturnCode,omitempty" yaml:"hardValidationReturnCode,omitempty"` + PathDelays map[string]int `json:"pathDelays,omitempty" yaml:"pathDelays,omitempty"` + MockMode bool `json:"mockMode,omitempty" yaml:"mockMode,omitempty"` + MockModePretty bool `json:"mockModePretty,omitempty" yaml:"mockModePretty,omitempty"` + Base string `json:"base,omitempty" yaml:"base,omitempty"` + HAR string `json:"har,omitempty" yaml:"har,omitempty"` + HARValidate bool `json:"harValidate,omitempty" yaml:"harValidate,omitempty"` + HARPathAllowList []string `json:"harPathAllowList,omitempty" yaml:"harPathAllowList,omitempty"` + StreamReport bool `json:"streamReport,omitempty" yaml:"streamReport,omitempty"` + ReportFile string `json:"reportFilename,omitempty" yaml:"reportFilename,omitempty"` + IgnoreRedirects []string `json:"ignoreRedirects,omitempty" yaml:"ignoreRedirects,omitempty"` + RedirectAllowList []string `json:"redirectAllowList,omitempty" yaml:"redirectAllowList,omitempty"` + WebsocketConfigs map[string]*WiretapWebsocketConfig `json:"websockets" yaml:"websockets"` + HARFile *harhar.HAR `json:"-" yaml:"-"` + CompiledPathDelays map[string]*CompiledPathDelay `json:"-" yaml:"-"` + CompiledVariables map[string]*CompiledVariable `json:"-" yaml:"-"` + Version string `json:"-" yaml:"-"` + StaticPathsCompiled []glob.Glob `json:"-" yaml:"-"` + CompiledPaths map[string]*CompiledPath `json:"-"` + CompiledIgnoreRedirects []*CompiledRedirect `json:"-" yaml:"-"` + CompiledRedirectAllowList []*CompiledRedirect `json:"-" yaml:"-"` + FS embed.FS `json:"-"` Logger *slog.Logger } @@ -125,6 +126,11 @@ func (wtc *WiretapConfiguration) ReplaceWithVariables(input string) string { return input } +type WiretapWebsocketConfig struct { + VerifyCert *bool `json:"verifyCert" yaml:"verifyCert"` + DropHeaders []string `json:"dropHeaders" yaml:"dropHeaders"` +} + type WiretapPathConfig struct { Target string `json:"target,omitempty" yaml:"target,omitempty"` PathRewrite map[string]string `json:"pathRewrite,omitempty" yaml:"pathRewrite,omitempty"` From 3c27d947966f2952fe1b1a70c9d9b74a5c46b009 Mon Sep 17 00:00:00 2001 From: Jacob Moore Date: Thu, 4 Apr 2024 16:18:38 -0400 Subject: [PATCH 6/6] OPENAPI: addressed MR Comments --- daemon/handle_request.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/daemon/handle_request.go b/daemon/handle_request.go index b1acfe7..d10278b 100644 --- a/daemon/handle_request.go +++ b/daemon/handle_request.go @@ -298,7 +298,7 @@ func (ws *WiretapService) handleWebsocketRequest(request *model.Request) { dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: !*websocketConfig.VerifyCert} serverConn, _, err := dialer.Dial(newRequest.URL.String(), newRequest.Header) if err != nil { - ws.config.Logger.Error("Unable to create server connection") + ws.config.Logger.Error(fmt.Sprintf("Unable to connect to remote server; websocket connection failed: %s", err)) return } defer func(serverConn *websocket.Conn) { @@ -386,7 +386,7 @@ func getCloseCode(err error) (int, bool) { func logWebsocketClose(config *shared.WiretapConfiguration, closeCode int, isUnexpected bool) { if isUnexpected { - config.Logger.Error(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode)) + config.Logger.Warn(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode)) } else { config.Logger.Info(fmt.Sprintf("Websocket closed expectedly with code: %d", closeCode)) }