Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Websockets #107

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cmd/handle_http_traffic.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,27 @@ func handleHttpTraffic(wiretapConfig *shared.WiretapConfiguration, wtService *da
wtService.HandleHttpRequest(requestModel)
}

handleWebsocket := func(w http.ResponseWriter, r *http.Request) {
id, _ := uuid.NewUUID()
requestModel := &model.Request{
Id: &id,
HttpRequest: r,
HttpResponseWriter: w,
}
wtService.HandleWebsocketRequest(requestModel)
}

// create a new mux.
mux := http.NewServeMux()

// handle the index
mux.HandleFunc("/", handleTraffic)

// Handle Websockets
for websocket := range wiretapConfig.WebsocketConfigs {
mux.HandleFunc(websocket, handleWebsocket)
}

pterm.Info.Println(pterm.LightMagenta(fmt.Sprintf("API Gateway UI booting on port %s...", wiretapConfig.Port)))

var httpErr error
Expand Down
24 changes: 21 additions & 3 deletions cmd/root_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ var (
printLoadedRedirectAllowList(config.RedirectAllowList)
}

if len(config.WebsocketConfigs) > 0 {
for _, config := range config.WebsocketConfigs {
if config.VerifyCert == nil {
config.VerifyCert = func() *bool { b := true; return &b }()
}
}

printLoadedWebsockets(config.WebsocketConfigs)
}

// static headers
if config.Headers != nil && len(config.Headers.DropHeaders) > 0 {
pterm.Info.Printf("Dropping the following %d %s globally:\n", len(config.Headers.DropHeaders),
Expand Down Expand Up @@ -625,8 +635,7 @@ func Execute(version, commit, date string, fs embed.FS) {
rootCmd.Flags().IntP("hard-validation-code", "q", 400, "Set a custom http error code for non-compliant requests when using the hard-error flag")
rootCmd.Flags().IntP("hard-validation-return-code", "y", 502, "Set a custom http error code for non-compliant responses when using the hard-error flag")
rootCmd.Flags().BoolP("mock-mode", "x", false, "Run in mock mode, responses are mocked and no traffic is sent to the target API (requires OpenAPI spec)")
rootCmd.Flags().StringP("config", "c", "",
"Location of wiretap configuration file to use (default is .wiretap in current directory)")
rootCmd.Flags().StringP("config", "c", "", "Location of wiretap configuration file to use (default is .wiretap in current directory)")
rootCmd.Flags().StringP("base", "b", "", "Set a base path to resolve relative file references from, or a overriding base URL to resolve remote references from")
rootCmd.Flags().BoolP("debug", "l", false, "Enable debug logging")
rootCmd.Flags().StringP("har", "z", "", "Load a HAR file instead of sniffing traffic")
Expand Down Expand Up @@ -706,11 +715,20 @@ func printLoadedIgnoreRedirectPaths(ignoreRedirects []string) {
}

func printLoadedRedirectAllowList(allowRedirects []string) {
pterm.Info.Printf("Loaded %d allows listed redirect %s :\n", len(allowRedirects),
pterm.Info.Printf("Loaded %d allows listed redirect %s:\n", len(allowRedirects),
shared.Pluralize(len(allowRedirects), "path", "paths"))

for _, x := range allowRedirects {
pterm.Printf("🐵 Paths matching '%s' will always follow redirects, regardless of ignoreRedirect settings\n", pterm.LightCyan(x))
}
pterm.Println()
}

func printLoadedWebsockets(websockets map[string]*shared.WiretapWebsocketConfig) {
pterm.Info.Printf("Loaded %d %s: \n", len(websockets), shared.Pluralize(len(websockets), "websocket", "websockets"))

for websocket := range websockets {
pterm.Printf("🔌 Paths prefixed '%s' will be managed as a websocket\n", pterm.LightCyan(websocket))
daveshanley marked this conversation as resolved.
Show resolved Hide resolved
}
pterm.Println()
}
231 changes: 205 additions & 26 deletions daemon/handle_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package daemon

import (
"crypto/tls"
_ "embed"
"fmt"
"github.com/gorilla/websocket"
"io"
"net/http"
"os"
Expand Down Expand Up @@ -99,32 +101,7 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) {
}
}

var dropHeaders []string
var injectHeaders map[string]string

// add global headers with injection.
if config.Headers != nil {
dropHeaders = config.Headers.DropHeaders
injectHeaders = config.Headers.InjectHeaders
}

// now add path specific headers.
matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config)
auth := ""
if len(matchedPaths) > 0 {
for _, path := range matchedPaths {
auth = path.Auth
if path.Headers != nil {
dropHeaders = append(dropHeaders, path.Headers.DropHeaders...)
newInjectHeaders := path.Headers.InjectHeaders
for key := range injectHeaders {
newInjectHeaders[key] = injectHeaders[key]
}
injectHeaders = newInjectHeaders
}
break
}
}
dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request)

newReq := CloneExistingRequest(CloneRequest{
Request: request.HttpRequest,
Expand Down Expand Up @@ -238,8 +215,210 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) {
_, _ = request.HttpResponseWriter.Write(body)
}

var gorillaDropHeaders = []string{
// Gorilla fills in the following headers, and complains if they are already present
"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Protocol",
"Sec-Websocket-Extensions",
}

func (ws *WiretapService) handleWebsocketRequest(request *model.Request) {

configStore, _ := ws.controlsStore.Get(shared.ConfigKey)
config := configStore.(*shared.WiretapConfiguration)

// Get the Websocket Configuration
websocketUrl := request.HttpRequest.URL.String()
websocketConfig, ok := config.WebsocketConfigs[websocketUrl]
if !ok {
ws.config.Logger.Error(fmt.Sprintf("Unable to find websocket config for URL: %s", websocketUrl))
}

// There's nothing to do if we're in mock mode
if config.MockMode {
return
}

upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// Upgrade the connection from the client to open a websocket connection
clientConn, err := upgrader.Upgrade(request.HttpResponseWriter, request.HttpRequest, nil)
if err != nil {
ws.config.Logger.Error("Unable to upgrade websocket connection")
return
}
defer func(clientConn *websocket.Conn) {
_ = clientConn.Close()
}(clientConn)

if config.Headers == nil || len(config.Headers.DropHeaders) == 0 {
config.Headers = &shared.WiretapHeaderConfig{
DropHeaders: []string{},
}
}

// Get the updated headers and auth
dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request)

dropHeaders = append(dropHeaders, gorillaDropHeaders...)
dropHeaders = append(dropHeaders, websocketConfig.DropHeaders...)

// Determine the correct websocket protocol based on redirect protocol
var protocol string
if config.RedirectProtocol == "https" {
protocol = "wss"
} else if config.RedirectProtocol == "http" {
protocol = "ws"
} else if config.RedirectProtocol != "wss" && config.RedirectProtocol != "ws" {
config.Logger.Error(fmt.Sprintf("Unsupported Redirect Protocol: %s", config.RedirectProtocol))
return
}

// Create a new request, which fills in the URL and other information
newRequest := CloneExistingRequest(CloneRequest{
Request: request.HttpRequest,
Protocol: protocol,
Host: config.RedirectHost,
BasePath: config.RedirectBasePath,
Port: config.RedirectPort,
DropHeaders: dropHeaders,
InjectHeaders: injectHeaders,
Auth: auth,
Variables: config.CompiledVariables,
})

// Open a new websocket connection with the server
dialer := *websocket.DefaultDialer
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: !*websocketConfig.VerifyCert}
serverConn, _, err := dialer.Dial(newRequest.URL.String(), newRequest.Header)
if err != nil {
ws.config.Logger.Error(fmt.Sprintf("Unable to connect to remote server; websocket connection failed: %s", err))
return
}
defer func(serverConn *websocket.Conn) {
_ = serverConn.Close()
}(serverConn)

// Create sentinel channels
clientSentinel := make(chan struct{})
serverSentinel := make(chan struct{})

// Go-Routine for communication between Client -> Server
go func() {
defer close(clientSentinel)

for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
_ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
return
}

err = serverConn.WriteMessage(messageType, message)
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
return
}
}
}()

