diff --git a/internal/services/v1/errors.go b/internal/services/v1/errors.go index a1f1333605..993d923bdd 100644 --- a/internal/services/v1/errors.go +++ b/internal/services/v1/errors.go @@ -15,6 +15,42 @@ import ( "github.com/authzed/spicedb/pkg/tuple" ) +// ErrExceedsMaximumChecks occurs when too many checks are given to a call. +type ErrExceedsMaximumChecks struct { + error + checkCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ErrExceedsMaximumChecks) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("checkCount", err.checkCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ErrExceedsMaximumChecks) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{ + "check_count": strconv.FormatUint(err.checkCount, 10), + "maximum_checks_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumChecksErr creates a new error representing that too many updates were given to a WriteRelationships call. +func NewExceedsMaximumChecksErr(checkCount uint64, maxCountAllowed uint64) ErrExceedsMaximumChecks { + return ErrExceedsMaximumChecks{ + error: fmt.Errorf("check count of %d is greater than maximum allowed of %d", checkCount, maxCountAllowed), + checkCount: checkCount, + maxCountAllowed: maxCountAllowed, + } +} + // ErrExceedsMaximumUpdates occurs when too many updates are given to a call. type ErrExceedsMaximumUpdates struct { error diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index 4745d7de52..a98faf899f 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -410,12 +410,18 @@ func (es *experimentalServer) BulkExportRelationships( return nil } +const maxBulkCheckCount = 10000 + func (es *experimentalServer) BulkCheckPermission(ctx context.Context, req *v1.BulkCheckPermissionRequest) (*v1.BulkCheckPermissionResponse, error) { atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, es.rewriteError(ctx, err) } + if len(req.Items) > maxBulkCheckCount { + return nil, es.rewriteError(ctx, NewExceedsMaximumChecksErr(uint64(len(req.Items)), maxBulkCheckCount)) + } + // Compute a hash for each requested item and record its index(es) for the items, to be used for sorting of results. itemCount, err := genutil.EnsureUInt32(len(req.Items)) if err != nil {