Skip to content

Commit

Permalink
Add support for relationship filtering on watch API
Browse files Browse the repository at this point in the history
  • Loading branch information
josephschorr committed Mar 13, 2024
1 parent 8f28907 commit 7dd68a1
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 5 deletions.
94 changes: 89 additions & 5 deletions internal/commands/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package commands

import (
"context"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"

Expand All @@ -15,9 +17,10 @@ import (
)

var (
watchObjectTypes []string
watchRevision string
watchTimestamps bool
watchObjectTypes []string
watchRevision string
watchTimestamps bool
watchRelationshipFilters []string
)

func RegisterWatchCmd(rootCmd *cobra.Command) *cobra.Command {
Expand All @@ -34,6 +37,7 @@ func RegisterWatchRelationshipCmd(parentCmd *cobra.Command) *cobra.Command {
watchRelationshipsCmd.Flags().StringSliceVar(&watchObjectTypes, "object_types", nil, "optional object types to watch updates for")
watchRelationshipsCmd.Flags().StringVar(&watchRevision, "revision", "", "optional revision at which to start watching")
watchRelationshipsCmd.Flags().BoolVar(&watchTimestamps, "timestamp", false, "shows timestamp of incoming update events")
watchRelationshipsCmd.Flags().StringSliceVar(&watchRelationshipFilters, "filter", nil, "optional filter(s) for the watch stream. Example: `optional_resource_type:optional_resource_id_or_prefix#optional_relation@optional_subject_filter`")
return watchRelationshipsCmd
}

Expand All @@ -53,15 +57,25 @@ var watchRelationshipsCmd = &cobra.Command{
}

func watchCmdFunc(cmd *cobra.Command, _ []string) error {
console.Errorf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision)
console.Printf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision)

cli, err := client.NewClient(cmd)
if err != nil {
return err
}

relFilters := make([]*v1.RelationshipFilter, 0, len(watchRelationshipFilters))
for _, filter := range watchRelationshipFilters {
relFilter, err := parseRelationshipFilter(filter)
if err != nil {
return err
}
relFilters = append(relFilters, relFilter)
}

req := &v1.WatchRequest{
OptionalObjectTypes: watchObjectTypes,
OptionalObjectTypes: watchObjectTypes,
OptionalRelationshipFilters: relFilters,
}
if watchRevision != "" {
req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
Expand Down Expand Up @@ -102,3 +116,73 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error {
}
}
}

func parseRelationshipFilter(relFilterStr string) (*v1.RelationshipFilter, error) {
relFilter := &v1.RelationshipFilter{}
pieces := strings.Split(relFilterStr, "@")
if len(pieces) > 2 {
return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr)
}

if len(pieces) == 2 {
subjectFilter, err := parseSubjectFilter(pieces[1])
if err != nil {
return nil, err
}
relFilter.OptionalSubjectFilter = subjectFilter
}

if len(pieces) > 0 {
resourcePieces := strings.Split(pieces[0], "#")
if len(resourcePieces) > 2 {
return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr)
}

if len(resourcePieces) == 2 {
relFilter.OptionalRelation = resourcePieces[1]
}

resourceTypePieces := strings.Split(resourcePieces[0], ":")
if len(resourceTypePieces) > 2 {
return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr)
}

relFilter.ResourceType = resourceTypePieces[0]
if len(resourceTypePieces) == 2 {
optionalResourceIDOrPrefix := resourceTypePieces[1]
if strings.HasSuffix(optionalResourceIDOrPrefix, "%") {
relFilter.OptionalResourceIdPrefix = strings.TrimSuffix(optionalResourceIDOrPrefix, "%")
} else {
relFilter.OptionalResourceId = optionalResourceIDOrPrefix
}
}
}

return relFilter, nil
}

func parseSubjectFilter(subjectFilterStr string) (*v1.SubjectFilter, error) {
subjectFilter := &v1.SubjectFilter{}
pieces := strings.Split(subjectFilterStr, "#")
if len(pieces) > 2 {
return nil, fmt.Errorf("invalid subject filter: %s", subjectFilterStr)
}

subjectTypePieces := strings.Split(pieces[0], ":")
if len(subjectTypePieces) > 2 {
return nil, fmt.Errorf("invalid subject filter: %s", subjectFilterStr)
}

subjectFilter.SubjectType = subjectTypePieces[0]
if len(subjectTypePieces) == 2 {
subjectFilter.OptionalSubjectId = subjectTypePieces[1]
}

if len(pieces) == 2 {
subjectFilter.OptionalRelation = &v1.SubjectFilter_RelationFilter{
Relation: pieces[1],
}
}

return subjectFilter, nil
}
110 changes: 110 additions & 0 deletions internal/commands/watch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package commands

import (
"reflect"
"testing"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
)

func TestParseRelationshipFilter(t *testing.T) {
tcs := []struct {
input string
expected *v1.RelationshipFilter
}{
{
input: "resourceType:resourceId",
expected: &v1.RelationshipFilter{
ResourceType: "resourceType",
OptionalResourceId: "resourceId",
},
},
{
input: "resourceType:resourceId%",
expected: &v1.RelationshipFilter{
ResourceType: "resourceType",
OptionalResourceIdPrefix: "resourceId",
},
},
{
input: "resourceType:resourceId#relation",
expected: &v1.RelationshipFilter{
ResourceType: "resourceType",
OptionalResourceId: "resourceId",
OptionalRelation: "relation",
},
},
{
input: "resourceType:resourceId#relation@subjectType:subjectId",
expected: &v1.RelationshipFilter{
ResourceType: "resourceType",
OptionalResourceId: "resourceId",
OptionalRelation: "relation",
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: "subjectType",
OptionalSubjectId: "subjectId",
},
},
},
{
input: "#relation",
expected: &v1.RelationshipFilter{
OptionalRelation: "relation",
},
},
{
input: "resourceType#relation",
expected: &v1.RelationshipFilter{
ResourceType: "resourceType",
OptionalRelation: "relation",
},
},
{
input: ":resourceId#relation",
expected: &v1.RelationshipFilter{
OptionalResourceId: "resourceId",
OptionalRelation: "relation",
},
},
{
input: ":resourceId%#relation",
expected: &v1.RelationshipFilter{
OptionalResourceIdPrefix: "resourceId",
OptionalRelation: "relation",
},
},
{
input: "resourceType:resourceId#relation@subjectType:subjectId#somerel",
expected: &v1.RelationshipFilter{
ResourceType: "resourceType",
OptionalResourceId: "resourceId",
OptionalRelation: "relation",
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: "subjectType",
OptionalSubjectId: "subjectId",
OptionalRelation: &v1.SubjectFilter_RelationFilter{Relation: "somerel"},
},
},
},
{
input: "@subjectType:subjectId#somerel",
expected: &v1.RelationshipFilter{
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: "subjectType",
OptionalSubjectId: "subjectId",
OptionalRelation: &v1.SubjectFilter_RelationFilter{Relation: "somerel"},
},
},
},
}

for _, tc := range tcs {
actual, err := parseRelationshipFilter(tc.input)
if err != nil {
t.Errorf("parseRelationshipFilter(%s) returned error: %v", tc.input, err)
}
if !reflect.DeepEqual(actual, tc.expected) {
t.Errorf("parseRelationshipFilter(%s) = %v, expected %v", tc.input, actual, tc.expected)
}
}
}

0 comments on commit 7dd68a1

Please sign in to comment.