// Go-Routine for communication between Server -> Client
go func() {
defer close(serverSentinel)

for {
messageType, message, err := serverConn.ReadMessage()
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
_ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
return
}

err = clientConn.WriteMessage(messageType, message)
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
return
}
}
}()

// Loop until at least one of our sentinel channels have been closed
for {
select {
case <-clientSentinel:
return
case <-serverSentinel:
return
}
}
}

func setCORSHeaders(headers map[string][]string) {
headers["Access-Control-Allow-Headers"] = []string{"*"}
headers["Access-Control-Allow-Origin"] = []string{"*"}
headers["Access-Control-Allow-Methods"] = []string{"OPTIONS,POST,GET,DELETE,PATCH,PUT"}
}

func getCloseCode(err error) (int, bool) {
unexpectedClose := websocket.IsUnexpectedCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
)

if ce, ok := err.(*websocket.CloseError); ok {
return ce.Code, unexpectedClose
}
return -1, unexpectedClose
}

func logWebsocketClose(config *shared.WiretapConfiguration, closeCode int, isUnexpected bool) {
if isUnexpected {
config.Logger.Warn(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode))
} else {
config.Logger.Info(fmt.Sprintf("Websocket closed expectedly with code: %d", closeCode))
daveshanley marked this conversation as resolved.
Show resolved Hide resolved
}
}

func (ws *WiretapService) getHeadersAndAuth(config *shared.WiretapConfiguration, request *model.Request) ([]string, map[string]string, string) {
var dropHeaders []string
var injectHeaders map[string]string

// add global headers with injection.
if config.Headers != nil {
dropHeaders = config.Headers.DropHeaders
injectHeaders = config.Headers.InjectHeaders
}

// now add path specific headers.
matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config)
auth := ""
if len(matchedPaths) > 0 {
for _, path := range matchedPaths {
auth = path.Auth
if path.Headers != nil {
dropHeaders = append(dropHeaders, path.Headers.DropHeaders...)
newInjectHeaders := path.Headers.InjectHeaders
for key := range injectHeaders {
newInjectHeaders[key] = injectHeaders[key]
}
injectHeaders = newInjectHeaders
}
break
}
}

return dropHeaders, injectHeaders, auth
}
5 changes: 4 additions & 1 deletion daemon/wiretap_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ func (ws *WiretapService) HandleServiceRequest(request *model.Request, core serv
}

func (ws *WiretapService) HandleHttpRequest(request *model.Request) {

ws.handleHttpRequest(request)
}

func (ws *WiretapService) HandleWebsocketRequest(request *model.Request) {
ws.handleWebsocketRequest(request)
}
Loading
Loading