Skip to content

Commit

Permalink
Fix X-Id-Token auth header forwarding (#1007)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanahuckova authored Oct 2, 2024
1 parent ec0fb40 commit 9d6533c
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 21 deletions.
5 changes: 5 additions & 0 deletions .changeset/great-paws-marry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'grafana-infinity-datasource': patch
---

Fix forward oauth for x-id-token header
28 changes: 19 additions & 9 deletions pkg/infinity/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ const (
const (
headerKeyAccept = "Accept"
headerKeyContentType = "Content-Type"
headerKeyAuthorization = "Authorization"
headerKeyIdToken = "X-ID-Token"
HeaderKeyAuthorization = "Authorization"
HeaderKeyIdToken = "X-Id-Token"
)

func ApplyAcceptHeader(query models.Query, settings models.InfinitySettings, req *http.Request, includeSect bool) *http.Request {
Expand Down Expand Up @@ -103,7 +103,7 @@ func ApplyBasicAuth(settings models.InfinitySettings, req *http.Request, include
if includeSect {
basicAuthHeader = "Basic " + base64.StdEncoding.EncodeToString([]byte(settings.UserName+":"+settings.Password))
}
req.Header.Set(headerKeyAuthorization, basicAuthHeader)
req.Header.Set(HeaderKeyAuthorization, basicAuthHeader)
}
return req
}
Expand All @@ -114,7 +114,7 @@ func ApplyBearerToken(settings models.InfinitySettings, req *http.Request, inclu
if includeSect {
bearerAuthHeader = fmt.Sprintf("Bearer %s", settings.BearerToken)
}
req.Header.Add(headerKeyAuthorization, bearerAuthHeader)
req.Header.Add(HeaderKeyAuthorization, bearerAuthHeader)
}
return req
}
Expand All @@ -137,13 +137,23 @@ func ApplyForwardedOAuthIdentity(requestHeaders map[string]string, settings mode
authHeader := dummyHeader
token := dummyHeader
if includeSect {
authHeader = requestHeaders[headerKeyAuthorization]
token = requestHeaders[headerKeyIdToken]
authHeader = getQueryReqHeader(requestHeaders, HeaderKeyAuthorization)
token = getQueryReqHeader(requestHeaders, HeaderKeyIdToken)
}
req.Header.Add(headerKeyAuthorization, authHeader)
if requestHeaders[headerKeyIdToken] != "" {
req.Header.Add(headerKeyIdToken, token)
req.Header.Add(HeaderKeyAuthorization, authHeader)
if token != "" && token != dummyHeader {
req.Header.Add(HeaderKeyIdToken, token)
}
}
return req
}

func getQueryReqHeader(requestHeaders map[string]string, headerName string) string {
for name, value := range requestHeaders {
if strings.EqualFold(headerName, name) {
return value
}
}

return ""
}
65 changes: 65 additions & 0 deletions pkg/infinity/headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package infinity

import (
"strings"
"testing"
)

func TestGetQueryReqHeader(t *testing.T) {
tests := []struct {
name string
requestHeaders map[string]string
headerName string
expected string
}{
{
name: "Authorization header exact match",
requestHeaders: map[string]string{
HeaderKeyAuthorization: "Bearer token",
},
headerName: HeaderKeyAuthorization,
expected: "Bearer token",
},
{
name: "Authorization header case insensitive match",
requestHeaders: map[string]string{
strings.ToLower(HeaderKeyAuthorization): "Bearer token",
},
headerName: HeaderKeyAuthorization,
expected: "Bearer token",
},
{
name: "X-Id-Token header exact match",
requestHeaders: map[string]string{
HeaderKeyIdToken: "some-id-token",
},
headerName: HeaderKeyIdToken,
expected: "some-id-token",
},
{
name: "X-Id-Token header case insensitive match",
requestHeaders: map[string]string{
strings.ToLower(HeaderKeyIdToken): "some-id-token",
},
headerName: HeaderKeyIdToken,
expected: "some-id-token",
},
{
name: "X-Id-Token header case with ID capitalization",
requestHeaders: map[string]string{
"X-ID-Token": "some-id-token",
},
headerName: HeaderKeyIdToken,
expected: "some-id-token",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getQueryReqHeader(tt.requestHeaders, tt.headerName)
if got != tt.expected {
t.Errorf("getQueryReqHeader() = %v, expected %v", got, tt.expected)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/infinity/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func ApplyNotices(ctx context.Context, settings models.InfinitySettings, frame *
func GetSecureHeaderWarnings(query models.Query) []data.Notice {
notices := []data.Notice{}
for _, h := range query.URLOptions.Headers {
if strings.EqualFold(h.Key, headerKeyAuthorization) {
if strings.EqualFold(h.Key, HeaderKeyAuthorization) {
notices = append(notices, data.Notice{
Severity: data.NoticeSeverityWarning,
Text: fmt.Sprintf("for security reasons, don't include headers such as %s in the query. Instead, add them in the config where possible", h.Key),
Expand Down
22 changes: 11 additions & 11 deletions pkg/testsuite/handler_querydata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ func TestAuthentication(t *testing.T) {
t.Run("should set basic auth headers when set the username and password", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "Basic "+base64.StdEncoding.EncodeToString([]byte("infinityUser:myPassword")), r.Header.Get("Authorization"))
assert.Equal(t, "", r.Header.Get("X-ID-Token"))
assert.Equal(t, "Basic "+base64.StdEncoding.EncodeToString([]byte("infinityUser:myPassword")), r.Header.Get(infinity.HeaderKeyAuthorization))
assert.Equal(t, "", r.Header.Get(infinity.HeaderKeyIdToken))
fmt.Fprintf(w, `{ "message" : "OK" }`)
}))
defer server.Close()
Expand Down Expand Up @@ -76,8 +76,8 @@ func TestAuthentication(t *testing.T) {
t.Run("should return error when incorrect credentials set", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "", r.Header.Get("X-ID-Token"))
if r.Header.Get("Authorization") == "Basic "+base64.StdEncoding.EncodeToString([]byte("infinityUser:myPassword")) {
assert.Equal(t, "", r.Header.Get(infinity.HeaderKeyIdToken))
if r.Header.Get(infinity.HeaderKeyAuthorization) == "Basic "+base64.StdEncoding.EncodeToString([]byte("infinityUser:myPassword")) {
fmt.Fprintf(w, "OK")
return
}
Expand Down Expand Up @@ -113,8 +113,8 @@ func TestAuthentication(t *testing.T) {
t.Run("should forward the oauth headers when forward oauth identity is set", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "foo", r.Header.Get("Authorization"))
assert.Equal(t, "bar", r.Header.Get("X-ID-Token"))
assert.Equal(t, "foo", r.Header.Get(infinity.HeaderKeyAuthorization))
assert.Equal(t, "bar", r.Header.Get(infinity.HeaderKeyIdToken))
fmt.Fprintf(w, `{ "message" : "OK" }`)
}))
defer server.Close()
Expand All @@ -137,8 +137,8 @@ func TestAuthentication(t *testing.T) {
t.Run("should not forward the oauth headers when forward oauth identity is not set", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "", r.Header.Get("Authorization"))
assert.Equal(t, "", r.Header.Get("X-ID-Token"))
assert.Equal(t, "", r.Header.Get(infinity.HeaderKeyAuthorization))
assert.Equal(t, "", r.Header.Get(infinity.HeaderKeyIdToken))
fmt.Fprintf(w, `{ "message" : "OK" }`)
}))
defer server.Close()
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestAuthentication(t *testing.T) {
w.WriteHeader(http.StatusUnauthorized)
return
}
if r.Header.Get("Authorization") != "Bearer foo" {
if r.Header.Get(infinity.HeaderKeyAuthorization) != "Bearer foo" {
w.WriteHeader(http.StatusUnauthorized)
return
}
Expand Down Expand Up @@ -249,7 +249,7 @@ func TestAuthentication(t *testing.T) {
t.Run("should error when CA cert verification failed", func(t *testing.T) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "", r.Header.Get("X-ID-Token"))
assert.Equal(t, "", r.Header.Get(infinity.HeaderKeyIdToken))
fmt.Fprintf(w, `{ "message" : "OK" }`)
}))
server.TLS = getServerCertificate(server.URL)
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestAuthentication(t *testing.T) {
t.Run("should honour skip tls verify setting", func(t *testing.T) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "", r.Header.Get("X-ID-Token"))
assert.Equal(t, "", r.Header.Get(infinity.HeaderKeyIdToken))
fmt.Fprintf(w, `{ "message" : "OK" }`)
}))
server.TLS = getServerCertificate(server.URL)
Expand Down

0 comments on commit 9d6533c

Please sign in to comment.