Skip to content

Commit

Permalink
Merge branch 'main' of github.com:jacobm-splunk/wiretap into jacobm/a…
Browse files Browse the repository at this point in the history
…dd-request-validation-filtering
  • Loading branch information
jacobm-splunk committed Apr 11, 2024
2 parents 8f5e41f + 3c27d94 commit 561212a
Show file tree
Hide file tree
Showing 7 changed files with 489 additions and 40 deletions.
15 changes: 15 additions & 0 deletions cmd/handle_http_traffic.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,27 @@ func handleHttpTraffic(wiretapConfig *shared.WiretapConfiguration, wtService *da
wtService.HandleHttpRequest(requestModel)
}

handleWebsocket := func(w http.ResponseWriter, r *http.Request) {
id, _ := uuid.NewUUID()
requestModel := &model.Request{
Id: &id,
HttpRequest: r,
HttpResponseWriter: w,
}
wtService.HandleWebsocketRequest(requestModel)
}

// create a new mux.
mux := http.NewServeMux()

// handle the index
mux.HandleFunc("/", handleTraffic)

// Handle Websockets
for websocket := range wiretapConfig.WebsocketConfigs {
mux.HandleFunc(websocket, handleWebsocket)
}

pterm.Info.Println(pterm.LightMagenta(fmt.Sprintf("API Gateway UI booting on port %s...", wiretapConfig.Port)))

var httpErr error
Expand Down
24 changes: 21 additions & 3 deletions cmd/root_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,16 @@ var (
printLoadedValidationAllowList(config.ValidationAllowList)
}

if len(config.WebsocketConfigs) > 0 {
for _, config := range config.WebsocketConfigs {
if config.VerifyCert == nil {
config.VerifyCert = func() *bool { b := true; return &b }()
}
}

printLoadedWebsockets(config.WebsocketConfigs)
}

// static headers
if config.Headers != nil && len(config.Headers.DropHeaders) > 0 {
pterm.Info.Printf("Dropping the following %d %s globally:\n", len(config.Headers.DropHeaders),
Expand Down Expand Up @@ -635,8 +645,7 @@ func Execute(version, commit, date string, fs embed.FS) {
rootCmd.Flags().IntP("hard-validation-code", "q", 400, "Set a custom http error code for non-compliant requests when using the hard-error flag")
rootCmd.Flags().IntP("hard-validation-return-code", "y", 502, "Set a custom http error code for non-compliant responses when using the hard-error flag")
rootCmd.Flags().BoolP("mock-mode", "x", false, "Run in mock mode, responses are mocked and no traffic is sent to the target API (requires OpenAPI spec)")
rootCmd.Flags().StringP("config", "c", "",
"Location of wiretap configuration file to use (default is .wiretap in current directory)")
rootCmd.Flags().StringP("config", "c", "", "Location of wiretap configuration file to use (default is .wiretap in current directory)")
rootCmd.Flags().StringP("base", "b", "", "Set a base path to resolve relative file references from, or a overriding base URL to resolve remote references from")
rootCmd.Flags().BoolP("debug", "l", false, "Enable debug logging")
rootCmd.Flags().StringP("har", "z", "", "Load a HAR file instead of sniffing traffic")
Expand Down Expand Up @@ -716,7 +725,7 @@ func printLoadedIgnoreRedirectPaths(ignoreRedirects []string) {
}

func printLoadedRedirectAllowList(allowRedirects []string) {
pterm.Info.Printf("Loaded %d allow listed redirect %s :\n", len(allowRedirects),
pterm.Info.Printf("Loaded %d allow listed redirect %s:\n", len(allowRedirects),
shared.Pluralize(len(allowRedirects), "path", "paths"))

for _, x := range allowRedirects {
Expand All @@ -725,6 +734,15 @@ func printLoadedRedirectAllowList(allowRedirects []string) {
pterm.Println()
}

func printLoadedWebsockets(websockets map[string]*shared.WiretapWebsocketConfig) {
pterm.Info.Printf("Loaded %d %s: \n", len(websockets), shared.Pluralize(len(websockets), "websocket", "websockets"))

for websocket := range websockets {
pterm.Printf("🔌 Paths prefixed '%s' will be managed as a websocket\n", pterm.LightCyan(websocket))
}
pterm.Println()
}

func printLoadedIgnoreValidationPaths(ignoreValidations []string) {
pterm.Info.Printf("Loaded %d %s to ignore validation:\n", len(ignoreValidations),
shared.Pluralize(len(ignoreValidations), "path", "paths"))
Expand Down
231 changes: 205 additions & 26 deletions daemon/handle_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package daemon

import (
"crypto/tls"
_ "embed"
"fmt"
"github.com/gorilla/websocket"
"io"
"net/http"
"os"
Expand Down Expand Up @@ -99,32 +101,7 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) {
}
}

var dropHeaders []string
var injectHeaders map[string]string

// add global headers with injection.
if config.Headers != nil {
dropHeaders = config.Headers.DropHeaders
injectHeaders = config.Headers.InjectHeaders
}

// now add path specific headers.
matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config)
auth := ""
if len(matchedPaths) > 0 {
for _, path := range matchedPaths {
auth = path.Auth
if path.Headers != nil {
dropHeaders = append(dropHeaders, path.Headers.DropHeaders...)
newInjectHeaders := path.Headers.InjectHeaders
for key := range injectHeaders {
newInjectHeaders[key] = injectHeaders[key]
}
injectHeaders = newInjectHeaders
}
break
}
}
dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request)

newReq := CloneExistingRequest(CloneRequest{
Request: request.HttpRequest,
Expand Down Expand Up @@ -235,8 +212,210 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) {
_, _ = request.HttpResponseWriter.Write(body)
}

var gorillaDropHeaders = []string{
// Gorilla fills in the following headers, and complains if they are already present
"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Protocol",
"Sec-Websocket-Extensions",
}

func (ws *WiretapService) handleWebsocketRequest(request *model.Request) {

configStore, _ := ws.controlsStore.Get(shared.ConfigKey)
config := configStore.(*shared.WiretapConfiguration)

// Get the Websocket Configuration
websocketUrl := request.HttpRequest.URL.String()
websocketConfig, ok := config.WebsocketConfigs[websocketUrl]
if !ok {
ws.config.Logger.Error(fmt.Sprintf("Unable to find websocket config for URL: %s", websocketUrl))
}

// There's nothing to do if we're in mock mode
if config.MockMode {
return
}

upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// Upgrade the connection from the client to open a websocket connection
clientConn, err := upgrader.Upgrade(request.HttpResponseWriter, request.HttpRequest, nil)
if err != nil {
ws.config.Logger.Error("Unable to upgrade websocket connection")
return
}
defer func(clientConn *websocket.Conn) {
_ = clientConn.Close()
}(clientConn)

if config.Headers == nil || len(config.Headers.DropHeaders) == 0 {
config.Headers = &shared.WiretapHeaderConfig{
DropHeaders: []string{},
}
}

// Get the updated headers and auth
dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request)

dropHeaders = append(dropHeaders, gorillaDropHeaders...)
dropHeaders = append(dropHeaders, websocketConfig.DropHeaders...)

// Determine the correct websocket protocol based on redirect protocol
var protocol string
if config.RedirectProtocol == "https" {
protocol = "wss"
} else if config.RedirectProtocol == "http" {
protocol = "ws"
} else if config.RedirectProtocol != "wss" && config.RedirectProtocol != "ws" {
config.Logger.Error(fmt.Sprintf("Unsupported Redirect Protocol: %s", config.RedirectProtocol))
return
}

// Create a new request, which fills in the URL and other information
newRequest := CloneExistingRequest(CloneRequest{
Request: request.HttpRequest,
Protocol: protocol,
Host: config.RedirectHost,
BasePath: config.RedirectBasePath,
Port: config.RedirectPort,
DropHeaders: dropHeaders,
InjectHeaders: injectHeaders,
Auth: auth,
Variables: config.CompiledVariables,
})

// Open a new websocket connection with the server
dialer := *websocket.DefaultDialer
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: !*websocketConfig.VerifyCert}
serverConn, _, err := dialer.Dial(newRequest.URL.String(), newRequest.Header)
if err != nil {
ws.config.Logger.Error(fmt.Sprintf("Unable to connect to remote server; websocket connection failed: %s", err))
return
}
defer func(serverConn *websocket.Conn) {
_ = serverConn.Close()
}(serverConn)

// Create sentinel channels
clientSentinel := make(chan struct{})
serverSentinel := make(chan struct{})

// Go-Routine for communication between Client -> Server
go func() {
defer close(clientSentinel)

for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
_ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
return
}

err = serverConn.WriteMessage(messageType, message)
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
return
}
}
}()

// Go-Routine for communication between Server -> Client
go func() {
defer close(serverSentinel)

for {
messageType, message, err := serverConn.ReadMessage()
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
_ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
return
}

err = clientConn.WriteMessage(messageType, message)
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
return
}
}
}()

