From 81f4cb96b88fd2ec4c5cbd874c69b24ccf2bde3e Mon Sep 17 00:00:00 2001 From: rot1024 Date: Wed, 4 Oct 2023 20:02:18 +0900 Subject: [PATCH] fix(account): add enforcer to enforce workspace member count --- account/accountdomain/user/user.go | 2 + .../accountdomain/workspace/workspace_list.go | 24 +- .../workspace/workspace_list_test.go | 48 ++-- .../accountmemory/user.go | 29 +- .../accountmemory/workspace.go | 26 +- .../accountmemory/workspace_test.go | 4 +- .../accountmongo/user.go | 17 +- .../accountmongo/workspace.go | 13 +- .../accountmongo/workspace_test.go | 14 +- .../accountusecase/accountinteractor/user.go | 4 +- .../accountinteractor/workspace.go | 84 ++++-- .../accountinteractor/workspace_test.go | 259 ++++++++++-------- account/accountusecase/accountrepo/user.go | 7 +- .../accountusecase/accountrepo/workspace.go | 16 +- 14 files changed, 311 insertions(+), 236 deletions(-) diff --git a/account/accountdomain/user/user.go b/account/accountdomain/user/user.go index 2ef069a..003a096 100644 --- a/account/accountdomain/user/user.go +++ b/account/accountdomain/user/user.go @@ -208,3 +208,5 @@ func (u *User) Clone() *User { passwordReset: util.CloneRef(u.passwordReset), } } + +type List []*User diff --git a/account/accountdomain/workspace/workspace_list.go b/account/accountdomain/workspace/workspace_list.go index e3e0732..26798b5 100644 --- a/account/accountdomain/workspace/workspace_list.go +++ b/account/accountdomain/workspace/workspace_list.go @@ -1,13 +1,13 @@ package workspace -type WorkspaceList []*Workspace +type List []*Workspace -func (l WorkspaceList) FilterByID(ids ...ID) WorkspaceList { +func (l List) FilterByID(ids ...ID) List { if l == nil { return nil } - res := make(WorkspaceList, 0, len(l)) + res := make(List, 0, len(l)) for _, id := range ids { var t2 *Workspace for _, t := range l { @@ -23,12 +23,12 @@ func (l WorkspaceList) FilterByID(ids ...ID) WorkspaceList { return res } -func (l WorkspaceList) FilterByUserRole(u UserID, r Role) WorkspaceList { +func (l List) FilterByUserRole(u UserID, r Role) List { if l == nil || u.IsEmpty() || r == "" { return nil } - res := make(WorkspaceList, 0, len(l)) + res := make(List, 0, len(l)) for _, t := range l { if m := t.Members().User(u); m != nil && m.Role == r { res = append(res, t) @@ -37,12 +37,12 @@ func (l WorkspaceList) FilterByUserRole(u UserID, r Role) WorkspaceList { return res } -func (l WorkspaceList) FilterByIntegrationRole(i IntegrationID, r Role) WorkspaceList { +func (l List) FilterByIntegrationRole(i IntegrationID, r Role) List { if l == nil || i.IsEmpty() || r == "" { return nil } - res := make(WorkspaceList, 0, len(l)) + res := make(List, 0, len(l)) for _, t := range l { if m := t.Members().Integration(i); m != nil && m.Role == r { res = append(res, t) @@ -51,12 +51,12 @@ func (l WorkspaceList) FilterByIntegrationRole(i IntegrationID, r Role) Workspac return res } -func (l WorkspaceList) FilterByUserRoleIncluding(u UserID, r Role) WorkspaceList { +func (l List) FilterByUserRoleIncluding(u UserID, r Role) List { if l == nil || u.IsEmpty() || r == "" { return nil } - res := make(WorkspaceList, 0, len(l)) + res := make(List, 0, len(l)) for _, t := range l { if m := t.Members().User(u); m != nil && m.Role.Includes(r) { res = append(res, t) @@ -65,12 +65,12 @@ func (l WorkspaceList) FilterByUserRoleIncluding(u UserID, r Role) WorkspaceList return res } -func (l WorkspaceList) FilterByIntegrationRoleIncluding(i IntegrationID, r Role) WorkspaceList { +func (l List) FilterByIntegrationRoleIncluding(i IntegrationID, r Role) List { if l == nil || i.IsEmpty() || r == "" { return nil } - res := make(WorkspaceList, 0, len(l)) + res := make(List, 0, len(l)) for _, t := range l { if m := t.Members().Integration(i); m != nil && m.Role.Includes(r) { res = append(res, t) @@ -79,7 +79,7 @@ func (l WorkspaceList) FilterByIntegrationRoleIncluding(i IntegrationID, r Role) return res } -func (l WorkspaceList) IDs() []ID { +func (l List) IDs() []ID { if l == nil { return nil } diff --git a/account/accountdomain/workspace/workspace_list_test.go b/account/accountdomain/workspace/workspace_list_test.go index 62c9b6e..f07a6ae 100644 --- a/account/accountdomain/workspace/workspace_list_test.go +++ b/account/accountdomain/workspace/workspace_list_test.go @@ -12,11 +12,11 @@ func TestWorkspaceList_FilterByID(t *testing.T) { t1 := &Workspace{id: tid1} t2 := &Workspace{id: tid2} - assert.Equal(t, WorkspaceList{t1}, WorkspaceList{t1, t2}.FilterByID(tid1)) - assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByID(tid2)) - assert.Equal(t, WorkspaceList{t1, t2}, WorkspaceList{t1, t2}.FilterByID(tid1, tid2)) - assert.Equal(t, WorkspaceList{}, WorkspaceList{t1, t2}.FilterByID(NewID())) - assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByID(tid1)) + assert.Equal(t, List{t1}, List{t1, t2}.FilterByID(tid1)) + assert.Equal(t, List{t2}, List{t1, t2}.FilterByID(tid2)) + assert.Equal(t, List{t1, t2}, List{t1, t2}.FilterByID(tid1, tid2)) + assert.Equal(t, List{}, List{t1, t2}.FilterByID(NewID())) + assert.Equal(t, List(nil), List(nil).FilterByID(tid1)) } func TestWorkspaceList_FilterByUserRole(t *testing.T) { @@ -40,10 +40,10 @@ func TestWorkspaceList_FilterByUserRole(t *testing.T) { }, } - assert.Equal(t, WorkspaceList{t1}, WorkspaceList{t1, t2}.FilterByUserRole(uid, RoleReader)) - assert.Equal(t, WorkspaceList{}, WorkspaceList{t1, t2}.FilterByUserRole(uid, RoleWriter)) - assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByUserRole(uid, RoleOwner)) - assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByUserRole(uid, RoleOwner)) + assert.Equal(t, List{t1}, List{t1, t2}.FilterByUserRole(uid, RoleReader)) + assert.Equal(t, List{}, List{t1, t2}.FilterByUserRole(uid, RoleWriter)) + assert.Equal(t, List{t2}, List{t1, t2}.FilterByUserRole(uid, RoleOwner)) + assert.Equal(t, List(nil), List(nil).FilterByUserRole(uid, RoleOwner)) } func TestWorkspaceList_FilterByIntegrationRole(t *testing.T) { @@ -67,10 +67,10 @@ func TestWorkspaceList_FilterByIntegrationRole(t *testing.T) { }, } - assert.Equal(t, WorkspaceList{t1}, WorkspaceList{t1, t2}.FilterByIntegrationRole(iid, RoleReader)) - assert.Equal(t, WorkspaceList{}, WorkspaceList{t1, t2}.FilterByIntegrationRole(iid, RoleOwner)) - assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByIntegrationRole(iid, RoleWriter)) - assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByIntegrationRole(iid, RoleOwner)) + assert.Equal(t, List{t1}, List{t1, t2}.FilterByIntegrationRole(iid, RoleReader)) + assert.Equal(t, List{}, List{t1, t2}.FilterByIntegrationRole(iid, RoleOwner)) + assert.Equal(t, List{t2}, List{t1, t2}.FilterByIntegrationRole(iid, RoleWriter)) + assert.Equal(t, List(nil), List(nil).FilterByIntegrationRole(iid, RoleOwner)) } func TestWorkspaceList_FilterByUserRoleIncluding(t *testing.T) { @@ -94,10 +94,10 @@ func TestWorkspaceList_FilterByUserRoleIncluding(t *testing.T) { }, } - assert.Equal(t, WorkspaceList{t1, t2}, WorkspaceList{t1, t2}.FilterByUserRoleIncluding(uid, RoleReader)) - assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByUserRoleIncluding(uid, RoleWriter)) - assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByUserRoleIncluding(uid, RoleOwner)) - assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByUserRoleIncluding(uid, RoleOwner)) + assert.Equal(t, List{t1, t2}, List{t1, t2}.FilterByUserRoleIncluding(uid, RoleReader)) + assert.Equal(t, List{t2}, List{t1, t2}.FilterByUserRoleIncluding(uid, RoleWriter)) + assert.Equal(t, List{t2}, List{t1, t2}.FilterByUserRoleIncluding(uid, RoleOwner)) + assert.Equal(t, List(nil), List(nil).FilterByUserRoleIncluding(uid, RoleOwner)) } func TestWorkspaceList_FilterByIntegrationRoleIncluding(t *testing.T) { @@ -121,10 +121,10 @@ func TestWorkspaceList_FilterByIntegrationRoleIncluding(t *testing.T) { }, } - assert.Equal(t, WorkspaceList{t1, t2}, WorkspaceList{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleReader)) - assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleWriter)) - assert.Equal(t, WorkspaceList{t2}, WorkspaceList{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleOwner)) - assert.Equal(t, WorkspaceList(nil), WorkspaceList(nil).FilterByIntegrationRoleIncluding(uid, RoleOwner)) + assert.Equal(t, List{t1, t2}, List{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleReader)) + assert.Equal(t, List{t2}, List{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleWriter)) + assert.Equal(t, List{t2}, List{t1, t2}.FilterByIntegrationRoleIncluding(uid, RoleOwner)) + assert.Equal(t, List(nil), List(nil).FilterByIntegrationRoleIncluding(uid, RoleOwner)) } func TestWorkspaceList_IDs(t *testing.T) { @@ -133,7 +133,7 @@ func TestWorkspaceList_IDs(t *testing.T) { t1 := &Workspace{id: wid1} t2 := &Workspace{id: wid2} - assert.Equal(t, []ID{wid1, wid2}, WorkspaceList{t1, t2}.IDs()) - assert.Equal(t, []ID{}, WorkspaceList{}.IDs()) - assert.Equal(t, []ID(nil), WorkspaceList(nil).IDs()) + assert.Equal(t, []ID{wid1, wid2}, List{t1, t2}.IDs()) + assert.Equal(t, []ID{}, List{}.IDs()) + assert.Equal(t, []ID(nil), List(nil).IDs()) } diff --git a/account/accountinfrastructure/accountmemory/user.go b/account/accountinfrastructure/accountmemory/user.go index a4c26e8..1e32c25 100644 --- a/account/accountinfrastructure/accountmemory/user.go +++ b/account/accountinfrastructure/accountmemory/user.go @@ -3,7 +3,6 @@ package accountmemory import ( "context" - "github.com/reearth/reearthx/account/accountdomain" "github.com/reearth/reearthx/account/accountdomain/user" "github.com/reearth/reearthx/account/accountusecase/accountrepo" "github.com/reearth/reearthx/rerror" @@ -11,13 +10,13 @@ import ( ) type User struct { - data *util.SyncMap[accountdomain.UserID, *user.User] + data *util.SyncMap[user.ID, *user.User] err error } func NewUser() *User { return &User{ - data: &util.SyncMap[accountdomain.UserID, *user.User]{}, + data: &util.SyncMap[user.ID, *user.User]{}, } } @@ -29,24 +28,24 @@ func NewUserWith(users ...*user.User) *User { return r } -func (r *User) FindByIDs(ctx context.Context, ids accountdomain.UserIDList) ([]*user.User, error) { +func (r *User) FindByIDs(ctx context.Context, ids user.IDList) (user.List, error) { if r.err != nil { return nil, r.err } - res := r.data.FindAll(func(key accountdomain.UserID, value *user.User) bool { + res := r.data.FindAll(func(key user.ID, value *user.User) bool { return ids.Has(key) }) return res, nil } -func (r *User) FindByID(ctx context.Context, v accountdomain.UserID) (*user.User, error) { +func (r *User) FindByID(ctx context.Context, v user.ID) (*user.User, error) { if r.err != nil { return nil, r.err } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool { return key == v }), rerror.ErrNotFound) } @@ -60,7 +59,7 @@ func (r *User) FindBySub(ctx context.Context, auth0sub string) (*user.User, erro return nil, rerror.ErrInvalidParams } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool { return value.ContainAuth(user.AuthFrom(auth0sub)) }), rerror.ErrNotFound) } @@ -74,7 +73,7 @@ func (r *User) FindByPasswordResetRequest(ctx context.Context, token string) (*u return nil, rerror.ErrInvalidParams } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool { return value.PasswordReset() != nil && value.PasswordReset().Token == token }), rerror.ErrNotFound) } @@ -88,7 +87,7 @@ func (r *User) FindByEmail(ctx context.Context, email string) (*user.User, error return nil, rerror.ErrInvalidParams } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool { return value.Email() == email }), rerror.ErrNotFound) } @@ -102,7 +101,7 @@ func (r *User) FindByName(ctx context.Context, name string) (*user.User, error) return nil, rerror.ErrInvalidParams } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool { return value.Name() == name }), rerror.ErrNotFound) } @@ -116,7 +115,7 @@ func (r *User) FindByNameOrEmail(ctx context.Context, nameOrEmail string) (*user return nil, rerror.ErrInvalidParams } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool { return value.Email() == nameOrEmail || value.Name() == nameOrEmail }), rerror.ErrNotFound) } @@ -130,7 +129,7 @@ func (r *User) FindByVerification(ctx context.Context, code string) (*user.User, return nil, rerror.ErrInvalidParams } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + return rerror.ErrIfNil(r.data.Find(func(key user.ID, value *user.User) bool { return value.Verification() != nil && value.Verification().Code() == code }), rerror.ErrNotFound) @@ -141,7 +140,7 @@ func (r *User) FindBySubOrCreate(ctx context.Context, u *user.User, sub string) return nil, r.err } - u2 := r.data.Find(func(key accountdomain.UserID, value *user.User) bool { + u2 := r.data.Find(func(key user.ID, value *user.User) bool { return value.ContainAuth(user.AuthFrom(sub)) }) if u2 == nil { @@ -174,7 +173,7 @@ func (r *User) Save(ctx context.Context, u *user.User) error { return nil } -func (r *User) Remove(ctx context.Context, user accountdomain.UserID) error { +func (r *User) Remove(ctx context.Context, user user.ID) error { if r.err != nil { return r.err } diff --git a/account/accountinfrastructure/accountmemory/workspace.go b/account/accountinfrastructure/accountmemory/workspace.go index 628d62e..b5a2b10 100644 --- a/account/accountinfrastructure/accountmemory/workspace.go +++ b/account/accountinfrastructure/accountmemory/workspace.go @@ -12,13 +12,13 @@ import ( ) type Workspace struct { - data *util.SyncMap[accountdomain.WorkspaceID, *workspace.Workspace] + data *util.SyncMap[workspace.ID, *workspace.Workspace] err error } func NewWorkspace() *Workspace { return &Workspace{ - data: &util.SyncMap[accountdomain.WorkspaceID, *workspace.Workspace]{}, + data: &util.SyncMap[workspace.ID, *workspace.Workspace]{}, } } @@ -30,44 +30,44 @@ func NewWorkspaceWith(workspaces ...*workspace.Workspace) *Workspace { return r } -func (r *Workspace) FindByUser(ctx context.Context, i accountdomain.UserID) (workspace.WorkspaceList, error) { +func (r *Workspace) FindByUser(ctx context.Context, i accountdomain.UserID) (workspace.List, error) { if r.err != nil { return nil, r.err } - return rerror.ErrIfNil(r.data.FindAll(func(key accountdomain.WorkspaceID, value *workspace.Workspace) bool { + return rerror.ErrIfNil(r.data.FindAll(func(key workspace.ID, value *workspace.Workspace) bool { return value.Members().HasUser(i) }), rerror.ErrNotFound) } -func (r *Workspace) FindByIntegration(_ context.Context, i accountdomain.IntegrationID) (workspace.WorkspaceList, error) { +func (r *Workspace) FindByIntegration(_ context.Context, i workspace.IntegrationID) (workspace.List, error) { if r.err != nil { return nil, r.err } - return rerror.ErrIfNil(r.data.FindAll(func(key accountdomain.WorkspaceID, value *workspace.Workspace) bool { + return rerror.ErrIfNil(r.data.FindAll(func(key workspace.ID, value *workspace.Workspace) bool { return value.Members().HasIntegration(i) }), rerror.ErrNotFound) } -func (r *Workspace) FindByIDs(ctx context.Context, ids accountdomain.WorkspaceIDList) (workspace.WorkspaceList, error) { +func (r *Workspace) FindByIDs(ctx context.Context, ids workspace.IDList) (workspace.List, error) { if r.err != nil { return nil, r.err } - res := r.data.FindAll(func(key accountdomain.WorkspaceID, value *workspace.Workspace) bool { + res := r.data.FindAll(func(key workspace.ID, value *workspace.Workspace) bool { return ids.Has(key) }) slices.SortFunc(res, func(a, b *workspace.Workspace) bool { return a.ID().Compare(b.ID()) < 0 }) return res, nil } -func (r *Workspace) FindByID(ctx context.Context, v accountdomain.WorkspaceID) (*workspace.Workspace, error) { +func (r *Workspace) FindByID(ctx context.Context, v workspace.ID) (*workspace.Workspace, error) { if r.err != nil { return nil, r.err } - return rerror.ErrIfNil(r.data.Find(func(key accountdomain.WorkspaceID, value *workspace.Workspace) bool { + return rerror.ErrIfNil(r.data.Find(func(key workspace.ID, value *workspace.Workspace) bool { return key == v }), rerror.ErrNotFound) } @@ -81,7 +81,7 @@ func (r *Workspace) Save(ctx context.Context, t *workspace.Workspace) error { return nil } -func (r *Workspace) SaveAll(ctx context.Context, workspaces []*workspace.Workspace) error { +func (r *Workspace) SaveAll(ctx context.Context, workspaces workspace.List) error { if r.err != nil { return r.err } @@ -92,7 +92,7 @@ func (r *Workspace) SaveAll(ctx context.Context, workspaces []*workspace.Workspa return nil } -func (r *Workspace) Remove(ctx context.Context, wid accountdomain.WorkspaceID) error { +func (r *Workspace) Remove(ctx context.Context, wid workspace.ID) error { if r.err != nil { return r.err } @@ -101,7 +101,7 @@ func (r *Workspace) Remove(ctx context.Context, wid accountdomain.WorkspaceID) e return nil } -func (r *Workspace) RemoveAll(ctx context.Context, ids accountdomain.WorkspaceIDList) error { +func (r *Workspace) RemoveAll(ctx context.Context, ids workspace.IDList) error { if r.err != nil { return r.err } diff --git a/account/accountinfrastructure/accountmemory/workspace_test.go b/account/accountinfrastructure/accountmemory/workspace_test.go index af11ba8..96822e2 100644 --- a/account/accountinfrastructure/accountmemory/workspace_test.go +++ b/account/accountinfrastructure/accountmemory/workspace_test.go @@ -60,7 +60,7 @@ func TestWorkspace_FindByIDs(t *testing.T) { r.data.Store(ws2.ID(), ws2) ids := accountdomain.WorkspaceIDList{ws.ID()} - wsl := workspace.WorkspaceList{ws} + wsl := workspace.List{ws} out, err := r.FindByIDs(ctx, ids) assert.NoError(t, err) assert.Equal(t, wsl, out) @@ -78,7 +78,7 @@ func TestWorkspace_FindByUser(t *testing.T) { data: &util.SyncMap[accountdomain.WorkspaceID, *workspace.Workspace]{}, } r.data.Store(ws.ID(), ws) - wsl := workspace.WorkspaceList{ws} + wsl := workspace.List{ws} out, err := r.FindByUser(ctx, u.ID()) assert.NoError(t, err) assert.Equal(t, wsl, out) diff --git a/account/accountinfrastructure/accountmongo/user.go b/account/accountinfrastructure/accountmongo/user.go index 4766a74..4177450 100644 --- a/account/accountinfrastructure/accountmongo/user.go +++ b/account/accountinfrastructure/accountmongo/user.go @@ -3,7 +3,6 @@ package accountmongo import ( "context" - "github.com/reearth/reearthx/account/accountdomain" "github.com/reearth/reearthx/account/accountdomain/user" "github.com/reearth/reearthx/account/accountinfrastructure/accountmongo/mongodoc" "github.com/reearth/reearthx/account/accountusecase/accountrepo" @@ -31,7 +30,11 @@ func (r *User) Init() error { return createIndexes(context.Background(), r.client, userIndexes, userUniqueIndexes) } -func (r *User) FindByIDs(ctx context.Context, ids accountdomain.UserIDList) ([]*user.User, error) { +func (r *User) FindByID(ctx context.Context, id2 user.ID) (*user.User, error) { + return r.findOne(ctx, bson.M{"id": id2.String()}) +} + +func (r *User) FindByIDs(ctx context.Context, ids user.IDList) (user.List, error) { if len(ids) == 0 { return nil, nil } @@ -45,10 +48,6 @@ func (r *User) FindByIDs(ctx context.Context, ids accountdomain.UserIDList) ([]* return filterUsers(ids, res), nil } -func (r *User) FindByID(ctx context.Context, id2 accountdomain.UserID) (*user.User, error) { - return r.findOne(ctx, bson.M{"id": id2.String()}) -} - func (r *User) FindBySub(ctx context.Context, auth0sub string) (*user.User, error) { return r.findOne(ctx, bson.M{ "$or": []bson.M{ @@ -145,11 +144,11 @@ func (r *User) Save(ctx context.Context, user *user.User) error { return r.client.SaveOne(ctx, id, doc) } -func (r *User) Remove(ctx context.Context, user accountdomain.UserID) error { +func (r *User) Remove(ctx context.Context, user user.ID) error { return r.client.RemoveOne(ctx, bson.M{"id": user.String()}) } -func (r *User) find(ctx context.Context, filter any) ([]*user.User, error) { +func (r *User) find(ctx context.Context, filter any) (user.List, error) { c := mongodoc.NewUserConsumer() if err := r.client.Find(ctx, filter, c); err != nil { return nil, err @@ -165,7 +164,7 @@ func (r *User) findOne(ctx context.Context, filter any) (*user.User, error) { return c.Result[0], nil } -func filterUsers(ids []accountdomain.UserID, rows []*user.User) []*user.User { +func filterUsers(ids []user.ID, rows []*user.User) []*user.User { res := make([]*user.User, 0, len(ids)) for _, id := range ids { var r2 *user.User diff --git a/account/accountinfrastructure/accountmongo/workspace.go b/account/accountinfrastructure/accountmongo/workspace.go index 977f4ac..f77ac36 100644 --- a/account/accountinfrastructure/accountmongo/workspace.go +++ b/account/accountinfrastructure/accountmongo/workspace.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/reearth/reearthx/account/accountdomain" + "github.com/reearth/reearthx/account/accountdomain/user" "github.com/reearth/reearthx/account/accountdomain/workspace" "github.com/reearth/reearthx/account/accountinfrastructure/accountmongo/mongodoc" "github.com/reearth/reearthx/account/accountusecase/accountrepo" @@ -32,7 +33,7 @@ func (r *Workspace) Init() error { return createIndexes(context.Background(), r.client, nil, workspaceUniqueIndexes) } -func (r *Workspace) FindByUser(ctx context.Context, id accountdomain.UserID) (workspace.WorkspaceList, error) { +func (r *Workspace) FindByUser(ctx context.Context, id user.ID) (workspace.List, error) { return r.find(ctx, bson.M{ "members." + strings.Replace(id.String(), ".", "", -1): bson.M{ "$exists": true, @@ -40,7 +41,7 @@ func (r *Workspace) FindByUser(ctx context.Context, id accountdomain.UserID) (wo }) } -func (r *Workspace) FindByIntegration(ctx context.Context, id accountdomain.IntegrationID) (workspace.WorkspaceList, error) { +func (r *Workspace) FindByIntegration(ctx context.Context, id workspace.IntegrationID) (workspace.List, error) { return r.find(ctx, bson.M{ "integrations." + strings.Replace(id.String(), ".", "", -1): bson.M{ "$exists": true, @@ -48,7 +49,7 @@ func (r *Workspace) FindByIntegration(ctx context.Context, id accountdomain.Inte }) } -func (r *Workspace) FindByIDs(ctx context.Context, ids accountdomain.WorkspaceIDList) (workspace.WorkspaceList, error) { +func (r *Workspace) FindByIDs(ctx context.Context, ids accountdomain.WorkspaceIDList) (workspace.List, error) { if len(ids) == 0 { return nil, nil } @@ -70,7 +71,7 @@ func (r *Workspace) Save(ctx context.Context, workspace *workspace.Workspace) er return r.client.SaveOne(ctx, id, doc) } -func (r *Workspace) SaveAll(ctx context.Context, workspaces []*workspace.Workspace) error { +func (r *Workspace) SaveAll(ctx context.Context, workspaces workspace.List) error { if len(workspaces) == 0 { return nil } @@ -95,7 +96,7 @@ func (r *Workspace) RemoveAll(ctx context.Context, ids accountdomain.WorkspaceID }) } -func (r *Workspace) find(ctx context.Context, filter any) (workspace.WorkspaceList, error) { +func (r *Workspace) find(ctx context.Context, filter any) (workspace.List, error) { c := mongodoc.NewWorkspaceConsumer() if err := r.client.Find(ctx, filter, c); err != nil { return nil, err @@ -111,6 +112,6 @@ func (r *Workspace) findOne(ctx context.Context, filter any) (*workspace.Workspa return c.Result[0], nil } -func filterWorkspaces(ids []accountdomain.WorkspaceID, rows workspace.WorkspaceList) workspace.WorkspaceList { +func filterWorkspaces(ids []accountdomain.WorkspaceID, rows workspace.List) workspace.List { return rows.FilterByID(ids...) } diff --git a/account/accountinfrastructure/accountmongo/workspace_test.go b/account/accountinfrastructure/accountmongo/workspace_test.go index d4a77dd..5a4139f 100644 --- a/account/accountinfrastructure/accountmongo/workspace_test.go +++ b/account/accountinfrastructure/accountmongo/workspace_test.go @@ -69,18 +69,18 @@ func TestWorkspace_FindByIDs(t *testing.T) { tests := []struct { Name string Input accountdomain.WorkspaceIDList - RepoData, Expected workspace.WorkspaceList + RepoData, Expected workspace.List }{ { Name: "must find users", - RepoData: workspace.WorkspaceList{ws1, ws2}, + RepoData: workspace.List{ws1, ws2}, Input: accountdomain.WorkspaceIDList{ws1.ID(), ws2.ID()}, - Expected: workspace.WorkspaceList{ws1, ws2}, + Expected: workspace.List{ws1, ws2}, }, { Name: "must not find any user", Input: accountdomain.WorkspaceIDList{ws3.ID()}, - RepoData: workspace.WorkspaceList{ws2, ws1}, + RepoData: workspace.List{ws2, ws1}, }, } @@ -118,13 +118,13 @@ func TestWorkspace_FindByUser(t *testing.T) { Name string Input accountdomain.UserID RepoData *workspace.Workspace - Expected workspace.WorkspaceList + Expected workspace.List }{ { Name: "must find a workspace", Input: u.ID(), RepoData: ws, - Expected: workspace.WorkspaceList{ws}, + Expected: workspace.List{ws}, }, { Name: "must not find any workspace", @@ -184,7 +184,7 @@ func TestWorkspace_RemoveAll(t *testing.T) { repo := NewWorkspace(client) ctx := context.Background() - err := repo.SaveAll(ctx, workspace.WorkspaceList{ws1, ws2}) + err := repo.SaveAll(ctx, workspace.List{ws1, ws2}) assert.NoError(t, err) err = repo.RemoveAll(ctx, accountdomain.WorkspaceIDList{ws1.ID(), ws2.ID()}) diff --git a/account/accountusecase/accountinteractor/user.go b/account/accountusecase/accountinteractor/user.go index 784ff1e..0eae128 100644 --- a/account/accountusecase/accountinteractor/user.go +++ b/account/accountusecase/accountinteractor/user.go @@ -108,6 +108,7 @@ func (i *User) UpdateMe(ctx context.Context, p accountinterfaces.UpdateMeParam, if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction(), func(ctx context.Context) (*user.User, error) { if p.Password != nil { if p.PasswordConfirmation == nil || *p.Password != *p.PasswordConfirmation { @@ -193,6 +194,7 @@ func (i *User) RemoveMyAuth(ctx context.Context, authProvider string, operator * if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction(), func(ctx context.Context) (*user.User, error) { u, err = i.repos.User.FindByID(ctx, *operator.User) if err != nil { @@ -296,7 +298,6 @@ func (i *User) VerifyUser(ctx context.Context, code string) (*user.User, error) } func (i *User) StartPasswordReset(ctx context.Context, email string) error { return Run0(ctx, nil, i.repos, Usecase().Transaction(), func(ctx context.Context) error { - u, err := i.repos.User.FindByEmail(ctx, email) if err != nil { return err @@ -339,6 +340,7 @@ func (i *User) StartPasswordReset(ctx context.Context, email string) error { return nil }) } + func (i *User) PasswordReset(ctx context.Context, password string, token string) error { return Run0(ctx, nil, i.repos, Usecase().Transaction(), func(ctx context.Context) error { u, err := i.repos.User.FindByPasswordResetRequest(ctx, token) diff --git a/account/accountusecase/accountinteractor/workspace.go b/account/accountusecase/accountinteractor/workspace.go index 640312a..61b1bd1 100644 --- a/account/accountusecase/accountinteractor/workspace.go +++ b/account/accountusecase/accountinteractor/workspace.go @@ -13,36 +13,37 @@ import ( "golang.org/x/exp/maps" ) +type WorkspaceMemberCountEnforcer func(context.Context, *workspace.Workspace, user.List, *accountusecase.Operator) error + type Workspace struct { - repos *accountrepo.Container + repos *accountrepo.Container + enforceMemberCount WorkspaceMemberCountEnforcer } -func NewWorkspace(r *accountrepo.Container) accountinterfaces.Workspace { +func NewWorkspace(r *accountrepo.Container, enforceMemberCount WorkspaceMemberCountEnforcer) accountinterfaces.Workspace { return &Workspace{ - repos: r, + repos: r, + enforceMemberCount: enforceMemberCount, } } func (i *Workspace) Fetch(ctx context.Context, ids accountdomain.WorkspaceIDList, operator *accountusecase.Operator) ([]*workspace.Workspace, error) { - return Run1(ctx, operator, i.repos, Usecase().Transaction(), func(ctx context.Context) ([]*workspace.Workspace, error) { - res, err := i.repos.Workspace.FindByIDs(ctx, ids) - res2, err := accountinterfaces.FilterWorkspaces(res, operator, err, false) - return res2, err - }) + res, err := i.repos.Workspace.FindByIDs(ctx, ids) + res2, err := accountinterfaces.FilterWorkspaces(res, operator, err, false) + return res2, err } func (i *Workspace) FindByUser(ctx context.Context, id accountdomain.UserID, operator *accountusecase.Operator) ([]*workspace.Workspace, error) { - return Run1(ctx, operator, i.repos, Usecase().Transaction(), func(ctx context.Context) ([]*workspace.Workspace, error) { - res, err := i.repos.Workspace.FindByUser(ctx, id) - res2, err := accountinterfaces.FilterWorkspaces(res, operator, err, true) - return res2, err - }) + res, err := i.repos.Workspace.FindByUser(ctx, id) + res2, err := accountinterfaces.FilterWorkspaces(res, operator, err, true) + return res2, err } func (i *Workspace) Create(ctx context.Context, name string, firstUser accountdomain.UserID, operator *accountusecase.Operator) (_ *workspace.Workspace, err error) { if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction(), func(ctx context.Context) (*workspace.Workspace, error) { if len(strings.TrimSpace(name)) == 0 { return nil, user.ErrInvalidName @@ -65,6 +66,7 @@ func (i *Workspace) Create(ctx context.Context, name string, firstUser accountdo } operator.AddNewWorkspace(ws.ID()) + i.applyDefaultPolicy(ws, operator) return ws, nil }) } @@ -73,11 +75,13 @@ func (i *Workspace) Update(ctx context.Context, id accountdomain.WorkspaceID, na if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction(), func(ctx context.Context) (*workspace.Workspace, error) { ws, err := i.repos.Workspace.FindByID(ctx, id) if err != nil { return nil, err } + if ws.IsPersonal() { return nil, workspace.ErrCannotModifyPersonalWorkspace } @@ -96,6 +100,7 @@ func (i *Workspace) Update(ctx context.Context, id accountdomain.WorkspaceID, na return nil, err } + i.applyDefaultPolicy(ws, operator) return ws, nil }) } @@ -104,11 +109,13 @@ func (i *Workspace) AddUserMember(ctx context.Context, workspaceID accountdomain if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction().WithOwnableWorkspaces(workspaceID), func(ctx context.Context) (*workspace.Workspace, error) { ws, err := i.repos.Workspace.FindByID(ctx, workspaceID) if err != nil { return nil, err } + if ws.IsPersonal() { return nil, workspace.ErrCannotModifyPersonalWorkspace } @@ -118,6 +125,12 @@ func (i *Workspace) AddUserMember(ctx context.Context, workspaceID accountdomain return nil, err } + if i.enforceMemberCount != nil { + if err := i.enforceMemberCount(ctx, ws, ul, operator); err != nil { + return nil, err + } + } + for _, m := range ul { err = ws.Members().Join(m.ID(), users[m.ID()], *operator.User) if err != nil { @@ -130,6 +143,7 @@ func (i *Workspace) AddUserMember(ctx context.Context, workspaceID accountdomain return nil, err } + i.applyDefaultPolicy(ws, operator) return ws, nil }) } @@ -138,23 +152,25 @@ func (i *Workspace) AddIntegrationMember(ctx context.Context, wId accountdomain. if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction().WithOwnableWorkspaces(wId), func(ctx context.Context) (*workspace.Workspace, error) { - workspace, err := i.repos.Workspace.FindByID(ctx, wId) + ws, err := i.repos.Workspace.FindByID(ctx, wId) if err != nil { return nil, err } - err = workspace.Members().AddIntegration(iId, role, *operator.User) + err = ws.Members().AddIntegration(iId, role, *operator.User) if err != nil { return nil, err } - err = i.repos.Workspace.Save(ctx, workspace) + err = i.repos.Workspace.Save(ctx, ws) if err != nil { return nil, err } - return workspace, nil + i.applyDefaultPolicy(ws, operator) + return ws, nil }) } @@ -162,11 +178,13 @@ func (i *Workspace) RemoveUserMember(ctx context.Context, id accountdomain.Works if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction(), func(ctx context.Context) (*workspace.Workspace, error) { ws, err := i.repos.Workspace.FindByID(ctx, id) if err != nil { return nil, err } + if ws.IsPersonal() { return nil, workspace.ErrCannotModifyPersonalWorkspace } @@ -191,6 +209,7 @@ func (i *Workspace) RemoveUserMember(ctx context.Context, id accountdomain.Works return nil, err } + i.applyDefaultPolicy(ws, operator) return ws, nil }) } @@ -199,23 +218,25 @@ func (i *Workspace) RemoveIntegration(ctx context.Context, wId accountdomain.Wor if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().WithOwnableWorkspaces(wId).Transaction(), func(ctx context.Context) (*workspace.Workspace, error) { - workspace, err := i.repos.Workspace.FindByID(ctx, wId) + ws, err := i.repos.Workspace.FindByID(ctx, wId) if err != nil { return nil, err } - err = workspace.Members().DeleteIntegration(iId) + err = ws.Members().DeleteIntegration(iId) if err != nil { return nil, err } - err = i.repos.Workspace.Save(ctx, workspace) + err = i.repos.Workspace.Save(ctx, ws) if err != nil { return nil, err } - return workspace, nil + i.applyDefaultPolicy(ws, operator) + return ws, nil }) } @@ -223,11 +244,13 @@ func (i *Workspace) UpdateUserMember(ctx context.Context, id accountdomain.Works if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().Transaction().WithOwnableWorkspaces(id), func(ctx context.Context) (*workspace.Workspace, error) { ws, err := i.repos.Workspace.FindByID(ctx, id) if err != nil { return nil, err } + if ws.IsPersonal() { return nil, workspace.ErrCannotModifyPersonalWorkspace } @@ -246,6 +269,7 @@ func (i *Workspace) UpdateUserMember(ctx context.Context, id accountdomain.Works return nil, err } + i.applyDefaultPolicy(ws, operator) return ws, nil }) } @@ -254,23 +278,25 @@ func (i *Workspace) UpdateIntegration(ctx context.Context, wId accountdomain.Wor if operator.User == nil { return nil, accountinterfaces.ErrInvalidOperator } + return Run1(ctx, operator, i.repos, Usecase().WithOwnableWorkspaces(wId).Transaction(), func(ctx context.Context) (*workspace.Workspace, error) { - workspace, err := i.repos.Workspace.FindByID(ctx, wId) + ws, err := i.repos.Workspace.FindByID(ctx, wId) if err != nil { return nil, err } - err = workspace.Members().UpdateIntegrationRole(iId, role) + err = ws.Members().UpdateIntegrationRole(iId, role) if err != nil { return nil, err } - err = i.repos.Workspace.Save(ctx, workspace) + err = i.repos.Workspace.Save(ctx, ws) if err != nil { return nil, err } - return workspace, nil + i.applyDefaultPolicy(ws, operator) + return ws, nil }) } @@ -278,11 +304,13 @@ func (i *Workspace) Remove(ctx context.Context, id accountdomain.WorkspaceID, op if operator.User == nil { return accountinterfaces.ErrInvalidOperator } + return Run0(ctx, operator, i.repos, Usecase().Transaction().WithOwnableWorkspaces(id), func(ctx context.Context) error { ws, err := i.repos.Workspace.FindByID(ctx, id) if err != nil { return err } + if ws.IsPersonal() { return workspace.ErrCannotModifyPersonalWorkspace } @@ -295,3 +323,9 @@ func (i *Workspace) Remove(ctx context.Context, id accountdomain.WorkspaceID, op return nil }) } + +func (i *Workspace) applyDefaultPolicy(ws *workspace.Workspace, o *accountusecase.Operator) { + if ws.Policy() == nil && o.DefaultPolicy != nil { + ws.SetPolicy(o.DefaultPolicy) + } +} diff --git a/account/accountusecase/accountinteractor/workspace_test.go b/account/accountusecase/accountinteractor/workspace_test.go index 2e40691..4b6bc39 100644 --- a/account/accountusecase/accountinteractor/workspace_test.go +++ b/account/accountusecase/accountinteractor/workspace_test.go @@ -22,22 +22,22 @@ func TestWorkspace_Create(t *testing.T) { db := accountmemory.New() u := user.New().NewID().Name("aaa").Email("aaa@bbb.com").Workspace(accountdomain.NewWorkspaceID()).MustBuild() - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) op := &accountusecase.Operator{User: lo.ToPtr(u.ID())} - workspace, err := workspaceUC.Create(ctx, "workspace name", u.ID(), op) + ws, err := workspaceUC.Create(ctx, "workspace name", u.ID(), op) assert.NoError(t, err) - assert.NotNil(t, workspace) + assert.NotNil(t, ws) - resultWorkspaces, _ := workspaceUC.Fetch(ctx, []accountdomain.WorkspaceID{workspace.ID()}, &accountusecase.Operator{ - ReadableWorkspaces: []accountdomain.WorkspaceID{workspace.ID()}, + resultWorkspaces, _ := workspaceUC.Fetch(ctx, []workspace.ID{ws.ID()}, &accountusecase.Operator{ + ReadableWorkspaces: []workspace.ID{ws.ID()}, }) assert.NotNil(t, resultWorkspaces) assert.NotEmpty(t, resultWorkspaces) - assert.Equal(t, resultWorkspaces[0].ID(), workspace.ID()) + assert.Equal(t, resultWorkspaces[0].ID(), ws.ID()) assert.Equal(t, resultWorkspaces[0].Name(), "workspace name") - assert.Equal(t, accountdomain.WorkspaceIDList{resultWorkspaces[0].ID()}, op.OwningWorkspaces) + assert.Equal(t, workspace.IDList{resultWorkspaces[0].ID()}, op.OwningWorkspaces) // mock workspace error wantErr := errors.New("test") @@ -56,14 +56,14 @@ func TestWorkspace_Fetch(t *testing.T) { u := user.New().NewID().Name("aaa").Email("aaa@bbb.com").Workspace(id1).MustBuild() op := &accountusecase.Operator{ User: lo.ToPtr(u.ID()), - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2}, + ReadableWorkspaces: []workspace.ID{id1, id2}, } tests := []struct { name string seeds []*workspace.Workspace args struct { - ids []accountdomain.WorkspaceID + ids []workspace.ID operator *accountusecase.Operator } want []*workspace.Workspace @@ -74,10 +74,10 @@ func TestWorkspace_Fetch(t *testing.T) { name: "Fetch 1 of 2", seeds: []*workspace.Workspace{w1, w2}, args: struct { - ids []accountdomain.WorkspaceID + ids []workspace.ID operator *accountusecase.Operator }{ - ids: []accountdomain.WorkspaceID{id1}, + ids: []workspace.ID{id1}, operator: op, }, want: []*workspace.Workspace{w1}, @@ -87,10 +87,10 @@ func TestWorkspace_Fetch(t *testing.T) { name: "Fetch 2 of 2", seeds: []*workspace.Workspace{w1, w2}, args: struct { - ids []accountdomain.WorkspaceID + ids []workspace.ID operator *accountusecase.Operator }{ - ids: []accountdomain.WorkspaceID{id1, id2}, + ids: []workspace.ID{id1, id2}, operator: op, }, want: []*workspace.Workspace{w1, w2}, @@ -100,10 +100,10 @@ func TestWorkspace_Fetch(t *testing.T) { name: "Fetch 1 of 0", seeds: []*workspace.Workspace{}, args: struct { - ids []accountdomain.WorkspaceID + ids []workspace.ID operator *accountusecase.Operator }{ - ids: []accountdomain.WorkspaceID{id1}, + ids: []workspace.ID{id1}, operator: op, }, want: nil, @@ -113,10 +113,10 @@ func TestWorkspace_Fetch(t *testing.T) { name: "Fetch 2 of 0", seeds: []*workspace.Workspace{}, args: struct { - ids []accountdomain.WorkspaceID + ids []workspace.ID operator *accountusecase.Operator }{ - ids: []accountdomain.WorkspaceID{id1, id2}, + ids: []workspace.ID{id1, id2}, operator: op, }, want: nil, @@ -143,7 +143,7 @@ func TestWorkspace_Fetch(t *testing.T) { err := db.Workspace.Save(ctx, p) assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) got, err := workspaceUC.Fetch(ctx, tc.args.ids, tc.args.operator) if tc.wantErr != nil { @@ -166,14 +166,14 @@ func TestWorkspace_FindByUser(t *testing.T) { u := user.New().NewID().Name("aaa").Email("aaa@bbb.com").Workspace(id1).MustBuild() op := &accountusecase.Operator{ User: lo.ToPtr(u.ID()), - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2}, + ReadableWorkspaces: []workspace.ID{id1, id2}, } tests := []struct { name string seeds []*workspace.Workspace args struct { - userID accountdomain.UserID + userID user.ID operator *accountusecase.Operator } want []*workspace.Workspace @@ -184,7 +184,7 @@ func TestWorkspace_FindByUser(t *testing.T) { name: "Fetch 1 of 2", seeds: []*workspace.Workspace{w1, w2}, args: struct { - userID accountdomain.UserID + userID user.ID operator *accountusecase.Operator }{ userID: userID, @@ -197,7 +197,7 @@ func TestWorkspace_FindByUser(t *testing.T) { name: "Fetch 1 of 0", seeds: []*workspace.Workspace{}, args: struct { - userID accountdomain.UserID + userID user.ID operator *accountusecase.Operator }{ userID: userID, @@ -210,7 +210,7 @@ func TestWorkspace_FindByUser(t *testing.T) { name: "Fetch 0 of 1", seeds: []*workspace.Workspace{w2}, args: struct { - userID accountdomain.UserID + userID user.ID operator *accountusecase.Operator }{ userID: userID, @@ -240,7 +240,7 @@ func TestWorkspace_FindByUser(t *testing.T) { err := db.Workspace.Save(ctx, p) assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) got, err := workspaceUC.FindByUser(ctx, tc.args.userID, tc.args.operator) if tc.wantErr != nil { @@ -265,15 +265,15 @@ func TestWorkspace_Update(t *testing.T) { op := &accountusecase.Operator{ User: &userID, - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2, id3}, - OwningWorkspaces: []accountdomain.WorkspaceID{id1}, + ReadableWorkspaces: []workspace.ID{id1, id2, id3}, + OwningWorkspaces: []workspace.ID{id1}, } tests := []struct { name string seeds []*workspace.Workspace args struct { - wId accountdomain.WorkspaceID + wId workspace.ID newName string operator *accountusecase.Operator } @@ -285,7 +285,7 @@ func TestWorkspace_Update(t *testing.T) { name: "Update 1", seeds: []*workspace.Workspace{w1, w2}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID newName string operator *accountusecase.Operator }{ @@ -300,7 +300,7 @@ func TestWorkspace_Update(t *testing.T) { name: "Update 2", seeds: []*workspace.Workspace{}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID newName string operator *accountusecase.Operator }{ @@ -315,7 +315,7 @@ func TestWorkspace_Update(t *testing.T) { name: "Update 3", seeds: []*workspace.Workspace{w3}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID newName string operator *accountusecase.Operator }{ @@ -329,7 +329,7 @@ func TestWorkspace_Update(t *testing.T) { { name: "mock error", args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID newName string operator *accountusecase.Operator }{ @@ -354,7 +354,7 @@ func TestWorkspace_Update(t *testing.T) { err := db.Workspace.Save(ctx, p) assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) got, err := workspaceUC.Update(ctx, tc.args.wId, tc.args.newName, tc.args.operator) if tc.wantErr != nil { @@ -387,15 +387,15 @@ func TestWorkspace_Remove(t *testing.T) { op := &accountusecase.Operator{ User: &userID, - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2, id3}, - OwningWorkspaces: []accountdomain.WorkspaceID{id1, id4, id5, id6}, + ReadableWorkspaces: []workspace.ID{id1, id2, id3}, + OwningWorkspaces: []workspace.ID{id1, id4, id5, id6}, } tests := []struct { name string seeds []*workspace.Workspace args struct { - wId accountdomain.WorkspaceID + wId workspace.ID operator *accountusecase.Operator } wantErr error @@ -406,7 +406,7 @@ func TestWorkspace_Remove(t *testing.T) { name: "Remove 1", seeds: []*workspace.Workspace{w1, w2}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID operator *accountusecase.Operator }{ wId: id1, @@ -419,7 +419,7 @@ func TestWorkspace_Remove(t *testing.T) { name: "Update 2", seeds: []*workspace.Workspace{w1, w2}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID operator *accountusecase.Operator }{ wId: id2, @@ -432,7 +432,7 @@ func TestWorkspace_Remove(t *testing.T) { name: "Update 3", seeds: []*workspace.Workspace{w3}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID operator *accountusecase.Operator }{ wId: id3, @@ -445,7 +445,7 @@ func TestWorkspace_Remove(t *testing.T) { name: "Remove 4", seeds: []*workspace.Workspace{w4}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID operator *accountusecase.Operator }{ wId: id4, @@ -457,7 +457,7 @@ func TestWorkspace_Remove(t *testing.T) { { name: "mock workspace error", args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID operator *accountusecase.Operator }{ wId: id5, @@ -482,7 +482,7 @@ func TestWorkspace_Remove(t *testing.T) { err := db.Workspace.Save(ctx, p) assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) err := workspaceUC.Remove(ctx, tc.args.wId, tc.args.operator) if tc.wantErr != nil { assert.Equal(t, tc.wantErr, err) @@ -516,17 +516,18 @@ func TestWorkspace_AddMember(t *testing.T) { op := &accountusecase.Operator{ User: &userID, - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2}, - OwningWorkspaces: []accountdomain.WorkspaceID{id1, id2, id3}, + ReadableWorkspaces: []workspace.ID{id1, id2}, + OwningWorkspaces: []workspace.ID{id1, id2, id3}, } tests := []struct { name string seeds []*workspace.Workspace usersSeeds []*user.User + enforcer WorkspaceMemberCountEnforcer args struct { - wId accountdomain.WorkspaceID - users map[accountdomain.UserID]workspace.Role + wId workspace.ID + users map[user.ID]workspace.Role operator *accountusecase.Operator } wantErr error @@ -534,62 +535,97 @@ func TestWorkspace_AddMember(t *testing.T) { want *workspace.Members }{ { - name: "Add non existing", - seeds: []*workspace.Workspace{w1}, + name: "add a member", + seeds: []*workspace.Workspace{w2}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - users map[accountdomain.UserID]workspace.Role + wId workspace.ID + users map[user.ID]workspace.Role operator *accountusecase.Operator }{ - wId: id1, - users: map[accountdomain.UserID]workspace.Role{accountdomain.NewUserID(): workspace.RoleReader}, + wId: w2.ID(), + users: map[user.ID]workspace.Role{ + u.ID(): workspace.RoleReader, + }, operator: op, }, - want: workspace.NewMembersWith(map[user.ID]workspace.Member{userID: {Role: workspace.RoleOwner}}, map[accountdomain.IntegrationID]workspace.Member{}, false), + wantErr: nil, + want: workspace.NewMembersWith(map[user.ID]workspace.Member{ + userID: {Role: workspace.RoleOwner}, + u.ID(): {Role: workspace.RoleReader, InvitedBy: userID}, // added + }, map[accountdomain.IntegrationID]workspace.Member{}, false), }, { - name: "Add", - seeds: []*workspace.Workspace{w2}, + name: "add a non existing member", + seeds: []*workspace.Workspace{w1}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - users map[accountdomain.UserID]workspace.Role + wId workspace.ID + users map[user.ID]workspace.Role operator *accountusecase.Operator }{ - wId: id2, - users: map[accountdomain.UserID]workspace.Role{u.ID(): workspace.RoleReader}, + wId: w1.ID(), + users: map[user.ID]workspace.Role{ + accountdomain.NewUserID(): workspace.RoleReader, + }, operator: op, }, - wantErr: nil, - want: workspace.NewMembersWith(map[user.ID]workspace.Member{userID: {Role: workspace.RoleOwner}, u.ID(): {Role: workspace.RoleReader, InvitedBy: userID}}, map[accountdomain.IntegrationID]workspace.Member{}, false), + want: workspace.NewMembersWith(map[user.ID]workspace.Member{ + userID: {Role: workspace.RoleOwner}, + }, map[accountdomain.IntegrationID]workspace.Member{}, false), }, { - name: "Add to personal workspace", + name: "add a mamber to personal workspace", seeds: []*workspace.Workspace{w3}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - users map[accountdomain.UserID]workspace.Role + wId workspace.ID + users map[user.ID]workspace.Role operator *accountusecase.Operator }{ - wId: id3, - users: map[accountdomain.UserID]workspace.Role{u.ID(): workspace.RoleReader}, + wId: w3.ID(), + users: map[user.ID]workspace.Role{ + u.ID(): workspace.RoleReader, + }, operator: op, }, wantErr: workspace.ErrCannotModifyPersonalWorkspace, - want: workspace.NewMembersWith(map[user.ID]workspace.Member{userID: {Role: workspace.RoleOwner}}, map[accountdomain.IntegrationID]workspace.Member{}, true), + want: workspace.NewMembersWith(map[user.ID]workspace.Member{ + userID: {Role: workspace.RoleOwner}, + }, map[accountdomain.IntegrationID]workspace.Member{}, true), + }, + { + name: "add member but enforcer rejects", + seeds: []*workspace.Workspace{w2}, + usersSeeds: []*user.User{u}, + enforcer: func(_ context.Context, _ *workspace.Workspace, _ user.List, _ *accountusecase.Operator) error { + return errors.New("test") + }, + args: struct { + wId workspace.ID + users map[user.ID]workspace.Role + operator *accountusecase.Operator + }{ + wId: w2.ID(), + users: map[user.ID]workspace.Role{ + u.ID(): workspace.RoleReader, + }, + operator: op, + }, + wantErr: errors.New("test"), }, { name: "op denied", seeds: []*workspace.Workspace{w4}, args: struct { - wId accountdomain.WorkspaceID - users map[accountdomain.UserID]workspace.Role + wId workspace.ID + users map[user.ID]workspace.Role operator *accountusecase.Operator }{ - wId: id4, - users: map[accountdomain.UserID]workspace.Role{accountdomain.NewUserID(): workspace.RoleReader}, + wId: id4, + users: map[user.ID]workspace.Role{ + accountdomain.NewUserID(): workspace.RoleReader, + }, operator: op, }, wantErr: accountinterfaces.ErrOperationDenied, @@ -598,18 +634,21 @@ func TestWorkspace_AddMember(t *testing.T) { { name: "mock error", args: struct { - wId accountdomain.WorkspaceID - users map[accountdomain.UserID]workspace.Role + wId workspace.ID + users map[user.ID]workspace.Role operator *accountusecase.Operator }{ - wId: id3, - users: map[accountdomain.UserID]workspace.Role{u.ID(): workspace.RoleReader}, + wId: id3, + users: map[user.ID]workspace.Role{ + u.ID(): workspace.RoleReader, + }, operator: op, }, wantErr: errors.New("test"), mockWorkspaceErr: true, }, } + for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -628,7 +667,7 @@ func TestWorkspace_AddMember(t *testing.T) { err := db.User.Save(ctx, p) assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, tc.enforcer) got, err := workspaceUC.AddUserMember(ctx, tc.args.wId, tc.args.users, tc.args.operator) if tc.wantErr != nil { @@ -659,8 +698,8 @@ func TestWorkspace_AddIntegrationMember(t *testing.T) { op := &accountusecase.Operator{ User: &userID, - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2}, - OwningWorkspaces: []accountdomain.WorkspaceID{id1, id2, id3}, + ReadableWorkspaces: []workspace.ID{id1, id2}, + OwningWorkspaces: []workspace.ID{id1, id2, id3}, } iid1 := accountdomain.NewIntegrationID() @@ -670,7 +709,7 @@ func TestWorkspace_AddIntegrationMember(t *testing.T) { seeds []*workspace.Workspace usersSeeds []*user.User args struct { - wId accountdomain.WorkspaceID + wId workspace.ID integrationID accountdomain.IntegrationID role workspace.Role operator *accountusecase.Operator @@ -680,11 +719,11 @@ func TestWorkspace_AddIntegrationMember(t *testing.T) { want []accountdomain.IntegrationID }{ { - name: "Add non existing", + name: "add non existing", seeds: []*workspace.Workspace{w1}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID integrationID accountdomain.IntegrationID role workspace.Role operator *accountusecase.Operator @@ -699,7 +738,7 @@ func TestWorkspace_AddIntegrationMember(t *testing.T) { { name: "mock error", args: struct { - wId accountdomain.WorkspaceID + wId workspace.ID integrationID accountdomain.IntegrationID role workspace.Role operator *accountusecase.Operator @@ -732,7 +771,7 @@ func TestWorkspace_AddIntegrationMember(t *testing.T) { assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) got, err := workspaceUC.AddIntegrationMember(ctx, tc.args.wId, tc.args.integrationID, tc.args.role, tc.args.operator) if tc.wantErr != nil { @@ -759,8 +798,8 @@ func TestWorkspace_RemoveMember(t *testing.T) { op := &accountusecase.Operator{ User: &userID, - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2}, - OwningWorkspaces: []accountdomain.WorkspaceID{id1}, + ReadableWorkspaces: []workspace.ID{id1, id2}, + OwningWorkspaces: []workspace.ID{id1}, } tests := []struct { @@ -768,8 +807,8 @@ func TestWorkspace_RemoveMember(t *testing.T) { seeds []*workspace.Workspace usersSeeds []*user.User args struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID operator *accountusecase.Operator } wantErr error @@ -781,8 +820,8 @@ func TestWorkspace_RemoveMember(t *testing.T) { seeds: []*workspace.Workspace{w1}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID operator *accountusecase.Operator }{ wId: id1, @@ -797,8 +836,8 @@ func TestWorkspace_RemoveMember(t *testing.T) { seeds: []*workspace.Workspace{w2}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID operator *accountusecase.Operator }{ wId: id2, @@ -813,8 +852,8 @@ func TestWorkspace_RemoveMember(t *testing.T) { seeds: []*workspace.Workspace{w3}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID operator *accountusecase.Operator }{ wId: id3, @@ -829,8 +868,8 @@ func TestWorkspace_RemoveMember(t *testing.T) { seeds: []*workspace.Workspace{w4}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID operator *accountusecase.Operator }{ wId: id4, @@ -843,8 +882,8 @@ func TestWorkspace_RemoveMember(t *testing.T) { { name: "mock error", args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID operator *accountusecase.Operator }{operator: op}, wantErr: errors.New("test"), @@ -870,7 +909,7 @@ func TestWorkspace_RemoveMember(t *testing.T) { err := db.User.Save(ctx, p) assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) got, err := workspaceUC.RemoveUserMember(ctx, tc.args.wId, tc.args.uId, tc.args.operator) if tc.wantErr != nil { @@ -903,8 +942,8 @@ func TestWorkspace_UpdateMember(t *testing.T) { op := &accountusecase.Operator{ User: &userID, - ReadableWorkspaces: []accountdomain.WorkspaceID{id1, id2}, - OwningWorkspaces: []accountdomain.WorkspaceID{id1, id2, id3}, + ReadableWorkspaces: []workspace.ID{id1, id2}, + OwningWorkspaces: []workspace.ID{id1, id2, id3}, } tests := []struct { @@ -912,8 +951,8 @@ func TestWorkspace_UpdateMember(t *testing.T) { seeds []*workspace.Workspace usersSeeds []*user.User args struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID role workspace.Role operator *accountusecase.Operator } @@ -926,8 +965,8 @@ func TestWorkspace_UpdateMember(t *testing.T) { seeds: []*workspace.Workspace{w1}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID role workspace.Role operator *accountusecase.Operator }{ @@ -944,8 +983,8 @@ func TestWorkspace_UpdateMember(t *testing.T) { seeds: []*workspace.Workspace{w2}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID role workspace.Role operator *accountusecase.Operator }{ @@ -962,8 +1001,8 @@ func TestWorkspace_UpdateMember(t *testing.T) { seeds: []*workspace.Workspace{w3}, usersSeeds: []*user.User{u}, args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID role workspace.Role operator *accountusecase.Operator }{ @@ -978,8 +1017,8 @@ func TestWorkspace_UpdateMember(t *testing.T) { { name: "mock error", args: struct { - wId accountdomain.WorkspaceID - uId accountdomain.UserID + wId workspace.ID + uId user.ID role workspace.Role operator *accountusecase.Operator }{ @@ -1009,7 +1048,7 @@ func TestWorkspace_UpdateMember(t *testing.T) { err := db.User.Save(ctx, p) assert.NoError(t, err) } - workspaceUC := NewWorkspace(db) + workspaceUC := NewWorkspace(db, nil) got, err := workspaceUC.UpdateUserMember(ctx, tc.args.wId, tc.args.uId, tc.args.role, tc.args.operator) if tc.wantErr != nil { diff --git a/account/accountusecase/accountrepo/user.go b/account/accountusecase/accountrepo/user.go index edb969e..3e54750 100644 --- a/account/accountusecase/accountrepo/user.go +++ b/account/accountusecase/accountrepo/user.go @@ -3,7 +3,6 @@ package accountrepo import ( "context" - "github.com/reearth/reearthx/account/accountdomain" "github.com/reearth/reearthx/account/accountdomain/user" "github.com/reearth/reearthx/i18n" "github.com/reearth/reearthx/rerror" @@ -12,8 +11,8 @@ import ( var ErrDuplicatedUser = rerror.NewE(i18n.T("duplicated user")) type User interface { - FindByIDs(context.Context, accountdomain.UserIDList) ([]*user.User, error) - FindByID(context.Context, accountdomain.UserID) (*user.User, error) + FindByID(context.Context, user.ID) (*user.User, error) + FindByIDs(context.Context, user.IDList) (user.List, error) FindBySub(context.Context, string) (*user.User, error) FindByEmail(context.Context, string) (*user.User, error) FindByName(context.Context, string) (*user.User, error) @@ -23,5 +22,5 @@ type User interface { FindBySubOrCreate(context.Context, *user.User, string) (*user.User, error) Create(context.Context, *user.User) error Save(context.Context, *user.User) error - Remove(context.Context, accountdomain.UserID) error + Remove(context.Context, user.ID) error } diff --git a/account/accountusecase/accountrepo/workspace.go b/account/accountusecase/accountrepo/workspace.go index c722859..153e592 100644 --- a/account/accountusecase/accountrepo/workspace.go +++ b/account/accountusecase/accountrepo/workspace.go @@ -3,17 +3,17 @@ package accountrepo import ( "context" - "github.com/reearth/reearthx/account/accountdomain" + "github.com/reearth/reearthx/account/accountdomain/user" "github.com/reearth/reearthx/account/accountdomain/workspace" ) type Workspace interface { - FindByID(context.Context, accountdomain.WorkspaceID) (*workspace.Workspace, error) - FindByIDs(context.Context, accountdomain.WorkspaceIDList) (workspace.WorkspaceList, error) - FindByUser(context.Context, accountdomain.UserID) (workspace.WorkspaceList, error) - FindByIntegration(context.Context, accountdomain.IntegrationID) (workspace.WorkspaceList, error) + FindByID(context.Context, workspace.ID) (*workspace.Workspace, error) + FindByIDs(context.Context, workspace.IDList) (workspace.List, error) + FindByUser(context.Context, user.ID) (workspace.List, error) + FindByIntegration(context.Context, workspace.IntegrationID) (workspace.List, error) Save(context.Context, *workspace.Workspace) error - SaveAll(context.Context, []*workspace.Workspace) error - Remove(context.Context, accountdomain.WorkspaceID) error - RemoveAll(context.Context, accountdomain.WorkspaceIDList) error + SaveAll(context.Context, workspace.List) error + Remove(context.Context, workspace.ID) error + RemoveAll(context.Context, workspace.IDList) error }