Skip to content

Commit

Permalink
fix(account): add enforcer to enforce workspace member count
Browse files Browse the repository at this point in the history
  • Loading branch information
rot1024 committed Oct 4, 2023
1 parent bceec51 commit 81f4cb9
Show file tree
Hide file tree
Showing 14 changed files with 311 additions and 236 deletions.
2 changes: 2 additions & 0 deletions account/accountdomain/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,5 @@ func (u *User) Clone() *User {
passwordReset: util.CloneRef(u.passwordReset),
}
}

type List []*User
24 changes: 12 additions & 12 deletions account/accountdomain/workspace/workspace_list.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package workspace

type WorkspaceList []*Workspace
type List []*Workspace

func (l WorkspaceList) FilterByID(ids ...ID) WorkspaceList {
func (l List) FilterByID(ids ...ID) List {
if l == nil {
return nil
}

res := make(WorkspaceList, 0, len(l))
res := make(List, 0, len(l))
for _, id := range ids {
var t2 *Workspace
for _, t := range l {
Expand All @@ -23,12 +23,12 @@ func (l WorkspaceList) FilterByID(ids ...ID) WorkspaceList {
return res
}

func (l WorkspaceList) FilterByUserRole(u UserID, r Role) WorkspaceList {
func (l List) FilterByUserRole(u UserID, r Role) List {
if l == nil || u.IsEmpty() || r == "" {
return nil
}

res := make(WorkspaceList, 0, len(l))
res := make(List, 0, len(l))
for _, t := range l {
if m := t.Members().User(u); m != nil && m.Role == r {
res = append(res, t)
Expand All @@ -37,12 +37,12 @@ func (l WorkspaceList) FilterByUserRole(u UserID, r Role) WorkspaceList {
return res
}

func (l WorkspaceList) FilterByIntegrationRole(i IntegrationID, r Role) WorkspaceList {
func (l List) FilterByIntegrationRole(i IntegrationID, r Role) List {
if l == nil || i.IsEmpty() || r == "" {
return nil
}

res := make(WorkspaceList, 0, len(l))
res := make(List, 0, len(l))
for _, t := range l {
if m := t.Members().Integration(i); m != nil && m.Role == r {
res = append(res, t)
Expand All @@ -51,12 +51,12 @@ func (l WorkspaceList) FilterByIntegrationRole(i IntegrationID, r Role) Workspac
return res
}

func (l WorkspaceList) FilterByUserRoleIncluding(u UserID, r Role) WorkspaceList {
func (l List) FilterByUserRoleIncluding(u UserID, r Role) List {
if l == nil || u.IsEmpty() || r == "" {
return nil
}

res := make(WorkspaceList, 0, len(l))
res := make(List, 0, len(l))
for _, t := range l {
if m := t.Members().User(u); m != nil && m.Role.Includes(r) {
res = append(res, t)
Expand All @@ -65,12 +65,12 @@ func (l WorkspaceList) FilterByUserRoleIncluding(u UserID, r Role) WorkspaceList
return res
}

func (l WorkspaceList) FilterByIntegrationRoleIncluding(i IntegrationID, r Role) WorkspaceList {
func (l List) FilterByIntegrationRoleIncluding(i IntegrationID, r Role) List {
if l == nil || i.IsEmpty() || r == "" {
return nil
}

res := make(WorkspaceList, 0, len(l))
res := make(List, 0, len(l))
for _, t := range l {
if m := t.Members().Integration(i); m != nil && m.Role.Includes(r) {
res = append(res, t)
Expand All @@ -79,7 +79,7 @@ func (l WorkspaceList) FilterByIntegrationRoleIncluding(i IntegrationID, r Role)
return res
}

func (l WorkspaceList) IDs() []ID {
func (l List) IDs() []ID {
if l == nil {
return nil
}
Expand Down
48 changes: 24 additions & 24 deletions account/accountdomain/workspace/workspace_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ func TestWorkspaceList_FilterByID(t *testing.T) {
t1 := &Workspace{id: tid1}
t2 := &Workspace{id: tid2}

assert.Equal(t, WorkspaceList{t1}, WorkspaceList{t1, t2}.FilterByID(tid1))
assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByID(tid2))
assert.Equal(t, WorkspaceList{t1, t2}, WorkspaceList{t1, t2}.FilterByID(tid1, tid2))
assert.Equal(t, WorkspaceList{}, WorkspaceList{t1, t2}.FilterByID(NewID()))
assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByID(tid1))
assert.Equal(t, List{t1}, List{t1, t2}.FilterByID(tid1))
assert.Equal(t, List{t2}, List{t1, t2}.FilterByID(tid2))
assert.Equal(t, List{t1, t2}, List{t1, t2}.FilterByID(tid1, tid2))
assert.Equal(t, List{}, List{t1, t2}.FilterByID(NewID()))
assert.Equal(t, List(nil), List(nil).FilterByID(tid1))
}

func TestWorkspaceList_FilterByUserRole(t *testing.T) {
Expand All @@ -40,10 +40,10 @@ func TestWorkspaceList_FilterByUserRole(t *testing.T) {
},
}

assert.Equal(t, WorkspaceList{t1}, WorkspaceList{t1, t2}.FilterByUserRole(uid, RoleReader))
assert.Equal(t, WorkspaceList{}, WorkspaceList{t1, t2}.FilterByUserRole(uid, RoleWriter))
assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByUserRole(uid, RoleOwner))
assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByUserRole(uid, RoleOwner))
assert.Equal(t, List{t1}, List{t1, t2}.FilterByUserRole(uid, RoleReader))
assert.Equal(t, List{}, List{t1, t2}.FilterByUserRole(uid, RoleWriter))
assert.Equal(t, List{t2}, List{t1, t2}.FilterByUserRole(uid, RoleOwner))
assert.Equal(t, List(nil), List(nil).FilterByUserRole(uid, RoleOwner))
}

func TestWorkspaceList_FilterByIntegrationRole(t *testing.T) {
Expand All @@ -67,10 +67,10 @@ func TestWorkspaceList_FilterByIntegrationRole(t *testing.T) {
},
}

assert.Equal(t, WorkspaceList{t1}, WorkspaceList{t1, t2}.FilterByIntegrationRole(iid, RoleReader))
assert.Equal(t, WorkspaceList{}, WorkspaceList{t1, t2}.FilterByIntegrationRole(iid, RoleOwner))
assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByIntegrationRole(iid, RoleWriter))
assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByIntegrationRole(iid, RoleOwner))
assert.Equal(t, List{t1}, List{t1, t2}.FilterByIntegrationRole(iid, RoleReader))
assert.Equal(t, List{}, List{t1, t2}.FilterByIntegrationRole(iid, RoleOwner))
assert.Equal(t, List{t2}, List{t1, t2}.FilterByIntegrationRole(iid, RoleWriter))
assert.Equal(t, List(nil), List(nil).FilterByIntegrationRole(iid, RoleOwner))
}

func TestWorkspaceList_FilterByUserRoleIncluding(t *testing.T) {
Expand All @@ -94,10 +94,10 @@ func TestWorkspaceList_FilterByUserRoleIncluding(t *testing.T) {
},
}

assert.Equal(t, WorkspaceList{t1, t2}, WorkspaceList{t1, t2}.FilterByUserRoleIncluding(uid, RoleReader))
assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByUserRoleIncluding(uid, RoleWriter))
assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByUserRoleIncluding(uid, RoleOwner))
assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByUserRoleIncluding(uid, RoleOwner))
assert.Equal(t, List{t1, t2}, List{t1, t2}.FilterByUserRoleIncluding(uid, RoleReader))
assert.Equal(t, List{t2}, List{t1, t2}.FilterByUserRoleIncluding(uid, RoleWriter))
assert.Equal(t, List{t2}, List{t1, t2}.FilterByUserRoleIncluding(uid, RoleOwner))
assert.Equal(t, List(nil), List(nil).FilterByUserRoleIncluding(uid, RoleOwner))
}

func TestWorkspaceList_FilterByIntegrationRoleIncluding(t *testing.T) {
Expand All @@ -121,10 +121,10 @@ func TestWorkspaceList_FilterByIntegrationRoleIncluding(t *testing.T) {
},
}

assert.Equal(t, WorkspaceList{t1, t2}, WorkspaceList{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleReader))
assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleWriter))
assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleOwner))
assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByIntegrationRoleIncluding(uid, RoleOwner))
assert.Equal(t, List{t1, t2}, List{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleReader))
assert.Equal(t, List{t2}, List{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleWriter))
assert.Equal(t, List{t2}, List{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleOwner))
assert.Equal(t, List(nil), List(nil).FilterByIntegrationRoleIncluding(uid, RoleOwner))
}

func TestWorkspaceList_IDs(t *testing.T) {
Expand All @@ -133,7 +133,7 @@ func TestWorkspaceList_IDs(t *testing.T) {
t1 := &Workspace{id: wid1}
t2 := &Workspace{id: wid2}

assert.Equal(t, []ID{wid1, wid2}, WorkspaceList{t1, t2}.IDs())
assert.Equal(t, []ID{}, WorkspaceList{}.IDs())
assert.Equal(t, []ID(nil), WorkspaceList(nil).IDs())
assert.Equal(t, []ID{wid1, wid2}, List{t1, t2}.IDs())
assert.Equal(t, []ID{}, List{}.IDs())
assert.Equal(t, []ID(nil), List(nil).IDs())
}
29 changes: 14 additions & 15 deletions account/accountinfrastructure/accountmemory/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@ package accountmemory
import (
"context"

"github.com/reearth/reearthx/account/accountdomain"
"github.com/reearth/reearthx/account/accountdomain/user"
"github.com/reearth/reearthx/account/accountusecase/accountrepo"
"github.com/reearth/reearthx/rerror"
"github.com/reearth/reearthx/util"
)

type User struct {
data *util.SyncMap[accountdomain.UserID, *user.User]
data *util.SyncMap[user.ID, *user.User]
err error
}

func NewUser() *User {
return &User{
data: &util.SyncMap[accountdomain.UserID, *user.User]{},
data: &util.SyncMap[user.ID, *user.User]{},
}
}

Expand All @@ -29,24 +28,24 @@ func NewUserWith(users ...*user.User) *User {
return r
}

func (r *User) FindByIDs(ctx context.Context, ids accountdomain.UserIDList) ([]*user.User, error) {
func (r *User) FindByIDs(ctx context.Context, ids user.IDList) (user.List, error) {
if r.err != nil {
return nil, r.err
}

res := r.data.FindAll(func(key accountdomain.UserID, value *user.User) bool {
res := r.data.FindAll(func(key user.ID, value *user.User) bool {
return ids.Has(key)
})

return res, nil
}

func (r *User) FindByID(ctx context.Context, v accountdomain.UserID) (*user.User, error) {
func (r *User) FindByID(ctx context.Context, v user.ID) (*user.User, error) {
if r.err != nil {
return nil, r.err
}

return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool {
return key == v
}), rerror.ErrNotFound)
}
Expand All @@ -60,7 +59,7 @@ func (r *User) FindBySub(ctx context.Context, auth0sub string) (*user.User, erro
return nil, rerror.ErrInvalidParams
}

return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool {
return value.ContainAuth(user.AuthFrom(auth0sub))
}), rerror.ErrNotFound)
}
Expand All @@ -74,7 +73,7 @@ func (r *User) FindByPasswordResetRequest(ctx context.Context, token string) (*u
return nil, rerror.ErrInvalidParams
}

return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool {
return value.PasswordReset() != nil && value.PasswordReset().Token == token
}), rerror.ErrNotFound)
}
Expand All @@ -88,7 +87,7 @@ func (r *User) FindByEmail(ctx context.Context, email string) (*user.User, error
return nil, rerror.ErrInvalidParams
}