// Loop until at least one of our sentinel channels have been closed
for {
select {
case <-clientSentinel:
return
case <-serverSentinel:
return
}
}
}

func setCORSHeaders(headers map[string][]string) {
headers["Access-Control-Allow-Headers"] = []string{"*"}
headers["Access-Control-Allow-Origin"] = []string{"*"}
headers["Access-Control-Allow-Methods"] = []string{"OPTIONS,POST,GET,DELETE,PATCH,PUT"}
}

func getCloseCode(err error) (int, bool) {
unexpectedClose := websocket.IsUnexpectedCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
)

if ce, ok := err.(*websocket.CloseError); ok {
return ce.Code, unexpectedClose
}
return -1, unexpectedClose
}

func logWebsocketClose(config *shared.WiretapConfiguration, closeCode int, isUnexpected bool) {
if isUnexpected {
config.Logger.Warn(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode))
} else {
config.Logger.Info(fmt.Sprintf("Websocket closed expectedly with code: %d", closeCode))
}
}

func (ws *WiretapService) getHeadersAndAuth(config *shared.WiretapConfiguration, request *model.Request) ([]string, map[string]string, string) {
var dropHeaders []string
var injectHeaders map[string]string

// add global headers with injection.
if config.Headers != nil {
dropHeaders = config.Headers.DropHeaders
injectHeaders = config.Headers.InjectHeaders
}

// now add path specific headers.
matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config)
auth := ""
if len(matchedPaths) > 0 {
for _, path := range matchedPaths {
auth = path.Auth
if path.Headers != nil {
dropHeaders = append(dropHeaders, path.Headers.DropHeaders...)
newInjectHeaders := path.Headers.InjectHeaders
for key := range injectHeaders {
newInjectHeaders[key] = injectHeaders[key]
}
injectHeaders = newInjectHeaders
}
break
}
}

return dropHeaders, injectHeaders, auth
}
5 changes: 4 additions & 1 deletion daemon/wiretap_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ func (ws *WiretapService) HandleServiceRequest(request *model.Request, core serv
}

func (ws *WiretapService) HandleHttpRequest(request *model.Request) {

ws.handleHttpRequest(request)
}

func (ws *WiretapService) HandleWebsocketRequest(request *model.Request) {
ws.handleWebsocketRequest(request)
}
Loading

0 comments on commit 561212a

Please sign in to comment.