Skip to content

Commit

Permalink
feat: introduce OAuth2 authentication passthrough for trusted ESP ser…
Browse files Browse the repository at this point in the history
…vers in Viya
  • Loading branch information
amlmtl committed Nov 17, 2023
1 parent d704aaa commit e4dd2da
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 78 deletions.
5 changes: 4 additions & 1 deletion internal/esp/client/espwsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ const (
const jsonFormat string = "json"
const cborFormat string = "cbor"

func New(wsConnectionUrl url.URL) *EspWsClient {
func New(wsConnectionUrl url.URL, authorizationHeader *string) *EspWsClient {
socket := gowebsocket.New(wsConnectionUrl.String())
if authorizationHeader != nil {
socket.RequestHeader.Set("Authorization", *authorizationHeader)
}

espWsClient := EspWsClient{
socket: &socket,
Expand Down
32 changes: 17 additions & 15 deletions internal/plugin/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,26 @@ import (
const CHANNEL_PATH_REGEX_PATTERN string = `^[A-z0-9_\-/=.]*$`

type Query struct {
ServerUrl url.URL
ProjectName string
CqName string
WindowName string
Fields []string
EventInterval uint64
MaxEvents uint64
ServerUrl url.URL
ProjectName string
CqName string
WindowName string
Fields []string
EventInterval uint64
MaxEvents uint64
AuthorizationHeader *string
}

func New(s server.Server, projectName string, cqName string, windowName string, interval uint64, maxEvents uint64, fields []string) *Query {
func New(s server.Server, projectName string, cqName string, windowName string, interval uint64, maxEvents uint64, fields []string, authorizationHeader *string) *Query {
return &Query{
ServerUrl: s.GetUrl(),
ProjectName: projectName,
CqName: cqName,
WindowName: windowName,
EventInterval: interval,
MaxEvents: maxEvents,
Fields: fields,
ServerUrl: s.GetUrl(),
ProjectName: projectName,
CqName: cqName,
WindowName: windowName,
EventInterval: interval,
MaxEvents: maxEvents,
Fields: fields,
AuthorizationHeader: authorizationHeader,
}
}

Expand Down
186 changes: 124 additions & 62 deletions pkg/plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,20 @@ var (

// NewSampleDatasource creates a new datasource instance.
func NewSampleDatasource(ctx context.Context, settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
discoveryUrl, err := url.Parse(settings.URL)
if err != nil {
return nil, err
}

opts, err := settings.HTTPClientOptions(ctx)
if err != nil {
return nil, fmt.Errorf("http client options: %w", err)
return nil, err
}
opts.ForwardHTTPHeaders = true

cl, err := httpclient.New(opts)
if err != nil {
return nil, fmt.Errorf("httpclient new: %w", err)
return nil, err
}

var jsonData datasourceJsonData
Expand All @@ -68,18 +73,22 @@ func NewSampleDatasource(ctx context.Context, settings backend.DataSourceInstanc
log.DefaultLogger.Debug(fmt.Sprintf("created data source with ForwardHTTPHeaders option set to: %v", opts.ForwardHTTPHeaders))

return &SampleDatasource{
channelQueryMap: syncmap.New[string, query.Query](),
httpClient: cl,
jsonData: jsonData,
httpClient: cl,
discoveryEndpointUrl: *discoveryUrl,
jsonData: jsonData,
channelQueryMap: syncmap.New[string, query.Query](),
serverUrlTrustedMap: syncmap.New[string, bool](),
}, nil
}

// SampleDatasource is an example datasource which can respond to data queries, reports
// its health and has streaming skills.
type SampleDatasource struct {
channelQueryMap *syncmap.SyncMap[string, query.Query]
httpClient *http.Client
jsonData datasourceJsonData
channelQueryMap *syncmap.SyncMap[string, query.Query]
httpClient *http.Client
jsonData datasourceJsonData
serverUrlTrustedMap *syncmap.SyncMap[string, bool]
discoveryEndpointUrl url.URL
}

type datasourceJsonData struct {
Expand All @@ -100,30 +109,45 @@ func (d *SampleDatasource) Dispose() {
// The QueryDataResponse contains a map of RefID to the response for each query, and each response
// contains Frames ([]*Frame).
func (d *SampleDatasource) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
// create response struct
response := backend.NewQueryDataResponse()

var jsonData datasourceJsonData
err := json.Unmarshal(req.PluginContext.DataSourceInstanceSettings.JSONData, &jsonData)
if err != nil {
return nil, err
}
var authorizationHeaderPtr *string = nil
if jsonData.OauthPassThru && jsonData.IsViya {
authorizationHeader := req.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName)
authorizationHeaderPtr = &authorizationHeader
}

for _, q := range req.Queries {
res := d.query(ctx, req.PluginContext.DataSourceInstanceSettings.UID, q.JSON)
var qdto querydto.QueryDTO
err := json.Unmarshal(q.JSON, &qdto)
if err != nil {
response.Responses[q.RefID] = handleQueryError("invalid query", err)
continue
}

var authHeaderToBePassed *string = nil
if authorizationHeaderPtr != nil && d.isServerUrlTrusted(qdto.ServerUrl, true, authorizationHeaderPtr) {
authHeaderToBePassed = authorizationHeaderPtr
}

response.Responses[q.RefID] = res
response.Responses[q.RefID] = d.query(ctx, req.PluginContext.DataSourceInstanceSettings.UID, qdto, authHeaderToBePassed)
}

return response, nil
}

func (d *SampleDatasource) query(_ context.Context, datasourceUid string, queryJson json.RawMessage) backend.DataResponse {
var qdto querydto.QueryDTO
err := json.Unmarshal(queryJson, &qdto)
if err != nil {
return handleQueryError("invalid query", err)
}
func (d *SampleDatasource) query(_ context.Context, datasourceUid string, qdto querydto.QueryDTO, authorizationHeader *string) backend.DataResponse {
s, err := server.FromUrlString(qdto.ServerUrl)
if err != nil {
return handleQueryError("invalid server URL", err)
}

q := query.New(*s, qdto.ProjectName, qdto.CqName, qdto.WindowName, qdto.Interval, qdto.MaxDataPoints, qdto.Fields)
q := query.New(*s, qdto.ProjectName, qdto.CqName, qdto.WindowName, qdto.Interval, qdto.MaxDataPoints, qdto.Fields, authorizationHeader)

channelPath, err := q.ToChannelPath()
if err != nil {
Expand Down Expand Up @@ -248,7 +272,7 @@ func (d *SampleDatasource) RunStream(ctx context.Context, req *backend.RunStream
return nil
}

espWsClient := client.New(q.ServerUrl)
espWsClient := client.New(q.ServerUrl, q.AuthorizationHeader)
defer espWsClient.Close()

subscribeToQuery := func() {
Expand Down Expand Up @@ -352,83 +376,121 @@ func newSerializedCallResourceResponseErrorBody(errorMessage string) []byte {
return errorResponseBody
}

func (d *SampleDatasource) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
var response = &backend.CallResourceResponse{
Status: http.StatusNotFound,
Headers: map[string][]string{},
}
type discoveredServer struct {
Url string `json:"url"`
Trusted bool `json:"trusted"`
}

func (d *SampleDatasource) CallResource(_ context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
var response backend.CallResourceResponse
switch req.Path {
case "servers":
discoveryServiceUrl, err := url.Parse(req.PluginContext.DataSourceInstanceSettings.URL)
if err != nil {
return err
var authHeaderPtr *string
if d.jsonData.OauthPassThru == true {
authHeader := req.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName)
authHeaderPtr = &authHeader
}

var discoveryResponse *http.Response
discoveryResponse, err = callDiscoveryEndpoint(ctx, d.httpClient, *discoveryServiceUrl)
serversData, discoveredServers, err := d.fetchDiscoverableServers(authHeaderPtr)
if err != nil {
errorMessage := "Unable to obtain ESP server schema information."
response.Body = newSerializedCallResourceResponseErrorBody(errorMessage)
response.Status = http.StatusInternalServerError
log.DefaultLogger.Error(errorMessage, "error", err)
log.DefaultLogger.Error(err.Error())
body := newSerializedCallResourceResponseErrorBody("Unable to fetch discoverable ESP servers.")
response := &backend.CallResourceResponse{
Status: http.StatusInternalServerError,
Body: body,
}
return sender.Send(response)
}

servers, err := io.ReadAll(discoveryResponse.Body)
if err != nil {
errorMessage := "Unable to read discovery response."
response.Body = newSerializedCallResourceResponseErrorBody(errorMessage)
response.Status = http.StatusInternalServerError
log.DefaultLogger.Error(errorMessage, "error", err)
return sender.Send(response)
for _, discoveredServer := range *discoveredServers {
d.serverUrlTrustedMap.Set(discoveredServer.Url, &discoveredServer.Trusted)
}

responseBody, err := json.Marshal(callResourceResponseBody{Data: servers})
responseBody, err := json.Marshal(callResourceResponseBody{Data: *serversData})
if err != nil {
errorMessage := "Unable to serialize discovery response."
response.Body = newSerializedCallResourceResponseErrorBody(errorMessage)
response.Status = http.StatusInternalServerError
response := &backend.CallResourceResponse{
Status: http.StatusInternalServerError,
Body: newSerializedCallResourceResponseErrorBody(errorMessage),
}
log.DefaultLogger.Error(errorMessage, "error", err)
return sender.Send(response)
}

response.Status = http.StatusOK
response.Body = responseBody
response = backend.CallResourceResponse{
Status: http.StatusOK,
Body: responseBody,
}
return sender.Send(&response)
default:
response = backend.CallResourceResponse{
Status: http.StatusNotFound,
}
break
}

return sender.Send(response)
return sender.Send(&response)
}

func callDiscoveryEndpoint(ctx context.Context, httpClient *http.Client, discoveryServiceUrl url.URL) (*http.Response, error) {
var discoveryEndpointUrl = discoveryServiceUrl.String() + "/grafana/discovery"
func (d *SampleDatasource) fetchDiscoverableServers(authHeader *string) (*[]byte, *[]discoveredServer, error) {
var discoveryEndpointUrl = d.discoveryEndpointUrl.String() + "/grafana/discovery"
log.DefaultLogger.Debug("Calling discovery endpoint", "discoveryEndpointUrl", discoveryEndpointUrl)

request, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryEndpointUrl, nil)
request, err := http.NewRequest(http.MethodGet, discoveryEndpointUrl, nil)
if err != nil {
log.DefaultLogger.Error("Unable to create discovery request.", "error", err)
return nil, err
return nil, nil, err
}

if request.Header.Get("Accept-Encoding") == "" {
// Grafana (as of version 9.x) contains a bug where it will enable compression by adding an explicit Accept-Encoding header by default if missing,
// but not properly handle decompression of any compressed response.
// The header set explicitly here as a workaround for this bug.
request.Header.Set("Accept-Encoding", "identity")
if authHeader != nil {
request.Header.Set(backend.OAuthIdentityTokenHeaderName, *authHeader)
}

resp, err := httpClient.Do(request)
resp, err := d.httpClient.Do(request)
if err != nil {
log.DefaultLogger.Error("Unable to receive discovery response.", "error", err)
return nil, err
log.DefaultLogger.Error(err.Error())
return nil, nil, fmt.Errorf("unable to receive discovery response")
}

if resp.StatusCode != 200 {
var err = fmt.Errorf("the discovery service sent an unexpected HTTP status code: %d", resp.StatusCode)
return nil, err
return nil, nil, fmt.Errorf("received unexpected HTTP status code: %d", resp.StatusCode)
}

serversData, err := io.ReadAll(resp.Body)
if err != nil {
log.DefaultLogger.Error(err.Error())
return nil, nil, fmt.Errorf("unable to read discovery response")
}

var discoveredServers []discoveredServer
err = json.Unmarshal(serversData, &discoveredServers)
if err != nil {
log.DefaultLogger.Error(err.Error())
return &serversData, nil, fmt.Errorf("unable to unmarshal discovery response")
}

return &serversData, &discoveredServers, nil
}

func (d *SampleDatasource) isServerUrlTrusted(url string, fetchIfMissing bool, authHeader *string) bool {
isServerUrlTrusted, err := d.serverUrlTrustedMap.Get(url)
if err == nil {
return *isServerUrlTrusted
}

if fetchIfMissing {
_, discoveredServers, fetchErr := d.fetchDiscoverableServers(authHeader)
if fetchErr != nil {
log.DefaultLogger.Error("Unable to fetch trusted status of server URL", "url", url, "error", err)
return false
}

for _, discoveredServer := range *discoveredServers {
d.serverUrlTrustedMap.Set(discoveredServer.Url, &discoveredServer.Trusted)
}

return d.isServerUrlTrusted(url, false, nil)
}

return resp, nil
log.DefaultLogger.Error("Unable to determine trusted status of server URL", "url", url, "error", err)
return false
}

0 comments on commit e4dd2da

Please sign in to comment.