return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool {
return value.Email() == email
}), rerror.ErrNotFound)
}
Expand All @@ -102,7 +101,7 @@ func (r *User) FindByName(ctx context.Context, name string) (*user.User, error)
return nil, rerror.ErrInvalidParams
}

return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool {
return value.Name() == name
}), rerror.ErrNotFound)
}
Expand All @@ -116,7 +115,7 @@ func (r *User) FindByNameOrEmail(ctx context.Context, nameOrEmail string) (*user
return nil, rerror.ErrInvalidParams
}

return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool {
return value.Email() == nameOrEmail || value.Name() == nameOrEmail
}), rerror.ErrNotFound)
}
Expand All @@ -130,7 +129,7 @@ func (r *User) FindByVerification(ctx context.Context, code string) (*user.User,
return nil, rerror.ErrInvalidParams
}

return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool {
return value.Verification() != nil && value.Verification().Code() == code

}), rerror.ErrNotFound)
Expand All @@ -141,7 +140,7 @@ func (r *User) FindBySubOrCreate(ctx context.Context, u *user.User, sub string)
return nil, r.err
}

u2 := r.data.Find(func(key accountdomain.UserID, value *user.User) bool {
u2 := r.data.Find(func(key user.ID, value *user.User) bool {
return value.ContainAuth(user.AuthFrom(sub))
})
if u2 == nil {
Expand Down Expand Up @@ -174,7 +173,7 @@ func (r *User) Save(ctx context.Context, u *user.User) error {
return nil
}

func (r *User) Remove(ctx context.Context, user accountdomain.UserID) error {
func (r *User) Remove(ctx context.Context, user user.ID) error {
if r.err != nil {
return r.err
}
Expand Down
Loading

0 comments on commit 81f4cb9

Please sign in to comment.