diff --git a/cmd/handle_http_traffic.go b/cmd/handle_http_traffic.go index f45b300..502fc19 100644 --- a/cmd/handle_http_traffic.go +++ b/cmd/handle_http_traffic.go @@ -27,12 +27,27 @@ func handleHttpTraffic(wiretapConfig *shared.WiretapConfiguration, wtService *da wtService.HandleHttpRequest(requestModel) } + handleWebsocket := func(w http.ResponseWriter, r *http.Request) { + id, _ := uuid.NewUUID() + requestModel := &model.Request{ + Id: &id, + HttpRequest: r, + HttpResponseWriter: w, + } + wtService.HandleWebsocketRequest(requestModel) + } + // create a new mux. mux := http.NewServeMux() // handle the index mux.HandleFunc("/", handleTraffic) + // Handle Websockets + for websocket := range wiretapConfig.WebsocketConfigs { + mux.HandleFunc(websocket, handleWebsocket) + } + pterm.Info.Println(pterm.LightMagenta(fmt.Sprintf("API Gateway UI booting on port %s...", wiretapConfig.Port))) var httpErr error diff --git a/cmd/root_command.go b/cmd/root_command.go index 76a57fe..ab2c853 100644 --- a/cmd/root_command.go +++ b/cmd/root_command.go @@ -349,6 +349,16 @@ var ( printLoadedRedirectAllowList(config.RedirectAllowList) } + if len(config.WebsocketConfigs) > 0 { + for _, config := range config.WebsocketConfigs { + if config.VerifyCert == nil { + config.VerifyCert = func() *bool { b := true; return &b }() + } + } + + printLoadedWebsockets(config.WebsocketConfigs) + } + // static headers if config.Headers != nil && len(config.Headers.DropHeaders) > 0 { pterm.Info.Printf("Dropping the following %d %s globally:\n", len(config.Headers.DropHeaders), @@ -625,8 +635,7 @@ func Execute(version, commit, date string, fs embed.FS) { rootCmd.Flags().IntP("hard-validation-code", "q", 400, "Set a custom http error code for non-compliant requests when using the hard-error flag") rootCmd.Flags().IntP("hard-validation-return-code", "y", 502, "Set a custom http error code for non-compliant responses when using the hard-error flag") rootCmd.Flags().BoolP("mock-mode", "x", false, "Run in mock mode, responses are mocked and no traffic is sent to the target API (requires OpenAPI spec)") - rootCmd.Flags().StringP("config", "c", "", - "Location of wiretap configuration file to use (default is .wiretap in current directory)") + rootCmd.Flags().StringP("config", "c", "", "Location of wiretap configuration file to use (default is .wiretap in current directory)") rootCmd.Flags().StringP("base", "b", "", "Set a base path to resolve relative file references from, or a overriding base URL to resolve remote references from") rootCmd.Flags().BoolP("debug", "l", false, "Enable debug logging") rootCmd.Flags().StringP("har", "z", "", "Load a HAR file instead of sniffing traffic") @@ -706,7 +715,7 @@ func printLoadedIgnoreRedirectPaths(ignoreRedirects []string) { } func printLoadedRedirectAllowList(allowRedirects []string) { - pterm.Info.Printf("Loaded %d allows listed redirect %s :\n", len(allowRedirects), + pterm.Info.Printf("Loaded %d allows listed redirect %s:\n", len(allowRedirects), shared.Pluralize(len(allowRedirects), "path", "paths")) for _, x := range allowRedirects { @@ -714,3 +723,12 @@ func printLoadedRedirectAllowList(allowRedirects []string) { } pterm.Println() } + +func printLoadedWebsockets(websockets map[string]*shared.WiretapWebsocketConfig) { + pterm.Info.Printf("Loaded %d %s: \n", len(websockets), shared.Pluralize(len(websockets), "websocket", "websockets")) + + for websocket := range websockets { + pterm.Printf("🔌 Paths prefixed '%s' will be managed as a websocket\n", pterm.LightCyan(websocket)) + } + pterm.Println() +} diff --git a/daemon/handle_request.go b/daemon/handle_request.go index 6a0239b..d10278b 100644 --- a/daemon/handle_request.go +++ b/daemon/handle_request.go @@ -4,8 +4,10 @@ package daemon import ( + "crypto/tls" _ "embed" "fmt" + "github.com/gorilla/websocket" "io" "net/http" "os" @@ -99,32 +101,7 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) { } } - var dropHeaders []string - var injectHeaders map[string]string - - // add global headers with injection. - if config.Headers != nil { - dropHeaders = config.Headers.DropHeaders - injectHeaders = config.Headers.InjectHeaders - } - - // now add path specific headers. - matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config) - auth := "" - if len(matchedPaths) > 0 { - for _, path := range matchedPaths { - auth = path.Auth - if path.Headers != nil { - dropHeaders = append(dropHeaders, path.Headers.DropHeaders...) - newInjectHeaders := path.Headers.InjectHeaders - for key := range injectHeaders { - newInjectHeaders[key] = injectHeaders[key] - } - injectHeaders = newInjectHeaders - } - break - } - } + dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request) newReq := CloneExistingRequest(CloneRequest{ Request: request.HttpRequest, @@ -238,8 +215,210 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) { _, _ = request.HttpResponseWriter.Write(body) } +var gorillaDropHeaders = []string{ + // Gorilla fills in the following headers, and complains if they are already present + "Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol", + "Sec-Websocket-Extensions", +} + +func (ws *WiretapService) handleWebsocketRequest(request *model.Request) { + + configStore, _ := ws.controlsStore.Get(shared.ConfigKey) + config := configStore.(*shared.WiretapConfiguration) + + // Get the Websocket Configuration + websocketUrl := request.HttpRequest.URL.String() + websocketConfig, ok := config.WebsocketConfigs[websocketUrl] + if !ok { + ws.config.Logger.Error(fmt.Sprintf("Unable to find websocket config for URL: %s", websocketUrl)) + } + + // There's nothing to do if we're in mock mode + if config.MockMode { + return + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + // Upgrade the connection from the client to open a websocket connection + clientConn, err := upgrader.Upgrade(request.HttpResponseWriter, request.HttpRequest, nil) + if err != nil { + ws.config.Logger.Error("Unable to upgrade websocket connection") + return + } + defer func(clientConn *websocket.Conn) { + _ = clientConn.Close() + }(clientConn) + + if config.Headers == nil || len(config.Headers.DropHeaders) == 0 { + config.Headers = &shared.WiretapHeaderConfig{ + DropHeaders: []string{}, + } + } + + // Get the updated headers and auth + dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request) + + dropHeaders = append(dropHeaders, gorillaDropHeaders...) + dropHeaders = append(dropHeaders, websocketConfig.DropHeaders...) + + // Determine the correct websocket protocol based on redirect protocol + var protocol string + if config.RedirectProtocol == "https" { + protocol = "wss" + } else if config.RedirectProtocol == "http" { + protocol = "ws" + } else if config.RedirectProtocol != "wss" && config.RedirectProtocol != "ws" { + config.Logger.Error(fmt.Sprintf("Unsupported Redirect Protocol: %s", config.RedirectProtocol)) + return + } + + // Create a new request, which fills in the URL and other information + newRequest := CloneExistingRequest(CloneRequest{ + Request: request.HttpRequest, + Protocol: protocol, + Host: config.RedirectHost, + BasePath: config.RedirectBasePath, + Port: config.RedirectPort, + DropHeaders: dropHeaders, + InjectHeaders: injectHeaders, + Auth: auth, + Variables: config.CompiledVariables, + }) + + // Open a new websocket connection with the server + dialer := *websocket.DefaultDialer + dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: !*websocketConfig.VerifyCert} + serverConn, _, err := dialer.Dial(newRequest.URL.String(), newRequest.Header) + if err != nil { + ws.config.Logger.Error(fmt.Sprintf("Unable to connect to remote server; websocket connection failed: %s", err)) + return + } + defer func(serverConn *websocket.Conn) { + _ = serverConn.Close() + }(serverConn) + + // Create sentinel channels + clientSentinel := make(chan struct{}) + serverSentinel := make(chan struct{}) + + // Go-Routine for communication between Client -> Server + go func() { + defer close(clientSentinel) + + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + _ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + } + + err = serverConn.WriteMessage(messageType, message) + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + return + } + } + }() + + // Go-Routine for communication between Server -> Client + go func() { + defer close(serverSentinel) + + for { + messageType, message, err := serverConn.ReadMessage() + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + _ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + } + + err = clientConn.WriteMessage(messageType, message) + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + return + } + } + }() + + // Loop until at least one of our sentinel channels have been closed + for { + select { + case <-clientSentinel: + return + case <-serverSentinel: + return + } + } +} + func setCORSHeaders(headers map[string][]string) { headers["Access-Control-Allow-Headers"] = []string{"*"} headers["Access-Control-Allow-Origin"] = []string{"*"} headers["Access-Control-Allow-Methods"] = []string{"OPTIONS,POST,GET,DELETE,PATCH,PUT"} } + +func getCloseCode(err error) (int, bool) { + unexpectedClose := websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure, + ) + + if ce, ok := err.(*websocket.CloseError); ok { + return ce.Code, unexpectedClose + } + return -1, unexpectedClose +} + +func logWebsocketClose(config *shared.WiretapConfiguration, closeCode int, isUnexpected bool) { + if isUnexpected { + config.Logger.Warn(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode)) + } else { + config.Logger.Info(fmt.Sprintf("Websocket closed expectedly with code: %d", closeCode)) + } +} + +func (ws *WiretapService) getHeadersAndAuth(config *shared.WiretapConfiguration, request *model.Request) ([]string, map[string]string, string) { + var dropHeaders []string + var injectHeaders map[string]string + + // add global headers with injection. + if config.Headers != nil { + dropHeaders = config.Headers.DropHeaders + injectHeaders = config.Headers.InjectHeaders + } + + // now add path specific headers. + matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config) + auth := "" + if len(matchedPaths) > 0 { + for _, path := range matchedPaths { + auth = path.Auth + if path.Headers != nil { + dropHeaders = append(dropHeaders, path.Headers.DropHeaders...) + newInjectHeaders := path.Headers.InjectHeaders + for key := range injectHeaders { + newInjectHeaders[key] = injectHeaders[key] + } + injectHeaders = newInjectHeaders + } + break + } + } + + return dropHeaders, injectHeaders, auth +} diff --git a/daemon/wiretap_service.go b/daemon/wiretap_service.go index 416bfe0..1b1d7d0 100644 --- a/daemon/wiretap_service.go +++ b/daemon/wiretap_service.go @@ -95,6 +95,9 @@ func (ws *WiretapService) HandleServiceRequest(request *model.Request, core serv } func (ws *WiretapService) HandleHttpRequest(request *model.Request) { - ws.handleHttpRequest(request) } + +func (ws *WiretapService) HandleWebsocketRequest(request *model.Request) { + ws.handleWebsocketRequest(request) +} diff --git a/shared/config.go b/shared/config.go index 05a9a47..e7552cb 100644 --- a/shared/config.go +++ b/shared/config.go @@ -13,49 +13,50 @@ import ( ) type WiretapConfiguration struct { - Contract string `json:"-" yaml:"-"` - RedirectHost string `json:"redirectHost,omitempty" yaml:"redirectHost,omitempty"` - RedirectPort string `json:"redirectPort,omitempty" yaml:"redirectPort,omitempty"` - RedirectBasePath string `json:"redirectBasePath,omitempty" yaml:"redirectBasePath,omitempty"` - RedirectProtocol string `json:"redirectProtocol,omitempty" yaml:"redirectProtocol,omitempty"` - RedirectURL string `json:"redirectURL,omitempty" yaml:"redirectURL,omitempty"` - Port string `json:"port,omitempty" yaml:"port,omitempty"` - MonitorPort string `json:"monitorPort,omitempty" yaml:"monitorPort,omitempty"` - WebSocketHost string `json:"webSocketHost,omitempty" yaml:"webSocketHost,omitempty"` - WebSocketPort string `json:"webSocketPort,omitempty" yaml:"webSocketPort,omitempty"` - GlobalAPIDelay int `json:"globalAPIDelay,omitempty" yaml:"globalAPIDelay,omitempty"` - StaticDir string `json:"staticDir,omitempty" yaml:"staticDir,omitempty"` - StaticIndex string `json:"staticIndex,omitempty" yaml:"staticIndex,omitempty"` - PathConfigurations map[string]*WiretapPathConfig `json:"paths,omitempty" yaml:"paths,omitempty"` - Headers *WiretapHeaderConfig `json:"headers,omitempty" yaml:"headers,omitempty"` - StaticPaths []string `json:"staticPaths,omitempty" yaml:"staticPaths,omitempty"` - Variables map[string]string `json:"variables,omitempty" yaml:"variables,omitempty"` - Spec string `json:"contract,omitempty" yaml:"contract,omitempty"` - Certificate string `json:"certificate,omitempty" yaml:"certificate,omitempty"` - CertificateKey string `json:"certificateKey,omitempty" yaml:"certificateKey,omitempty"` - HardErrors bool `json:"hardValidation,omitempty" yaml:"hardValidation,omitempty"` - HardErrorCode int `json:"hardValidationCode,omitempty" yaml:"hardValidationCode,omitempty"` - HardErrorReturnCode int `json:"hardValidationReturnCode,omitempty" yaml:"hardValidationReturnCode,omitempty"` - PathDelays map[string]int `json:"pathDelays,omitempty" yaml:"pathDelays,omitempty"` - MockMode bool `json:"mockMode,omitempty" yaml:"mockMode,omitempty"` - MockModePretty bool `json:"mockModePretty,omitempty" yaml:"mockModePretty,omitempty"` - Base string `json:"base,omitempty" yaml:"base,omitempty"` - HAR string `json:"har,omitempty" yaml:"har,omitempty"` - HARValidate bool `json:"harValidate,omitempty" yaml:"harValidate,omitempty"` - HARPathAllowList []string `json:"harPathAllowList,omitempty" yaml:"harPathAllowList,omitempty"` - StreamReport bool `json:"streamReport,omitempty" yaml:"streamReport,omitempty"` - ReportFile string `json:"reportFilename,omitempty" yaml:"reportFilename,omitempty"` - IgnoreRedirects []string `json:"ignoreRedirects,omitempty" yaml:"ignoreRedirects,omitempty"` - RedirectAllowList []string `json:"redirectAllowList,omitempty" yaml:"redirectAllowList,omitempty"` - HARFile *harhar.HAR `json:"-" yaml:"-"` - CompiledPathDelays map[string]*CompiledPathDelay `json:"-" yaml:"-"` - CompiledVariables map[string]*CompiledVariable `json:"-" yaml:"-"` - Version string `json:"-" yaml:"-"` - StaticPathsCompiled []glob.Glob `json:"-" yaml:"-"` - CompiledPaths map[string]*CompiledPath `json:"-"` - CompiledIgnoreRedirects []*CompiledRedirect `json:"-" yaml:"-"` - CompiledRedirectAllowList []*CompiledRedirect `json:"-" yaml:"-"` - FS embed.FS `json:"-"` + Contract string `json:"-" yaml:"-"` + RedirectHost string `json:"redirectHost,omitempty" yaml:"redirectHost,omitempty"` + RedirectPort string `json:"redirectPort,omitempty" yaml:"redirectPort,omitempty"` + RedirectBasePath string `json:"redirectBasePath,omitempty" yaml:"redirectBasePath,omitempty"` + RedirectProtocol string `json:"redirectProtocol,omitempty" yaml:"redirectProtocol,omitempty"` + RedirectURL string `json:"redirectURL,omitempty" yaml:"redirectURL,omitempty"` + Port string `json:"port,omitempty" yaml:"port,omitempty"` + MonitorPort string `json:"monitorPort,omitempty" yaml:"monitorPort,omitempty"` + WebSocketHost string `json:"webSocketHost,omitempty" yaml:"webSocketHost,omitempty"` + WebSocketPort string `json:"webSocketPort,omitempty" yaml:"webSocketPort,omitempty"` + GlobalAPIDelay int `json:"globalAPIDelay,omitempty" yaml:"globalAPIDelay,omitempty"` + StaticDir string `json:"staticDir,omitempty" yaml:"staticDir,omitempty"` + StaticIndex string `json:"staticIndex,omitempty" yaml:"staticIndex,omitempty"` + PathConfigurations map[string]*WiretapPathConfig `json:"paths,omitempty" yaml:"paths,omitempty"` + Headers *WiretapHeaderConfig `json:"headers,omitempty" yaml:"headers,omitempty"` + StaticPaths []string `json:"staticPaths,omitempty" yaml:"staticPaths,omitempty"` + Variables map[string]string `json:"variables,omitempty" yaml:"variables,omitempty"` + Spec string `json:"contract,omitempty" yaml:"contract,omitempty"` + Certificate string `json:"certificate,omitempty" yaml:"certificate,omitempty"` + CertificateKey string `json:"certificateKey,omitempty" yaml:"certificateKey,omitempty"` + HardErrors bool `json:"hardValidation,omitempty" yaml:"hardValidation,omitempty"` + HardErrorCode int `json:"hardValidationCode,omitempty" yaml:"hardValidationCode,omitempty"` + HardErrorReturnCode int `json:"hardValidationReturnCode,omitempty" yaml:"hardValidationReturnCode,omitempty"` + PathDelays map[string]int `json:"pathDelays,omitempty" yaml:"pathDelays,omitempty"` + MockMode bool `json:"mockMode,omitempty" yaml:"mockMode,omitempty"` + MockModePretty bool `json:"mockModePretty,omitempty" yaml:"mockModePretty,omitempty"` + Base string `json:"base,omitempty" yaml:"base,omitempty"` + HAR string `json:"har,omitempty" yaml:"har,omitempty"` + HARValidate bool `json:"harValidate,omitempty" yaml:"harValidate,omitempty"` + HARPathAllowList []string `json:"harPathAllowList,omitempty" yaml:"harPathAllowList,omitempty"` + StreamReport bool `json:"streamReport,omitempty" yaml:"streamReport,omitempty"` + ReportFile string `json:"reportFilename,omitempty" yaml:"reportFilename,omitempty"` + IgnoreRedirects []string `json:"ignoreRedirects,omitempty" yaml:"ignoreRedirects,omitempty"` + RedirectAllowList []string `json:"redirectAllowList,omitempty" yaml:"redirectAllowList,omitempty"` + WebsocketConfigs map[string]*WiretapWebsocketConfig `json:"websockets" yaml:"websockets"` + HARFile *harhar.HAR `json:"-" yaml:"-"` + CompiledPathDelays map[string]*CompiledPathDelay `json:"-" yaml:"-"` + CompiledVariables map[string]*CompiledVariable `json:"-" yaml:"-"` + Version string `json:"-" yaml:"-"` + StaticPathsCompiled []glob.Glob `json:"-" yaml:"-"` + CompiledPaths map[string]*CompiledPath `json:"-"` + CompiledIgnoreRedirects []*CompiledRedirect `json:"-" yaml:"-"` + CompiledRedirectAllowList []*CompiledRedirect `json:"-" yaml:"-"` + FS embed.FS `json:"-"` Logger *slog.Logger } @@ -125,6 +126,11 @@ func (wtc *WiretapConfiguration) ReplaceWithVariables(input string) string { return input } +type WiretapWebsocketConfig struct { + VerifyCert *bool `json:"verifyCert" yaml:"verifyCert"` + DropHeaders []string `json:"dropHeaders" yaml:"dropHeaders"` +} + type WiretapPathConfig struct { Target string `json:"target,omitempty" yaml:"target,omitempty"` PathRewrite map[string]string `json:"pathRewrite,omitempty" yaml:"pathRewrite,omitempty"`