Skip to content

Commit

Permalink
Trying to make the auth elements configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
kellrott committed Jan 3, 2025
1 parent e201659 commit d6626b2
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 76 deletions.
69 changes: 39 additions & 30 deletions gripgraphql/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net/http"
"os"
"strconv"
"sync"

//"encoding/json"
Expand Down Expand Up @@ -44,6 +45,7 @@ type GraphQLJS struct {
gjHandler *handler.Handler
Pool sync.Pool
cw *JSClientWrapper
auth bool
//Once sync.Once
}

Expand Down Expand Up @@ -151,7 +153,7 @@ func (e *Endpoint) Add(x map[string]any) {
objField, err := parseField(name, schemaA)
if err == nil {
objField.Resolve = func(params graphql.ResolveParams) (interface{}, error) {
log.Infof("Calling resolver \n")
log.Debug("Calling resolver")
uArgs := map[string]any{}
for k, v := range defaults {
uArgs[k] = v
Expand All @@ -174,9 +176,10 @@ func (e *Endpoint) Add(x map[string]any) {
args := goja.FunctionCall{
Arguments: []goja.Value{e.cw.toValue(), vArgs},
}

log.Infof("Calling user function")
val := jHandler(args)
out := jsExport(val)
log.Infof("User function returned : %#v", out)
return out, nil
}

Expand Down Expand Up @@ -257,6 +260,10 @@ func NewHTTPHandler(client gripql.Client, config map[string]string) (http.Handle
if c, ok := config["graph"]; ok {
graph = c
}
auth := false
if c, ok := config["auth"]; ok {
auth, _ = strconv.ParseBool(c)
}
file, err := os.Open(configPath)
if err != nil {
return nil, err
Expand All @@ -272,7 +279,7 @@ func NewHTTPHandler(client gripql.Client, config map[string]string) (http.Handle
New: func() any {
vm := goja.New()
vm.SetFieldNameMapper(JSRenamer{})
jsClient, err := GetJSClient(graph, client, vm)
jsClient, err := GetJSClient(graph, client, vm, auth)
if err != nil {
log.Infof("js error: %s\n", err)
}
Expand All @@ -286,9 +293,11 @@ func NewHTTPHandler(client gripql.Client, config map[string]string) (http.Handle
"Boolean": "Boolean",
})

vm.Set("print", fmt.Printf) //Adding print statement for debugging. This may need to be removed/updated

_, err = vm.RunString(string(data))
if err != nil {
log.Errorf("Error running data config", err)
log.Errorf("Error running data config %s", err)
}

schema, err := e.Build()
Expand All @@ -298,15 +307,15 @@ func NewHTTPHandler(client gripql.Client, config map[string]string) (http.Handle
hnd = handler.New(&handler.Config{
Schema: schema,
})
gh := &GraphQLJS{client: client, gjHandler: hnd, cw: jsClient}
gh := &GraphQLJS{client: client, gjHandler: hnd, cw: jsClient, auth: auth}
return gh
},
}
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
log.Infof("Getting graph handler from Sync Pool +++++++++++++++++++++++++++++++++++++++++++++++++++++")
log.Debug("Getting graph handler from Sync Pool +++++++++++++++++++++++++++++++++++++++++++++++++++++")
gh := Pool.Get().(*GraphQLJS)
defer func() {
log.Infof("Putting graph handler back to Pool ---------------------------------------------------")
log.Debug("Putting graph handler back to Pool ---------------------------------------------------")
Pool.Put(gh)
}()
gh.ServeHTTP(writer, request)
Expand All @@ -322,35 +331,35 @@ func (gh *GraphQLJS) ServeHTTP(writer http.ResponseWriter, request *http.Request
if request.URL.Path == "/api" || request.URL.Path == "api" {
requestHeaders := request.Header
ctx := context.WithValue(context.Background(), "Header", requestHeaders)

var jwtHandler middleware.JWTHandler = &middleware.ProdJWTHandler{}
if gh.cw.graph == "TEST" {
jwtHandler = &middleware.MockJWTHandler{}
}
//fmt.Println("REQUEST HEADERS:::: +++++++++++++++++++", requestHeaders)
if val, ok := requestHeaders["Authorization"]; ok {
Token := val[0]
resourceList, err := jwtHandler.HandleJWTToken(Token, "read")
//resourceList := []any{"/programs/cbds/projects/demo", "/programs/cbds/projects/welcome", "/programs/synthea/projects/test"}
if err != nil {
middleware.HandleError(err, writer)
return err
if gh.auth {
var jwtHandler middleware.JWTHandler = &middleware.ProdJWTHandler{}
if gh.cw.graph == "TEST" {
jwtHandler = &middleware.MockJWTHandler{}
}
//fmt.Println("REQUEST HEADERS:::: +++++++++++++++++++", requestHeaders)
if val, ok := requestHeaders["Authorization"]; ok {
Token := val[0]
resourceList, err := jwtHandler.HandleJWTToken(Token, "read")
//resourceList := []any{"/programs/cbds/projects/demo", "/programs/cbds/projects/welcome", "/programs/synthea/projects/test"}
if err != nil {
middleware.HandleError(err, writer)
return err
}

if len(resourceList) == 0 || err != nil {
if len(resourceList) == 0 {
err = &middleware.ServerError{StatusCode: http.StatusForbidden, Message: "User does not have access to any projects"}
if len(resourceList) == 0 || err != nil {
if len(resourceList) == 0 {
err = &middleware.ServerError{StatusCode: http.StatusForbidden, Message: "User does not have access to any projects"}
}
middleware.HandleError(err, writer)
return err
}
middleware.HandleError(err, writer)
ctx = context.WithValue(ctx, "ResourceList", resourceList)
} else {
err := middleware.HandleError(&middleware.ServerError{StatusCode: http.StatusUnauthorized, Message: "No authorization header provided."}, writer)
log.Infoln("ERR: ", err)
return err
}
ctx = context.WithValue(ctx, "ResourceList", resourceList)
} else {
err := middleware.HandleError(&middleware.ServerError{StatusCode: http.StatusUnauthorized, Message: "No authorization header provided."}, writer)
log.Infoln("ERR: ", err)
return err
}

gh.gjHandler.ServeHTTP(writer, request.WithContext(ctx))
}

Expand Down
106 changes: 60 additions & 46 deletions gripgraphql/js_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type JSClientWrapper struct {
client gripql.Client
query goja.Callable
graph string

auth bool
}

type JSRenamer struct{}
Expand Down Expand Up @@ -77,62 +79,73 @@ func (cw *JSClientWrapper) ToList(args goja.Value) goja.Value {
log.Infof("Error: %s\n", err)
return nil
}
ResourceList := cw.vm.Get("ResourceList").Export()
Header := cw.vm.Get("Header").Export().(any)
ctx := context.WithValue(context.Background(), "Header", Header)
ctx = context.WithValue(ctx, "ResourceList", ResourceList)

query := gripql.GraphQuery{}
err = protojson.Unmarshal(queryJSON, &query)

sValue, _ := structpb.NewValue(ResourceList)
Has_Statement := &gripql.GraphStatement{Statement: &gripql.GraphStatement_Has{
Has: &gripql.HasExpression{Expression: &gripql.HasExpression_Condition{
Condition: &gripql.HasCondition{
Condition: gripql.Condition_WITHIN,
Key: "auth_resource_path",
Value: sValue,
},
}},
}}

query.Graph = cw.graph
steps := inspect.PipelineSteps(query.Query)

var ctx context.Context
FilteredGS, CachedGS, RemainingGS := []*gripql.GraphStatement{}, []*gripql.GraphStatement{}, []*gripql.GraphStatement{}
for i, v := range query.Query {
steps_index, err := strconv.Atoi(steps[i])

if cw.auth {
ResourceList := cw.vm.Get("ResourceList").Export()
Header := cw.vm.Get("Header").Export().(any)
ctx := context.WithValue(context.Background(), "Header", Header)
ctx = context.WithValue(ctx, "ResourceList", ResourceList)

sValue, _ := structpb.NewValue(ResourceList)
Has_Statement := &gripql.GraphStatement{Statement: &gripql.GraphStatement_Has{
Has: &gripql.HasExpression{Expression: &gripql.HasExpression_Condition{
Condition: &gripql.HasCondition{
Condition: gripql.Condition_WITHIN,
Key: "auth_resource_path",
Value: sValue,
},
}},
}}
steps := inspect.PipelineSteps(query.Query)
for i, v := range query.Query {
steps_index, err := strconv.Atoi(steps[i])
if err != nil {
log.Infof("Error: %s\n", err)
return nil
}
if i > steps_index {
RemainingGS = append(RemainingGS, v)
}

if i == steps_index {
FilteredGS = append(FilteredGS, v, Has_Statement)
CachedGS = append(CachedGS, v, Has_Statement)
} else {
if i == 0 {
CachedGS = append(CachedGS, v)
}
FilteredGS = append(FilteredGS, v)
}
}
query.Query = FilteredGS
} else {
ctx = context.Background()
}

/*
log.Infof("Getting cached job")
resultChan, err := cw.GetCachedJob(query, CachedGS, RemainingGS)
if err != nil {
log.Infof("Error: %s\n", err)
return nil
}
if i > steps_index {
RemainingGS = append(RemainingGS, v)
}

if i == steps_index {
FilteredGS = append(FilteredGS, v, Has_Statement)
CachedGS = append(CachedGS, v, Has_Statement)
} else {
if i == 0 {
CachedGS = append(CachedGS, v)
if resultChan != nil {
cachedOut := []any{}
for row := range resultChan {
cachedOut = append(cachedOut, cw.vm.ToValue(toInterface(row)))
}
FilteredGS = append(FilteredGS, v)
return cw.vm.ToValue(cachedOut)
}
}
*/

query.Query = FilteredGS
resultChan, err := cw.GetCachedJob(query, CachedGS, RemainingGS)
if err != nil {
log.Infof("Error: %s\n", err)
return nil
}
if resultChan != nil {
cachedOut := []any{}
for row := range resultChan {
cachedOut = append(cachedOut, cw.vm.ToValue(toInterface(row)))
}
return cw.vm.ToValue(cachedOut)
}
log.Infof("Doing traversal")

res, err := cw.client.Traversal(ctx, &query)
if err != nil {
Expand All @@ -144,6 +157,7 @@ func (cw *JSClientWrapper) ToList(args goja.Value) goja.Value {
for row := range res {
out = append(out, cw.vm.ToValue(toInterface(row)))
}
//log.Infof("Returning value: %s\n", out)

return cw.vm.ToValue(out)
}
Expand All @@ -166,13 +180,13 @@ func (cw *JSClientWrapper) toValue() goja.Value {
return cw.vm.ToValue(cw)
}

func GetJSClient(graph string, client gripql.Client, vm *goja.Runtime) (*JSClientWrapper, error) { // ctx context.Context
func GetJSClient(graph string, client gripql.Client, vm *goja.Runtime, auth bool) (*JSClientWrapper, error) { // ctx context.Context
gripqljs, _ := gripqljs.Asset("gripql.js")
vm.RunString(string(gripqljs))

qVal := vm.Get("query")
query, _ := goja.AssertFunction(qVal)

myWrapper := &JSClientWrapper{vm, client, query, graph}
myWrapper := &JSClientWrapper{vm, client, query, graph, auth}
return myWrapper, nil
}

0 comments on commit d6626b2

Please sign in to comment.