From 220fd60be128dbfea508fd7fc555fa5114bc75fa Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 20 Feb 2024 12:48:10 -0500 Subject: [PATCH 1/3] Fix and remove downcasts --- internal/datasets/subjectset_test.go | 8 +- internal/datastore/crdb/pool/balancer.go | 13 +-- internal/datastore/crdb/pool/balancer_test.go | 2 +- internal/services/shared/schema.go | 4 +- internal/services/v1/errors.go | 24 ++--- internal/services/v1/experimental.go | 8 +- internal/services/v1/experimental_test.go | 2 +- internal/services/v1/relationships.go | 23 +++-- internal/services/v1/schema.go | 16 +++- pkg/caveats/types/registration.go | 20 ++++- pkg/genutil/ensure.go | 28 ++++++ pkg/genutil/ensure_test.go | 88 +++++++++++++++++++ pkg/tuple/onrset.go | 4 +- pkg/tuple/onrset_test.go | 26 +++--- 14 files changed, 213 insertions(+), 53 deletions(-) create mode 100644 pkg/genutil/ensure.go create mode 100644 pkg/genutil/ensure_test.go diff --git a/internal/datasets/subjectset_test.go b/internal/datasets/subjectset_test.go index 22701d568b..4af11368fb 100644 --- a/internal/datasets/subjectset_test.go +++ b/internal/datasets/subjectset_test.go @@ -2598,13 +2598,13 @@ func TestIntersectConcreteWithWildcard(t *testing.T) { // it counts in binary and "activates" input funcs that match 1s in the binary representation // it doesn't check for overflow so don't go crazy func allSubsets[T any](objs []T, n int) [][]T { - maxInt := uint(math.Exp2(float64(len(objs)))) - 1 + maxInt := uint64(math.Exp2(float64(len(objs)))) - 1 all := make([][]T, 0) - for i := uint(0); i < maxInt; i++ { + for i := uint64(0); i < maxInt; i++ { set := make([]T, 0, n) - for digit := uint(0); digit < uint(len(objs)); digit++ { - mask := uint(1) << digit + for digit := uint64(0); digit < uint64(len(objs)); digit++ { + mask := uint64(1) << digit if mask&i != 0 { set = append(set, objs[digit]) } diff --git a/internal/datastore/crdb/pool/balancer.go b/internal/datastore/crdb/pool/balancer.go index e88ce5c5de..cf825f1d44 100644 --- a/internal/datastore/crdb/pool/balancer.go +++ b/internal/datastore/crdb/pool/balancer.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sync/semaphore" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/genutil" ) var ( @@ -108,7 +109,7 @@ func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) { case <-p.ticker.C: if p.sem.TryAcquire(1) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - p.pruneConnections(ctx) + p.mustPruneConnections(ctx) cancel() p.sem.Release(1) } @@ -116,10 +117,10 @@ func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) { } } -// pruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes) +// mustPruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes) // This causes the pool to reconnect, which over time will lead to a balanced number of connections // across each node. -func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) { +func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context) { start := time.Now() defer func() { pruningTimeHistogram.WithLabelValues(p.pool.ID()).Observe(float64(time.Since(start).Milliseconds())) @@ -224,8 +225,10 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) { if numToPrune > 1 { numToPrune >>= 1 } - if uint32(len(healthyConns[node])) < numToPrune { - numToPrune = uint32(len(healthyConns[node])) + + healthyNodeCount := genutil.MustEnsureUInt32(len(healthyConns[node])) + if healthyNodeCount < numToPrune { + numToPrune = healthyNodeCount } if numToPrune == 0 { continue diff --git a/internal/datastore/crdb/pool/balancer_test.go b/internal/datastore/crdb/pool/balancer_test.go index 5fefdb9ec6..f977b447a7 100644 --- a/internal/datastore/crdb/pool/balancer_test.go +++ b/internal/datastore/crdb/pool/balancer_test.go @@ -162,7 +162,7 @@ func TestNodeConnectionBalancerPrune(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p.pruneConnections(ctx) + p.mustPruneConnections(ctx) require.Equal(t, len(tt.expectedGC), len(pool.gc)) gcFromNodes := make([]uint32, 0, len(tt.expectedGC)) for _, n := range pool.gc { diff --git a/internal/services/shared/schema.go b/internal/services/shared/schema.go index 92a84b64b4..9e8f99d216 100644 --- a/internal/services/shared/schema.go +++ b/internal/services/shared/schema.go @@ -75,7 +75,7 @@ func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchem type AppliedSchemaChanges struct { // TotalOperationCount holds the total number of "dispatch" operations performed by the schema // being applied. - TotalOperationCount uint32 + TotalOperationCount int // NewObjectDefNames contains the names of the newly added object definitions. NewObjectDefNames []string @@ -229,7 +229,7 @@ func ApplySchemaChangesOverExisting( Msg("completed schema update") return &AppliedSchemaChanges{ - TotalOperationCount: uint32(len(validated.compiled.ObjectDefinitions) + len(validated.compiled.CaveatDefinitions) + removedObjectDefNames.Len() + removedCaveatDefNames.Len()), + TotalOperationCount: len(validated.compiled.ObjectDefinitions) + len(validated.compiled.CaveatDefinitions) + removedObjectDefNames.Len() + removedCaveatDefNames.Len(), NewObjectDefNames: validated.newObjectDefNames.Subtract(existingObjectDefNames).AsSlice(), RemovedObjectDefNames: removedObjectDefNames.AsSlice(), NewCaveatDefNames: validated.newCaveatDefNames.Subtract(existingCaveatDefNames).AsSlice(), diff --git a/internal/services/v1/errors.go b/internal/services/v1/errors.go index 3c5eaa81b3..a1f1333605 100644 --- a/internal/services/v1/errors.go +++ b/internal/services/v1/errors.go @@ -18,13 +18,13 @@ import ( // ErrExceedsMaximumUpdates occurs when too many updates are given to a call. type ErrExceedsMaximumUpdates struct { error - updateCount uint16 - maxCountAllowed uint16 + updateCount uint64 + maxCountAllowed uint64 } // MarshalZerologObject implements zerolog object marshalling. func (err ErrExceedsMaximumUpdates) MarshalZerologObject(e *zerolog.Event) { - e.Err(err.error).Uint16("updateCount", err.updateCount).Uint16("maxCountAllowed", err.maxCountAllowed) + e.Err(err.error).Uint64("updateCount", err.updateCount).Uint64("maxCountAllowed", err.maxCountAllowed) } // GRPCStatus implements retrieving the gRPC status for the error. @@ -35,15 +35,15 @@ func (err ErrExceedsMaximumUpdates) GRPCStatus() *status.Status { spiceerrors.ForReason( v1.ErrorReason_ERROR_REASON_TOO_MANY_UPDATES_IN_REQUEST, map[string]string{ - "update_count": strconv.Itoa(int(err.updateCount)), - "maximum_updates_allowed": strconv.Itoa(int(err.maxCountAllowed)), + "update_count": strconv.FormatUint(err.updateCount, 10), + "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10), }, ), ) } // NewExceedsMaximumUpdatesErr creates a new error representing that too many updates were given to a WriteRelationships call. -func NewExceedsMaximumUpdatesErr(updateCount uint16, maxCountAllowed uint16) ErrExceedsMaximumUpdates { +func NewExceedsMaximumUpdatesErr(updateCount uint64, maxCountAllowed uint64) ErrExceedsMaximumUpdates { return ErrExceedsMaximumUpdates{ error: fmt.Errorf("update count of %d is greater than maximum allowed of %d", updateCount, maxCountAllowed), updateCount: updateCount, @@ -54,13 +54,13 @@ func NewExceedsMaximumUpdatesErr(updateCount uint16, maxCountAllowed uint16) Err // ErrExceedsMaximumPreconditions occurs when too many preconditions are given to a call. type ErrExceedsMaximumPreconditions struct { error - preconditionCount uint16 - maxCountAllowed uint16 + preconditionCount uint64 + maxCountAllowed uint64 } // MarshalZerologObject implements zerolog object marshalling. func (err ErrExceedsMaximumPreconditions) MarshalZerologObject(e *zerolog.Event) { - e.Err(err.error).Uint16("preconditionCount", err.preconditionCount).Uint16("maxCountAllowed", err.maxCountAllowed) + e.Err(err.error).Uint64("preconditionCount", err.preconditionCount).Uint64("maxCountAllowed", err.maxCountAllowed) } // GRPCStatus implements retrieving the gRPC status for the error. @@ -71,15 +71,15 @@ func (err ErrExceedsMaximumPreconditions) GRPCStatus() *status.Status { spiceerrors.ForReason( v1.ErrorReason_ERROR_REASON_TOO_MANY_PRECONDITIONS_IN_REQUEST, map[string]string{ - "precondition_count": strconv.Itoa(int(err.preconditionCount)), - "maximum_updates_allowed": strconv.Itoa(int(err.maxCountAllowed)), + "precondition_count": strconv.FormatUint(err.preconditionCount, 10), + "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10), }, ), ) } // NewExceedsMaximumPreconditionsErr creates a new error representing that too many preconditions were given to a call. -func NewExceedsMaximumPreconditionsErr(preconditionCount uint16, maxCountAllowed uint16) ErrExceedsMaximumPreconditions { +func NewExceedsMaximumPreconditionsErr(preconditionCount uint64, maxCountAllowed uint64) ErrExceedsMaximumPreconditions { return ErrExceedsMaximumPreconditions{ error: fmt.Errorf( "precondition count of %d is greater than maximum allowed of %d", diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index d1a9b811cf..4745d7de52 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -34,6 +34,7 @@ import ( "github.com/authzed/spicedb/pkg/cursor" "github.com/authzed/spicedb/pkg/datastore" dsoptions "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/genutil" "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/genutil/slicez" "github.com/authzed/spicedb/pkg/middleware/consistency" @@ -416,7 +417,12 @@ func (es *experimentalServer) BulkCheckPermission(ctx context.Context, req *v1.B } // Compute a hash for each requested item and record its index(es) for the items, to be used for sorting of results. - itemIndexByHash := mapz.NewMultiMapWithCap[string, int](uint32(len(req.Items))) + itemCount, err := genutil.EnsureUInt32(len(req.Items)) + if err != nil { + return nil, es.rewriteError(ctx, err) + } + + itemIndexByHash := mapz.NewMultiMapWithCap[string, int](itemCount) for index, item := range req.Items { itemHash, err := computeBulkCheckPermissionItemHash(item) if err != nil { diff --git a/internal/services/v1/experimental_test.go b/internal/services/v1/experimental_test.go index 7cdd7d76d4..6c98ce48a0 100644 --- a/internal/services/v1/experimental_test.go +++ b/internal/services/v1/experimental_test.go @@ -204,7 +204,7 @@ func TestBulkExportRelationships(t *testing.T) { } require.NoError(err) - require.LessOrEqual(uint32(len(batch.Relationships)), tc.batchSize) + require.LessOrEqual(uint64(len(batch.Relationships)), uint64(tc.batchSize)) require.NotNil(batch.AfterResultCursor) require.NotEmpty(batch.AfterResultCursor.Token) diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index b76e5efd6d..e4db8ff8ba 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -26,6 +26,7 @@ import ( "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/datastore/pagination" + "github.com/authzed/spicedb/pkg/genutil" "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/middleware/consistency" dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -254,14 +255,14 @@ func (ps *permissionServer) WriteRelationships(ctx context.Context, req *v1.Writ if len(req.Updates) > int(ps.config.MaxUpdatesPerWrite) { return nil, ps.rewriteError( ctx, - NewExceedsMaximumUpdatesErr(uint16(len(req.Updates)), ps.config.MaxUpdatesPerWrite), + NewExceedsMaximumUpdatesErr(uint64(len(req.Updates)), uint64(ps.config.MaxUpdatesPerWrite)), ) } if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) { return nil, ps.rewriteError( ctx, - NewExceedsMaximumPreconditionsErr(uint16(len(req.OptionalPreconditions)), ps.config.MaxPreconditionsCount), + NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)), ) } @@ -302,9 +303,14 @@ func (ps *permissionServer) WriteRelationships(ctx context.Context, req *v1.Writ return ps.rewriteError(ctx, err) } + dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1) + if err != nil { + return ps.rewriteError(ctx, err) + } + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ // One request per precondition and one request for the actual writes. - DispatchCount: uint32(len(req.OptionalPreconditions)) + 1, + DispatchCount: dispatchCount, }) span.AddEvent("preconditions") @@ -338,7 +344,7 @@ func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.Del if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) { return nil, ps.rewriteError( ctx, - NewExceedsMaximumPreconditionsErr(uint16(len(req.OptionalPreconditions)), ps.config.MaxPreconditionsCount), + NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)), ) } @@ -350,9 +356,14 @@ func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.Del return err } + dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1) + if err != nil { + return ps.rewriteError(ctx, err) + } + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ // One request per precondition and one request for the actual delete. - DispatchCount: uint32(len(req.OptionalPreconditions)) + 1, + DispatchCount: dispatchCount, }) if err := checkPreconditions(ctx, rwt, req.OptionalPreconditions); err != nil { @@ -403,7 +414,7 @@ func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.Del } // Otherwise, kick off an unlimited deletion. - _, err := rwt.DeleteRelationships(ctx, req.RelationshipFilter) + _, err = rwt.DeleteRelationships(ctx, req.RelationshipFilter) return err }) if err != nil { diff --git a/internal/services/v1/schema.go b/internal/services/v1/schema.go index 4d0805aed9..25cfc72ba9 100644 --- a/internal/services/v1/schema.go +++ b/internal/services/v1/schema.go @@ -14,6 +14,7 @@ import ( "github.com/authzed/spicedb/internal/middleware/usagemetrics" "github.com/authzed/spicedb/internal/services/shared" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil" dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/schemadsl/generator" @@ -87,8 +88,13 @@ func (ss *schemaServer) ReadSchema(ctx context.Context, _ *v1.ReadSchemaRequest) return nil, ss.rewriteError(ctx, err) } + dispatchCount, err := genutil.EnsureUInt32(len(nsDefs) + len(caveatDefs)) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ - DispatchCount: uint32(len(nsDefs) + len(caveatDefs)), + DispatchCount: dispatchCount, }) return &v1.ReadSchemaResponse{ @@ -124,8 +130,14 @@ func (ss *schemaServer) WriteSchema(ctx context.Context, in *v1.WriteSchemaReque if err != nil { return err } + + dispatchCount, err := genutil.EnsureUInt32(applied.TotalOperationCount) + if err != nil { + return err + } + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ - DispatchCount: applied.TotalOperationCount, + DispatchCount: dispatchCount, }) return nil }) diff --git a/pkg/caveats/types/registration.go b/pkg/caveats/types/registration.go index eae8094720..0e75c1ea7f 100644 --- a/pkg/caveats/types/registration.go +++ b/pkg/caveats/types/registration.go @@ -5,6 +5,8 @@ import ( "github.com/authzed/cel-go/cel" "github.com/authzed/cel-go/common/types/ref" + + "github.com/authzed/spicedb/pkg/genutil" ) var definitions = map[string]typeDefinition{} @@ -27,7 +29,7 @@ type typeDefinition struct { localName string // childTypeCount is the number of generics on the type, if any. - childTypeCount uint + childTypeCount uint8 // asVariableType converts the type definition into a VariableType. asVariableType func(childTypes []VariableType) (*VariableType, error) @@ -55,14 +57,19 @@ func registerBasicType(keyword string, celType *cel.Type, converter typedValueCo // registerGenericType registers a type with at least one generic. func registerGenericType( keyword string, - childTypeCount uint, + childTypeCount uint8, asVariableType func(childTypes []VariableType) VariableType, ) func(childTypes ...VariableType) (VariableType, error) { definitions[keyword] = typeDefinition{ localName: keyword, childTypeCount: childTypeCount, asVariableType: func(childTypes []VariableType) (*VariableType, error) { - if uint(len(childTypes)) != childTypeCount { + childTypeLength, err := genutil.EnsureUInt8(len(childTypes)) + if err != nil { + return nil, err + } + + if childTypeLength != childTypeCount { return nil, fmt.Errorf("type `%s` requires %d generic types; found %d", keyword, childTypeCount, len(childTypes)) } @@ -71,7 +78,12 @@ func registerGenericType( }, } return func(childTypes ...VariableType) (VariableType, error) { - if uint(len(childTypes)) != childTypeCount { + childTypeLength, err := genutil.EnsureUInt8(len(childTypes)) + if err != nil { + return VariableType{}, err + } + + if childTypeLength != childTypeCount { return VariableType{}, fmt.Errorf("invalid number of parameters given to type constructor. expected: %d, found: %d", childTypeCount, len(childTypes)) } diff --git a/pkg/genutil/ensure.go b/pkg/genutil/ensure.go new file mode 100644 index 0000000000..580d17fc02 --- /dev/null +++ b/pkg/genutil/ensure.go @@ -0,0 +1,28 @@ +package genutil + +import "github.com/authzed/spicedb/pkg/spiceerrors" + +// MustEnsureUInt32 is a helper function that calls EnsureUInt32 and panics on error. +func MustEnsureUInt32(value int) uint32 { + ret, err := EnsureUInt32(value) + if err != nil { + panic(err) + } + return ret +} + +// EnsureUInt32 ensures that the specified value can be represented as a uint32. +func EnsureUInt32(value int) (uint32, error) { + if value > int(^uint32(0)) { + return 0, spiceerrors.MustBugf("specified value is too large to fit in a uint32") + } + return uint32(value), nil +} + +// EnsureUInt8 ensures that the specified value can be represented as a uint8. +func EnsureUInt8(value int) (uint8, error) { + if value > int(^uint8(0)) { + return 0, spiceerrors.MustBugf("specified value is too large to fit in a uint8") + } + return uint8(value), nil +} diff --git a/pkg/genutil/ensure_test.go b/pkg/genutil/ensure_test.go new file mode 100644 index 0000000000..7da26a5062 --- /dev/null +++ b/pkg/genutil/ensure_test.go @@ -0,0 +1,88 @@ +package genutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEnsureUInt32(t *testing.T) { + tcs := []struct { + name string + value int + want uint32 + err bool + }{ + { + name: "zero", + value: 0, + want: 0, + }, + { + name: "max", + value: int(^uint32(0)), + want: ^uint32(0), + }, + { + name: "overflow", + value: int(^uint32(0)) + 1, + err: true, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + if tc.err { + assert.Panics(t, func() { + _, _ = EnsureUInt32(tc.value) + }, "The code did not panic") + return + } + + got, err := EnsureUInt32(tc.value) + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestEnsureUInt8(t *testing.T) { + tcs := []struct { + name string + value int + want uint8 + err bool + }{ + { + name: "zero", + value: 0, + want: 0, + }, + { + name: "max", + value: int(^uint8(0)), + want: ^uint8(0), + }, + { + name: "overflow", + value: int(^uint8(0)) + 1, + err: true, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + if tc.err { + assert.Panics(t, func() { + _, _ = EnsureUInt8(tc.value) + }, "The code did not panic") + return + } + + got, err := EnsureUInt8(tc.value) + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/pkg/tuple/onrset.go b/pkg/tuple/onrset.go index 0496d1103d..04a4ad66af 100644 --- a/pkg/tuple/onrset.go +++ b/pkg/tuple/onrset.go @@ -23,8 +23,8 @@ func NewONRSet(onrs ...*core.ObjectAndRelation) *ONRSet { } // Length returns the size of the set. -func (ons *ONRSet) Length() uint32 { - return uint32(len(ons.onrs)) +func (ons *ONRSet) Length() uint64 { + return uint64(len(ons.onrs)) } // IsEmpty returns whether the set is empty. diff --git a/pkg/tuple/onrset_test.go b/pkg/tuple/onrset_test.go index 43417d7e63..e102ca2d8e 100644 --- a/pkg/tuple/onrset_test.go +++ b/pkg/tuple/onrset_test.go @@ -11,15 +11,15 @@ import ( func TestONRSet(t *testing.T) { set := NewONRSet() require.True(t, set.IsEmpty()) - require.Equal(t, uint32(0), set.Length()) + require.Equal(t, uint64(0), set.Length()) require.True(t, set.Add(ParseONR("resource:1#viewer"))) require.False(t, set.IsEmpty()) - require.Equal(t, uint32(1), set.Length()) + require.Equal(t, uint64(1), set.Length()) require.True(t, set.Add(ParseONR("resource:2#viewer"))) require.True(t, set.Add(ParseONR("resource:3#viewer"))) - require.Equal(t, uint32(3), set.Length()) + require.Equal(t, uint64(3), set.Length()) require.False(t, set.Add(ParseONR("resource:1#viewer"))) require.True(t, set.Add(ParseONR("resource:1#editor"))) @@ -40,7 +40,7 @@ func TestONRSetUpdate(t *testing.T) { ParseONR("resource:2#viewer"), ParseONR("resource:3#viewer"), }) - require.Equal(t, uint32(3), set.Length()) + require.Equal(t, uint64(3), set.Length()) set.Update([]*core.ObjectAndRelation{ ParseONR("resource:1#viewer"), @@ -49,7 +49,7 @@ func TestONRSetUpdate(t *testing.T) { ParseONR("resource:1#admin"), ParseONR("resource:1#reader"), }) - require.Equal(t, uint32(7), set.Length()) + require.Equal(t, uint64(7), set.Length()) } func TestONRSetIntersect(t *testing.T) { @@ -70,8 +70,8 @@ func TestONRSetIntersect(t *testing.T) { ParseONR("resource:1#reader"), }) - require.Equal(t, uint32(2), set1.Intersect(set2).Length()) - require.Equal(t, uint32(2), set2.Intersect(set1).Length()) + require.Equal(t, uint64(2), set1.Intersect(set2).Length()) + require.Equal(t, uint64(2), set2.Intersect(set1).Length()) } func TestONRSetSubtract(t *testing.T) { @@ -92,8 +92,8 @@ func TestONRSetSubtract(t *testing.T) { ParseONR("resource:1#reader"), }) - require.Equal(t, uint32(1), set1.Subtract(set2).Length()) - require.Equal(t, uint32(4), set2.Subtract(set1).Length()) + require.Equal(t, uint64(1), set1.Subtract(set2).Length()) + require.Equal(t, uint64(4), set2.Subtract(set1).Length()) } func TestONRSetUnion(t *testing.T) { @@ -114,8 +114,8 @@ func TestONRSetUnion(t *testing.T) { ParseONR("resource:1#reader"), }) - require.Equal(t, uint32(7), set1.Union(set2).Length()) - require.Equal(t, uint32(7), set2.Union(set1).Length()) + require.Equal(t, uint64(7), set1.Union(set2).Length()) + require.Equal(t, uint64(7), set2.Union(set1).Length()) } func TestONRSetWith(t *testing.T) { @@ -127,8 +127,8 @@ func TestONRSetWith(t *testing.T) { }) added := set1.With(ParseONR("resource:1#editor")) - require.Equal(t, uint32(3), set1.Length()) - require.Equal(t, uint32(4), added.Length()) + require.Equal(t, uint64(3), set1.Length()) + require.Equal(t, uint64(4), added.Length()) } func TestONRSetAsSlice(t *testing.T) { From 8dac7e3fdd3c80c98b8f3622cb468e27f4a8ed9c Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 20 Feb 2024 13:06:02 -0500 Subject: [PATCH 2/3] Add linter --- tools/analyzers/cmd/analyzers/main.go | 2 + .../lendowncastcheck/lendowncastcheck.go | 126 ++++++++++++++++++ .../lendowncastcheck/lendowncastcheck_test.go | 14 ++ .../testdata/src/lenexamples/lenexamples.go | 23 ++++ 4 files changed, 165 insertions(+) create mode 100644 tools/analyzers/lendowncastcheck/lendowncastcheck.go create mode 100644 tools/analyzers/lendowncastcheck/lendowncastcheck_test.go create mode 100644 tools/analyzers/lendowncastcheck/testdata/src/lenexamples/lenexamples.go diff --git a/tools/analyzers/cmd/analyzers/main.go b/tools/analyzers/cmd/analyzers/main.go index 1fbc1f6641..0270299091 100644 --- a/tools/analyzers/cmd/analyzers/main.go +++ b/tools/analyzers/cmd/analyzers/main.go @@ -3,6 +3,7 @@ package main import ( "github.com/authzed/spicedb/tools/analyzers/closeafterusagecheck" "github.com/authzed/spicedb/tools/analyzers/exprstatementcheck" + "github.com/authzed/spicedb/tools/analyzers/lendowncastcheck" "github.com/authzed/spicedb/tools/analyzers/nilvaluecheck" "github.com/authzed/spicedb/tools/analyzers/paniccheck" "golang.org/x/tools/go/analysis/multichecker" @@ -14,5 +15,6 @@ func main() { exprstatementcheck.Analyzer(), closeafterusagecheck.Analyzer(), paniccheck.Analyzer(), + lendowncastcheck.Analyzer(), ) } diff --git a/tools/analyzers/lendowncastcheck/lendowncastcheck.go b/tools/analyzers/lendowncastcheck/lendowncastcheck.go new file mode 100644 index 0000000000..c4a8117090 --- /dev/null +++ b/tools/analyzers/lendowncastcheck/lendowncastcheck.go @@ -0,0 +1,126 @@ +package lendowncastcheck + +import ( + "flag" + "fmt" + "go/ast" + "regexp" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +func sliceMap(s []string, f func(value string) string) []string { + mapped := make([]string, 0, len(s)) + for _, value := range s { + mapped = append(mapped, f(value)) + } + return mapped +} + +var disallowedDowncastTypes = map[string]bool{ + "int8": true, + "int16": true, + "int32": true, + "int64": true, + "uint": true, + "uint8": true, + "uint16": true, + "uint32": true, + "float32": true, + "float64": true, +} + +func Analyzer() *analysis.Analyzer { + flagSet := flag.NewFlagSet("lendowncastcheck", flag.ExitOnError) + skipPkg := flagSet.String("skip-pkg", "", "package(s) to skip for linting") + skipFiles := flagSet.String("skip-files", "", "patterns of files to skip for linting") + + return &analysis.Analyzer{ + Name: "lendowncastcheck", + Doc: "reports downcasting of len() calls", + Run: func(pass *analysis.Pass) (any, error) { + // Check for a skipped package. + if len(*skipPkg) > 0 { + skipped := sliceMap(strings.Split(*skipPkg, ","), strings.TrimSpace) + for _, s := range skipped { + if strings.Contains(pass.Pkg.Path(), s) { + return nil, nil + } + } + } + + // Check for a skipped file. + skipFilePatterns := make([]string, 0) + if len(*skipFiles) > 0 { + skipFilePatterns = sliceMap(strings.Split(*skipFiles, ","), strings.TrimSpace) + } + for _, pattern := range skipFilePatterns { + _, err := regexp.Compile(pattern) + if err != nil { + return nil, fmt.Errorf("invalid skip-files pattern `%s`: %w", pattern, err) + } + } + + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeFilter := []ast.Node{ + (*ast.File)(nil), + (*ast.CallExpr)(nil), + } + + inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) bool { + switch s := n.(type) { + case *ast.File: + for _, pattern := range skipFilePatterns { + isMatch, _ := regexp.MatchString(pattern, pass.Fset.Position(s.Package).Filename) + if isMatch { + return false + } + } + return true + + case *ast.CallExpr: + identExpr, ok := s.Fun.(*ast.Ident) + if !ok { + return false + } + + if _, ok := disallowedDowncastTypes[identExpr.Name]; !ok { + return false + } + + if len(s.Args) != 1 { + return false + } + + childExpr, ok := s.Args[0].(*ast.CallExpr) + if !ok { + return false + } + + childIdentExpr, ok := childExpr.Fun.(*ast.Ident) + if !ok { + return false + } + + if childIdentExpr.Name != "len" { + return false + } + + pass.Reportf(s.Pos(), "In package %s: found downcast of `len` call to %s", pass.Pkg.Path(), identExpr.Name) + return false + + default: + return true + } + }) + + return nil, nil + }, + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Flags: *flagSet, + } +} diff --git a/tools/analyzers/lendowncastcheck/lendowncastcheck_test.go b/tools/analyzers/lendowncastcheck/lendowncastcheck_test.go new file mode 100644 index 0000000000..b8671894cb --- /dev/null +++ b/tools/analyzers/lendowncastcheck/lendowncastcheck_test.go @@ -0,0 +1,14 @@ +package lendowncastcheck + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +func TestAnalyzer(t *testing.T) { + analyzer := Analyzer() + + testdata := analysistest.TestData() + analysistest.Run(t, testdata, analyzer, "lenexamples") +} diff --git a/tools/analyzers/lendowncastcheck/testdata/src/lenexamples/lenexamples.go b/tools/analyzers/lendowncastcheck/testdata/src/lenexamples/lenexamples.go new file mode 100644 index 0000000000..7d65321b98 --- /dev/null +++ b/tools/analyzers/lendowncastcheck/testdata/src/lenexamples/lenexamples.go @@ -0,0 +1,23 @@ +package lenexamples + +import "fmt" + +func DoSomething(someSlice []string) uint64 { + return uint64(len(someSlice)) +} + +func DoSomethingBad(someSlice []string) uint32 { + v := uint32(len(someSlice)) // want "found downcast of `len` call to uint32" + return v +} + +func DoSomethingBad16(someSlice []string) uint16 { + v := uint16(len(someSlice)) // want "found downcast of `len` call to uint16" + return v +} + +func DoSomeLoop(someSlice []string) { + for i := 0; i < len(someSlice); i++ { + fmt.Println(someSlice[i]) + } +} From 2f533cf1582a9a13ab7dcd2629c66816fe4fb4c6 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 20 Feb 2024 14:47:34 -0500 Subject: [PATCH 3/3] Set maximum bulk check size --- internal/services/v1/errors.go | 36 ++++++++++++++++++++++++++++ internal/services/v1/experimental.go | 6 +++++ 2 files changed, 42 insertions(+) diff --git a/internal/services/v1/errors.go b/internal/services/v1/errors.go index a1f1333605..f12aa80b36 100644 --- a/internal/services/v1/errors.go +++ b/internal/services/v1/errors.go @@ -15,6 +15,42 @@ import ( "github.com/authzed/spicedb/pkg/tuple" ) +// ErrExceedsMaximumChecks occurs when too many checks are given to a call. +type ErrExceedsMaximumChecks struct { + error + checkCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ErrExceedsMaximumChecks) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("checkCount", err.checkCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ErrExceedsMaximumChecks) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{ + "check_count": strconv.FormatUint(err.checkCount, 10), + "maximum_checks_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumChecksErr creates a new error representing that too many updates were given to a BulkCheckPermissions call. +func NewExceedsMaximumChecksErr(checkCount uint64, maxCountAllowed uint64) ErrExceedsMaximumChecks { + return ErrExceedsMaximumChecks{ + error: fmt.Errorf("check count of %d is greater than maximum allowed of %d", checkCount, maxCountAllowed), + checkCount: checkCount, + maxCountAllowed: maxCountAllowed, + } +} + // ErrExceedsMaximumUpdates occurs when too many updates are given to a call. type ErrExceedsMaximumUpdates struct { error diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index 4745d7de52..a98faf899f 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -410,12 +410,18 @@ func (es *experimentalServer) BulkExportRelationships( return nil } +const maxBulkCheckCount = 10000 + func (es *experimentalServer) BulkCheckPermission(ctx context.Context, req *v1.BulkCheckPermissionRequest) (*v1.BulkCheckPermissionResponse, error) { atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, es.rewriteError(ctx, err) } + if len(req.Items) > maxBulkCheckCount { + return nil, es.rewriteError(ctx, NewExceedsMaximumChecksErr(uint64(len(req.Items)), maxBulkCheckCount)) + } + // Compute a hash for each requested item and record its index(es) for the items, to be used for sorting of results. itemCount, err := genutil.EnsureUInt32(len(req.Items)) if err != nil {