From e4dd2da37948bfd0dbb9d898d38b8857969e0548 Mon Sep 17 00:00:00 2001 From: amlmtl <> Date: Fri, 17 Nov 2023 14:26:08 +0000 Subject: [PATCH] feat: introduce OAuth2 authentication passthrough for trusted ESP servers in Viya --- internal/esp/client/espwsclient.go | 5 +- internal/plugin/query/query.go | 32 ++--- pkg/plugin/plugin.go | 186 +++++++++++++++++++---------- 3 files changed, 145 insertions(+), 78 deletions(-) diff --git a/internal/esp/client/espwsclient.go b/internal/esp/client/espwsclient.go index e676d17..52bcbd5 100644 --- a/internal/esp/client/espwsclient.go +++ b/internal/esp/client/espwsclient.go @@ -61,8 +61,11 @@ const ( const jsonFormat string = "json" const cborFormat string = "cbor" -func New(wsConnectionUrl url.URL) *EspWsClient { +func New(wsConnectionUrl url.URL, authorizationHeader *string) *EspWsClient { socket := gowebsocket.New(wsConnectionUrl.String()) + if authorizationHeader != nil { + socket.RequestHeader.Set("Authorization", *authorizationHeader) + } espWsClient := EspWsClient{ socket: &socket, diff --git a/internal/plugin/query/query.go b/internal/plugin/query/query.go index 020ab94..bdad643 100644 --- a/internal/plugin/query/query.go +++ b/internal/plugin/query/query.go @@ -18,24 +18,26 @@ import ( const CHANNEL_PATH_REGEX_PATTERN string = `^[A-z0-9_\-/=.]*$` type Query struct { - ServerUrl url.URL - ProjectName string - CqName string - WindowName string - Fields []string - EventInterval uint64 - MaxEvents uint64 + ServerUrl url.URL + ProjectName string + CqName string + WindowName string + Fields []string + EventInterval uint64 + MaxEvents uint64 + AuthorizationHeader *string } -func New(s server.Server, projectName string, cqName string, windowName string, interval uint64, maxEvents uint64, fields []string) *Query { +func New(s server.Server, projectName string, cqName string, windowName string, interval uint64, maxEvents uint64, fields []string, authorizationHeader *string) *Query { return &Query{ - ServerUrl: s.GetUrl(), - ProjectName: projectName, - CqName: cqName, - WindowName: windowName, - EventInterval: interval, - MaxEvents: maxEvents, - Fields: fields, + ServerUrl: s.GetUrl(), + ProjectName: projectName, + CqName: cqName, + WindowName: windowName, + EventInterval: interval, + MaxEvents: maxEvents, + Fields: fields, + AuthorizationHeader: authorizationHeader, } } diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index f18a7ff..2441b40 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -48,15 +48,20 @@ var ( // NewSampleDatasource creates a new datasource instance. func NewSampleDatasource(ctx context.Context, settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { + discoveryUrl, err := url.Parse(settings.URL) + if err != nil { + return nil, err + } + opts, err := settings.HTTPClientOptions(ctx) if err != nil { - return nil, fmt.Errorf("http client options: %w", err) + return nil, err } opts.ForwardHTTPHeaders = true cl, err := httpclient.New(opts) if err != nil { - return nil, fmt.Errorf("httpclient new: %w", err) + return nil, err } var jsonData datasourceJsonData @@ -68,18 +73,22 @@ func NewSampleDatasource(ctx context.Context, settings backend.DataSourceInstanc log.DefaultLogger.Debug(fmt.Sprintf("created data source with ForwardHTTPHeaders option set to: %v", opts.ForwardHTTPHeaders)) return &SampleDatasource{ - channelQueryMap: syncmap.New[string, query.Query](), - httpClient: cl, - jsonData: jsonData, + httpClient: cl, + discoveryEndpointUrl: *discoveryUrl, + jsonData: jsonData, + channelQueryMap: syncmap.New[string, query.Query](), + serverUrlTrustedMap: syncmap.New[string, bool](), }, nil } // SampleDatasource is an example datasource which can respond to data queries, reports // its health and has streaming skills. type SampleDatasource struct { - channelQueryMap *syncmap.SyncMap[string, query.Query] - httpClient *http.Client - jsonData datasourceJsonData + channelQueryMap *syncmap.SyncMap[string, query.Query] + httpClient *http.Client + jsonData datasourceJsonData + serverUrlTrustedMap *syncmap.SyncMap[string, bool] + discoveryEndpointUrl url.URL } type datasourceJsonData struct { @@ -100,30 +109,45 @@ func (d *SampleDatasource) Dispose() { // The QueryDataResponse contains a map of RefID to the response for each query, and each response // contains Frames ([]*Frame). func (d *SampleDatasource) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { - // create response struct response := backend.NewQueryDataResponse() + var jsonData datasourceJsonData + err := json.Unmarshal(req.PluginContext.DataSourceInstanceSettings.JSONData, &jsonData) + if err != nil { + return nil, err + } + var authorizationHeaderPtr *string = nil + if jsonData.OauthPassThru && jsonData.IsViya { + authorizationHeader := req.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName) + authorizationHeaderPtr = &authorizationHeader + } + for _, q := range req.Queries { - res := d.query(ctx, req.PluginContext.DataSourceInstanceSettings.UID, q.JSON) + var qdto querydto.QueryDTO + err := json.Unmarshal(q.JSON, &qdto) + if err != nil { + response.Responses[q.RefID] = handleQueryError("invalid query", err) + continue + } + + var authHeaderToBePassed *string = nil + if authorizationHeaderPtr != nil && d.isServerUrlTrusted(qdto.ServerUrl, true, authorizationHeaderPtr) { + authHeaderToBePassed = authorizationHeaderPtr + } - response.Responses[q.RefID] = res + response.Responses[q.RefID] = d.query(ctx, req.PluginContext.DataSourceInstanceSettings.UID, qdto, authHeaderToBePassed) } return response, nil } -func (d *SampleDatasource) query(_ context.Context, datasourceUid string, queryJson json.RawMessage) backend.DataResponse { - var qdto querydto.QueryDTO - err := json.Unmarshal(queryJson, &qdto) - if err != nil { - return handleQueryError("invalid query", err) - } +func (d *SampleDatasource) query(_ context.Context, datasourceUid string, qdto querydto.QueryDTO, authorizationHeader *string) backend.DataResponse { s, err := server.FromUrlString(qdto.ServerUrl) if err != nil { return handleQueryError("invalid server URL", err) } - q := query.New(*s, qdto.ProjectName, qdto.CqName, qdto.WindowName, qdto.Interval, qdto.MaxDataPoints, qdto.Fields) + q := query.New(*s, qdto.ProjectName, qdto.CqName, qdto.WindowName, qdto.Interval, qdto.MaxDataPoints, qdto.Fields, authorizationHeader) channelPath, err := q.ToChannelPath() if err != nil { @@ -248,7 +272,7 @@ func (d *SampleDatasource) RunStream(ctx context.Context, req *backend.RunStream return nil } - espWsClient := client.New(q.ServerUrl) + espWsClient := client.New(q.ServerUrl, q.AuthorizationHeader) defer espWsClient.Close() subscribeToQuery := func() { @@ -352,83 +376,121 @@ func newSerializedCallResourceResponseErrorBody(errorMessage string) []byte { return errorResponseBody } -func (d *SampleDatasource) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { - var response = &backend.CallResourceResponse{ - Status: http.StatusNotFound, - Headers: map[string][]string{}, - } +type discoveredServer struct { + Url string `json:"url"` + Trusted bool `json:"trusted"` +} +func (d *SampleDatasource) CallResource(_ context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + var response backend.CallResourceResponse switch req.Path { case "servers": - discoveryServiceUrl, err := url.Parse(req.PluginContext.DataSourceInstanceSettings.URL) - if err != nil { - return err + var authHeaderPtr *string + if d.jsonData.OauthPassThru == true { + authHeader := req.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName) + authHeaderPtr = &authHeader } - - var discoveryResponse *http.Response - discoveryResponse, err = callDiscoveryEndpoint(ctx, d.httpClient, *discoveryServiceUrl) + serversData, discoveredServers, err := d.fetchDiscoverableServers(authHeaderPtr) if err != nil { - errorMessage := "Unable to obtain ESP server schema information." - response.Body = newSerializedCallResourceResponseErrorBody(errorMessage) - response.Status = http.StatusInternalServerError - log.DefaultLogger.Error(errorMessage, "error", err) + log.DefaultLogger.Error(err.Error()) + body := newSerializedCallResourceResponseErrorBody("Unable to fetch discoverable ESP servers.") + response := &backend.CallResourceResponse{ + Status: http.StatusInternalServerError, + Body: body, + } return sender.Send(response) } - servers, err := io.ReadAll(discoveryResponse.Body) - if err != nil { - errorMessage := "Unable to read discovery response." - response.Body = newSerializedCallResourceResponseErrorBody(errorMessage) - response.Status = http.StatusInternalServerError - log.DefaultLogger.Error(errorMessage, "error", err) - return sender.Send(response) + for _, discoveredServer := range *discoveredServers { + d.serverUrlTrustedMap.Set(discoveredServer.Url, &discoveredServer.Trusted) } - responseBody, err := json.Marshal(callResourceResponseBody{Data: servers}) + responseBody, err := json.Marshal(callResourceResponseBody{Data: *serversData}) if err != nil { errorMessage := "Unable to serialize discovery response." - response.Body = newSerializedCallResourceResponseErrorBody(errorMessage) - response.Status = http.StatusInternalServerError + response := &backend.CallResourceResponse{ + Status: http.StatusInternalServerError, + Body: newSerializedCallResourceResponseErrorBody(errorMessage), + } log.DefaultLogger.Error(errorMessage, "error", err) return sender.Send(response) } - response.Status = http.StatusOK - response.Body = responseBody + response = backend.CallResourceResponse{ + Status: http.StatusOK, + Body: responseBody, + } + return sender.Send(&response) default: + response = backend.CallResourceResponse{ + Status: http.StatusNotFound, + } break } - return sender.Send(response) + return sender.Send(&response) } -func callDiscoveryEndpoint(ctx context.Context, httpClient *http.Client, discoveryServiceUrl url.URL) (*http.Response, error) { - var discoveryEndpointUrl = discoveryServiceUrl.String() + "/grafana/discovery" +func (d *SampleDatasource) fetchDiscoverableServers(authHeader *string) (*[]byte, *[]discoveredServer, error) { + var discoveryEndpointUrl = d.discoveryEndpointUrl.String() + "/grafana/discovery" log.DefaultLogger.Debug("Calling discovery endpoint", "discoveryEndpointUrl", discoveryEndpointUrl) - request, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryEndpointUrl, nil) + request, err := http.NewRequest(http.MethodGet, discoveryEndpointUrl, nil) if err != nil { log.DefaultLogger.Error("Unable to create discovery request.", "error", err) - return nil, err + return nil, nil, err } - if request.Header.Get("Accept-Encoding") == "" { - // Grafana (as of version 9.x) contains a bug where it will enable compression by adding an explicit Accept-Encoding header by default if missing, - // but not properly handle decompression of any compressed response. - // The header set explicitly here as a workaround for this bug. - request.Header.Set("Accept-Encoding", "identity") + if authHeader != nil { + request.Header.Set(backend.OAuthIdentityTokenHeaderName, *authHeader) } - resp, err := httpClient.Do(request) + resp, err := d.httpClient.Do(request) if err != nil { - log.DefaultLogger.Error("Unable to receive discovery response.", "error", err) - return nil, err + log.DefaultLogger.Error(err.Error()) + return nil, nil, fmt.Errorf("unable to receive discovery response") } if resp.StatusCode != 200 { - var err = fmt.Errorf("the discovery service sent an unexpected HTTP status code: %d", resp.StatusCode) - return nil, err + return nil, nil, fmt.Errorf("received unexpected HTTP status code: %d", resp.StatusCode) + } + + serversData, err := io.ReadAll(resp.Body) + if err != nil { + log.DefaultLogger.Error(err.Error()) + return nil, nil, fmt.Errorf("unable to read discovery response") + } + + var discoveredServers []discoveredServer + err = json.Unmarshal(serversData, &discoveredServers) + if err != nil { + log.DefaultLogger.Error(err.Error()) + return &serversData, nil, fmt.Errorf("unable to unmarshal discovery response") + } + + return &serversData, &discoveredServers, nil +} + +func (d *SampleDatasource) isServerUrlTrusted(url string, fetchIfMissing bool, authHeader *string) bool { + isServerUrlTrusted, err := d.serverUrlTrustedMap.Get(url) + if err == nil { + return *isServerUrlTrusted + } + + if fetchIfMissing { + _, discoveredServers, fetchErr := d.fetchDiscoverableServers(authHeader) + if fetchErr != nil { + log.DefaultLogger.Error("Unable to fetch trusted status of server URL", "url", url, "error", err) + return false + } + + for _, discoveredServer := range *discoveredServers { + d.serverUrlTrustedMap.Set(discoveredServer.Url, &discoveredServer.Trusted) + } + + return d.isServerUrlTrusted(url, false, nil) } - return resp, nil + log.DefaultLogger.Error("Unable to determine trusted status of server URL", "url", url, "error", err) + return false }