Skip to content

Commit e4dd2da

Browse files
author
amlmtl
committed
feat: introduce OAuth2 authentication passthrough for trusted ESP servers in Viya
1 parent d704aaa commit e4dd2da

File tree

3 files changed

+145
-78
lines changed

3 files changed

+145
-78
lines changed

internal/esp/client/espwsclient.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,11 @@ const (
6161
const jsonFormat string = "json"
6262
const cborFormat string = "cbor"
6363

64-
func New(wsConnectionUrl url.URL) *EspWsClient {
64+
func New(wsConnectionUrl url.URL, authorizationHeader *string) *EspWsClient {
6565
socket := gowebsocket.New(wsConnectionUrl.String())
66+
if authorizationHeader != nil {
67+
socket.RequestHeader.Set("Authorization", *authorizationHeader)
68+
}
6669

6770
espWsClient := EspWsClient{
6871
socket: &socket,

internal/plugin/query/query.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,26 @@ import (
1818
const CHANNEL_PATH_REGEX_PATTERN string = `^[A-z0-9_\-/=.]*$`
1919

2020
type Query struct {
21-
ServerUrl url.URL
22-
ProjectName string
23-
CqName string
24-
WindowName string
25-
Fields []string
26-
EventInterval uint64
27-
MaxEvents uint64
21+
ServerUrl url.URL
22+
ProjectName string
23+
CqName string
24+
WindowName string
25+
Fields []string
26+
EventInterval uint64
27+
MaxEvents uint64
28+
AuthorizationHeader *string
2829
}
2930

30-
func New(s server.Server, projectName string, cqName string, windowName string, interval uint64, maxEvents uint64, fields []string) *Query {
31+
func New(s server.Server, projectName string, cqName string, windowName string, interval uint64, maxEvents uint64, fields []string, authorizationHeader *string) *Query {
3132
return &Query{
32-
ServerUrl: s.GetUrl(),
33-
ProjectName: projectName,
34-
CqName: cqName,
35-
WindowName: windowName,
36-
EventInterval: interval,
37-
MaxEvents: maxEvents,
38-
Fields: fields,
33+
ServerUrl: s.GetUrl(),
34+
ProjectName: projectName,
35+
CqName: cqName,
36+
WindowName: windowName,
37+
EventInterval: interval,
38+
MaxEvents: maxEvents,
39+
Fields: fields,
40+
AuthorizationHeader: authorizationHeader,
3941
}
4042
}
4143

pkg/plugin/plugin.go

Lines changed: 124 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,20 @@ var (
4848

4949
// NewSampleDatasource creates a new datasource instance.
5050
func NewSampleDatasource(ctx context.Context, settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
51+
discoveryUrl, err := url.Parse(settings.URL)
52+
if err != nil {
53+
return nil, err
54+
}
55+
5156
opts, err := settings.HTTPClientOptions(ctx)
5257
if err != nil {
53-
return nil, fmt.Errorf("http client options: %w", err)
58+
return nil, err
5459
}
5560
opts.ForwardHTTPHeaders = true
5661

5762
cl, err := httpclient.New(opts)
5863
if err != nil {
59-
return nil, fmt.Errorf("httpclient new: %w", err)
64+
return nil, err
6065
}
6166

6267
var jsonData datasourceJsonData
@@ -68,18 +73,22 @@ func NewSampleDatasource(ctx context.Context, settings backend.DataSourceInstanc
6873
log.DefaultLogger.Debug(fmt.Sprintf("created data source with ForwardHTTPHeaders option set to: %v", opts.ForwardHTTPHeaders))
6974

7075
return &SampleDatasource{
71-
channelQueryMap: syncmap.New[string, query.Query](),
72-
httpClient: cl,
73-
jsonData: jsonData,
76+
httpClient: cl,
77+
discoveryEndpointUrl: *discoveryUrl,
78+
jsonData: jsonData,
79+
channelQueryMap: syncmap.New[string, query.Query](),
80+
serverUrlTrustedMap: syncmap.New[string, bool](),
7481
}, nil
7582
}
7683

7784
// SampleDatasource is an example datasource which can respond to data queries, reports
7885
// its health and has streaming skills.
7986
type SampleDatasource struct {
80-
channelQueryMap *syncmap.SyncMap[string, query.Query]
81-
httpClient *http.Client
82-
jsonData datasourceJsonData
87+
channelQueryMap *syncmap.SyncMap[string, query.Query]
88+
httpClient *http.Client
89+
jsonData datasourceJsonData
90+
serverUrlTrustedMap *syncmap.SyncMap[string, bool]
91+
discoveryEndpointUrl url.URL
8392
}
8493

8594
type datasourceJsonData struct {
@@ -100,30 +109,45 @@ func (d *SampleDatasource) Dispose() {
100109
// The QueryDataResponse contains a map of RefID to the response for each query, and each response
101110
// contains Frames ([]*Frame).
102111
func (d *SampleDatasource) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
103-
// create response struct
104112
response := backend.NewQueryDataResponse()
105113

114+
var jsonData datasourceJsonData
115+
err := json.Unmarshal(req.PluginContext.DataSourceInstanceSettings.JSONData, &jsonData)
116+
if err != nil {
117+
return nil, err
118+
}
119+
var authorizationHeaderPtr *string = nil
120+
if jsonData.OauthPassThru && jsonData.IsViya {
121+
authorizationHeader := req.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName)
122+
authorizationHeaderPtr = &authorizationHeader
123+
}
124+
106125
for _, q := range req.Queries {
107-
res := d.query(ctx, req.PluginContext.DataSourceInstanceSettings.UID, q.JSON)
126+
var qdto querydto.QueryDTO
127+
err := json.Unmarshal(q.JSON, &qdto)
128+
if err != nil {
129+
response.Responses[q.RefID] = handleQueryError("invalid query", err)
130+
continue
131+
}
132+
133+
var authHeaderToBePassed *string = nil
134+
if authorizationHeaderPtr != nil && d.isServerUrlTrusted(qdto.ServerUrl, true, authorizationHeaderPtr) {
135+
authHeaderToBePassed = authorizationHeaderPtr
136+
}
108137

109-
response.Responses[q.RefID] = res
138+
response.Responses[q.RefID] = d.query(ctx, req.PluginContext.DataSourceInstanceSettings.UID, qdto, authHeaderToBePassed)
110139
}
111140

112141
return response, nil
113142
}
114143

115-
func (d *SampleDatasource) query(_ context.Context, datasourceUid string, queryJson json.RawMessage) backend.DataResponse {
116-
var qdto querydto.QueryDTO
117-
err := json.Unmarshal(queryJson, &qdto)
118-
if err != nil {
119-
return handleQueryError("invalid query", err)
120-
}
144+
func (d *SampleDatasource) query(_ context.Context, datasourceUid string, qdto querydto.QueryDTO, authorizationHeader *string) backend.DataResponse {
121145
s, err := server.FromUrlString(qdto.ServerUrl)
122146
if err != nil {
123147
return handleQueryError("invalid server URL", err)
124148
}
125149

126-
q := query.New(*s, qdto.ProjectName, qdto.CqName, qdto.WindowName, qdto.Interval, qdto.MaxDataPoints, qdto.Fields)
150+
q := query.New(*s, qdto.ProjectName, qdto.CqName, qdto.WindowName, qdto.Interval, qdto.MaxDataPoints, qdto.Fields, authorizationHeader)
127151

128152
channelPath, err := q.ToChannelPath()
129153
if err != nil {
@@ -248,7 +272,7 @@ func (d *SampleDatasource) RunStream(ctx context.Context, req *backend.RunStream
248272
return nil
249273
}
250274

251-
espWsClient := client.New(q.ServerUrl)
275+
espWsClient := client.New(q.ServerUrl, q.AuthorizationHeader)
252276
defer espWsClient.Close()
253277

254278
subscribeToQuery := func() {
@@ -352,83 +376,121 @@ func newSerializedCallResourceResponseErrorBody(errorMessage string) []byte {
352376
return errorResponseBody
353377
}
354378

355-
func (d *SampleDatasource) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
356-
var response = &backend.CallResourceResponse{
357-
Status: http.StatusNotFound,
358-
Headers: map[string][]string{},
359-
}
379+
type discoveredServer struct {
380+
Url string `json:"url"`
381+
Trusted bool `json:"trusted"`
382+
}
360383

384+
func (d *SampleDatasource) CallResource(_ context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
385+
var response backend.CallResourceResponse
361386
switch req.Path {
362387
case "servers":
363-
discoveryServiceUrl, err := url.Parse(req.PluginContext.DataSourceInstanceSettings.URL)
364-
if err != nil {
365-
return err
388+
var authHeaderPtr *string
389+
if d.jsonData.OauthPassThru == true {
390+
authHeader := req.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName)
391+
authHeaderPtr = &authHeader
366392
}
367-
368-
var discoveryResponse *http.Response
369-
discoveryResponse, err = callDiscoveryEndpoint(ctx, d.httpClient, *discoveryServiceUrl)
393+
serversData, discoveredServers, err := d.fetchDiscoverableServers(authHeaderPtr)
370394
if err != nil {
371-
errorMessage := "Unable to obtain ESP server schema information."
372-
response.Body = newSerializedCallResourceResponseErrorBody(errorMessage)
373-
response.Status = http.StatusInternalServerError
374-
log.DefaultLogger.Error(errorMessage, "error", err)
395+
log.DefaultLogger.Error(err.Error())
396+
body := newSerializedCallResourceResponseErrorBody("Unable to fetch discoverable ESP servers.")
397+
response := &backend.CallResourceResponse{
398+
Status: http.StatusInternalServerError,
399+
Body: body,
400+
}
375401
return sender.Send(response)
376402
}
377403

378-
servers, err := io.ReadAll(discoveryResponse.Body)
379-
if err != nil {
380-
errorMessage := "Unable to read discovery response."
381-
response.Body = newSerializedCallResourceResponseErrorBody(errorMessage)
382-
response.Status = http.StatusInternalServerError
383-
log.DefaultLogger.Error(errorMessage, "error", err)
384-
return sender.Send(response)
404+
for _, discoveredServer := range *discoveredServers {
405+
d.serverUrlTrustedMap.Set(discoveredServer.Url, &discoveredServer.Trusted)
385406
}
386407

387-
responseBody, err := json.Marshal(callResourceResponseBody{Data: servers})
408+
responseBody, err := json.Marshal(callResourceResponseBody{Data: *serversData})
388409
if err != nil {
389410
errorMessage := "Unable to serialize discovery response."
390-
response.Body = newSerializedCallResourceResponseErrorBody(errorMessage)
391-
response.Status = http.StatusInternalServerError
411+
response := &backend.CallResourceResponse{
412+
Status: http.StatusInternalServerError,
413+
Body: newSerializedCallResourceResponseErrorBody(errorMessage),
414+
}
392415
log.DefaultLogger.Error(errorMessage, "error", err)
393416
return sender.Send(response)
394417
}
395418

396-
response.Status = http.StatusOK
397-
response.Body = responseBody
419+
response = backend.CallResourceResponse{
420+
Status: http.StatusOK,
421+
Body: responseBody,
422+
}
423+
return sender.Send(&response)
398424
default:
425+
response = backend.CallResourceResponse{
426+
Status: http.StatusNotFound,
427+
}
399428
break
400429
}
401430

402-
return sender.Send(response)
431+
return sender.Send(&response)
403432
}
404433

405-
func callDiscoveryEndpoint(ctx context.Context, httpClient *http.Client, discoveryServiceUrl url.URL) (*http.Response, error) {
406-
var discoveryEndpointUrl = discoveryServiceUrl.String() + "/grafana/discovery"
434+
func (d *SampleDatasource) fetchDiscoverableServers(authHeader *string) (*[]byte, *[]discoveredServer, error) {
435+
var discoveryEndpointUrl = d.discoveryEndpointUrl.String() + "/grafana/discovery"
407436
log.DefaultLogger.Debug("Calling discovery endpoint", "discoveryEndpointUrl", discoveryEndpointUrl)
408437

409-
request, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryEndpointUrl, nil)
438+
request, err := http.NewRequest(http.MethodGet, discoveryEndpointUrl, nil)
410439
if err != nil {
411440
log.DefaultLogger.Error("Unable to create discovery request.", "error", err)
412-
return nil, err
441+
return nil, nil, err
413442
}
414443

415-
if request.Header.Get("Accept-Encoding") == "" {
416-
// Grafana (as of version 9.x) contains a bug where it will enable compression by adding an explicit Accept-Encoding header by default if missing,
417-
// but not properly handle decompression of any compressed response.
418-
// The header set explicitly here as a workaround for this bug.
419-
request.Header.Set("Accept-Encoding", "identity")
444+
if authHeader != nil {
445+
request.Header.Set(backend.OAuthIdentityTokenHeaderName, *authHeader)
420446
}
421447

422-
resp, err := httpClient.Do(request)
448+
resp, err := d.httpClient.Do(request)
423449
if err != nil {
424-
log.DefaultLogger.Error("Unable to receive discovery response.", "error", err)
425-
return nil, err
450+
log.DefaultLogger.Error(err.Error())
451+
return nil, nil, fmt.Errorf("unable to receive discovery response")
426452
}
427453

428454
if resp.StatusCode != 200 {
429-
var err = fmt.Errorf("the discovery service sent an unexpected HTTP status code: %d", resp.StatusCode)
430-
return nil, err
455+
return nil, nil, fmt.Errorf("received unexpected HTTP status code: %d", resp.StatusCode)
456+
}
457+
458+
serversData, err := io.ReadAll(resp.Body)
459+
if err != nil {
460+
log.DefaultLogger.Error(err.Error())
461+
return nil, nil, fmt.Errorf("unable to read discovery response")
462+
}
463+
464+
var discoveredServers []discoveredServer
465+
err = json.Unmarshal(serversData, &discoveredServers)
466+
if err != nil {
467+
log.DefaultLogger.Error(err.Error())
468+
return &serversData, nil, fmt.Errorf("unable to unmarshal discovery response")
469+
}
470+
471+
return &serversData, &discoveredServers, nil
472+
}
473+
474+
func (d *SampleDatasource) isServerUrlTrusted(url string, fetchIfMissing bool, authHeader *string) bool {
475+
isServerUrlTrusted, err := d.serverUrlTrustedMap.Get(url)
476+
if err == nil {
477+
return *isServerUrlTrusted
478+
}
479+
480+
if fetchIfMissing {
481+
_, discoveredServers, fetchErr := d.fetchDiscoverableServers(authHeader)
482+
if fetchErr != nil {
483+
log.DefaultLogger.Error("Unable to fetch trusted status of server URL", "url", url, "error", err)
484+
return false
485+
}
486+
487+
for _, discoveredServer := range *discoveredServers {
488+
d.serverUrlTrustedMap.Set(discoveredServer.Url, &discoveredServer.Trusted)
489+
}
490+
491+
return d.isServerUrlTrusted(url, false, nil)
431492
}
432493

433-
return resp, nil
494+
log.DefaultLogger.Error("Unable to determine trusted status of server URL", "url", url, "error", err)
495+
return false
434496
}

0 commit comments

Comments
 (0)