Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move WeChat official accounts to a single space #18

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions example-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,12 @@ bridge:
listen_address: "0.0.0.0:20002"
listen_secret: foobar
# Should the bridge create a space for each logged-in user and add bridged rooms to it?
# Users who logged in before turning this on should run `!wa sync space` to create and fill the space for the first time.
personal_filtering_spaces: false
# Users who logged in before turning this on should run `sync space` to create and fill the space for the first time.
personal_filtering_spaces: true
# Should the bridge create a single separate space for all official accounts?
# If disabled, PMs and official accounts will be added to the same space.
# Only works if personal_filtering_spaces is enabled.
space_for_official_accounts: true
# Whether the bridge should send the message status as a custom com.beeper.message_send_status event.
message_status_events: false
# Whether the bridge should send error notices via m.notice events when a message fails to bridge.
Expand Down
66 changes: 55 additions & 11 deletions internal/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,18 +478,62 @@ func fnSync(ce *WrappedCommandEvent) {
ce.Reply("Personal filtering spaces are not enabled on this instance of the bridge")
return
}
keys := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.UID)
count := 0
for _, key := range keys {
portal := ce.Bridge.GetPortalByUID(key)
portal.addToSpace(ce.User)
count++
}
plural := "s"
if count == 1 {
plural = ""
if !ce.Bridge.Config.Bridge.SpaceForOfficialAccounts {
count := 0
chatsToAdd := ce.Bridge.DB.Portal.FindAllChatsNotInSpace(ce.User.UID)
for _, key := range chatsToAdd {
portal := ce.Bridge.GetPortalByUID(key)
portal.addToSpace(ce.User)
count++
}
plural := "s"
if count == 1 {
plural = ""
}
println("[DEBUG] Added", plural, "room"+plural+" to space")
ce.Reply("Added %d room%s to space", count, plural)
} else {
privateChatsToAdd := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.UID)
officialAccountsToAdd := ce.Bridge.DB.Portal.FindOfficialAccountsNotInOASpace(ce.User.UID)
officialAccountsToRemove := ce.Bridge.DB.Portal.FindOfficialAccountsInDefaultSpace(ce.User.UID)
privateAdded := 0
officialAdded := 0
officialRemoved := 0
for _, key := range privateChatsToAdd {
portal := ce.Bridge.GetPortalByUID(key)
portal.addToSpace(ce.User)
privateAdded++
}
for _, key := range officialAccountsToAdd {
portal := ce.Bridge.GetPortalByUID(key)
portal.addToOfficialAccountSpace(ce.User)
officialAdded++
}
for _, key := range officialAccountsToRemove {
portal := ce.Bridge.GetPortalByUID(key)
portal.removeFromSpace(ce.User)
officialRemoved++
}
privateAddedPlural := "s"
officialAddedPlural := "s"
if privateAdded == 1 {
privateAddedPlural = ""
}
if officialAdded == 1 {
officialAddedPlural = ""
}
officialRemovedMessage := ""
if officialRemoved > 0 {
plural := "s"
if officialRemoved == 1 {
plural = ""
}
officialRemovedMessage = ", and removed " + strconv.Itoa(officialRemoved) + " official account" + plural + " from space"
}
println("[DEBUG] Added", privateAdded, "DM room"+privateAddedPlural+" to space"+officialRemovedMessage)
ce.Reply("Added %d DM room%s and %d official account%s to space"+officialRemovedMessage,
privateAdded, privateAddedPlural, officialAdded, officialAddedPlural)
}
ce.Reply("Added %d DM room%s to space", count, plural)
}
if groups {
err := ce.User.ResyncGroups(createPortals)
Expand Down
3 changes: 2 additions & 1 deletion internal/config/bridgeconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ type BridgeConfig struct {
ListenAddress string `yaml:"listen_address"`
ListenSecret string `yaml:"listen_secret"`

PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
SpaceForOfficialAccounts bool `yaml:"space_for_official_accounts"`

MessageStatusEvents bool `yaml:"message_status_events"`
MessageErrorNotices bool `yaml:"message_error_notices"`
Expand Down
1 change: 1 addition & 0 deletions internal/config/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Str, "bridge", "listen_address")
helper.Copy(up.Str, "bridge", "listen_secret")
helper.Copy(up.Bool, "bridge", "personal_filtering_spaces")
helper.Copy(up.Bool, "bridge", "space_for_official_accounts")
helper.Copy(up.Bool, "bridge", "message_status_events")
helper.Copy(up.Bool, "bridge", "message_error_notices")
helper.Copy(up.Int, "bridge", "portal_message_buffer")
Expand Down
110 changes: 107 additions & 3 deletions internal/database/portalquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,117 @@ func (pq *PortalQuery) FindPrivateChats(receiver types.UID) []*Portal {
return pq.getAll(query, args...)
}

func (pq *PortalQuery) FindAllChatsNotInSpace(receiver types.UID) []PortalKey {
keys := []PortalKey{}

query := `
SELECT portal.uid FROM portal
LEFT JOIN user_portal
ON portal.uid = user_portal.portal_uid
AND portal.receiver = user_portal.portal_receiver
WHERE portal.mxid != ''
AND portal.receiver = $1
AND (user_portal.in_space = false OR user_portal.in_space IS NULL)
`
args := []interface{}{receiver}

rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return keys
}

defer rows.Close()
for rows.Next() {
var key PortalKey
key.Receiver = receiver
err = rows.Scan(&key.UID)
if err == nil {
keys = append(keys, key)
}
}

return keys
}

func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.UID) []PortalKey {
keys := []PortalKey{}

query := `
SELECT uid FROM portal
LEFT JOIN user_portal ON portal.uid=user_portal.portal_uid AND portal.receiver=user_portal.portal_receiver
WHERE mxid<>'' AND receiver=$1 AND (in_space=false OR in_space IS NULL)
SELECT portal.uid FROM portal
LEFT JOIN user_portal
ON portal.uid = user_portal.portal_uid
AND portal.receiver = user_portal.portal_receiver
WHERE portal.mxid != ''
AND portal.uid NOT LIKE 'gh_%'
AND portal.receiver = $1
AND (user_portal.in_space = false OR user_portal.in_space IS NULL)
`
args := []interface{}{receiver}

rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return keys
}

defer rows.Close()
for rows.Next() {
var key PortalKey
key.Receiver = receiver
err = rows.Scan(&key.UID)
if err == nil {
keys = append(keys, key)
}
}

return keys
}

func (pq *PortalQuery) FindOfficialAccountsInDefaultSpace(receiver types.UID) []PortalKey {
keys := []PortalKey{}

query := `
SELECT portal.uid FROM portal
LEFT JOIN user_portal
ON portal.uid = user_portal.portal_uid
AND portal.receiver = user_portal.portal_receiver
WHERE portal.mxid != ''
AND portal.uid LIKE 'gh_%'
AND portal.receiver = $1
AND user_portal.in_space = true
`
args := []interface{}{receiver}

rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return keys
}

defer rows.Close()
for rows.Next() {
var key PortalKey
key.Receiver = receiver
err = rows.Scan(&key.UID)
if err == nil {
keys = append(keys, key)
}
}

return keys
}

func (pq *PortalQuery) FindOfficialAccountsNotInOASpace(receiver types.UID) []PortalKey {
keys := []PortalKey{}

query := `
SELECT portal.uid FROM portal
LEFT JOIN user_portal
ON portal.uid = user_portal.portal_uid
AND portal.receiver = user_portal.portal_receiver
WHERE portal.mxid != ''
AND portal.uid LIKE 'gh_%'
AND portal.receiver = $1
AND (user_portal.in_official_account_space = false
OR user_portal.in_official_account_space IS NULL)
`
args := []interface{}{receiver}

Expand Down
6 changes: 6 additions & 0 deletions internal/database/upgrades/01-add-official-account-space.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- v2: Add official account space room
ALTER TABLE "user" ADD COLUMN official_account_space_room TEXT;
UPDATE "user" SET official_account_space_room = '' WHERE official_account_space_room IS NULL;

ALTER TABLE "user_portal" ADD COLUMN in_official_account_space BOOLEAN NOT NULL DEFAULT false;
UPDATE "user_portal" SET in_official_account_space = false WHERE in_official_account_space IS NULL;
23 changes: 12 additions & 11 deletions internal/database/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ type User struct {
db *Database
log log.Logger

MXID id.UserID
UID types.UID
ManagementRoom id.RoomID
SpaceRoom id.RoomID
MXID id.UserID
UID types.UID
ManagementRoom id.RoomID
SpaceRoom id.RoomID
OfficialAccountSpaceRoom id.RoomID

lastReadCache map[PortalKey]time.Time
lastReadCacheLock sync.Mutex
Expand All @@ -30,7 +31,7 @@ type User struct {

func (u *User) Scan(row dbutil.Scannable) *User {
var uin sql.NullString
err := row.Scan(&u.MXID, &uin, &u.ManagementRoom, &u.SpaceRoom)
err := row.Scan(&u.MXID, &uin, &u.ManagementRoom, &u.SpaceRoom, &u.OfficialAccountSpaceRoom)
if err != nil {
if err != sql.ErrNoRows {
u.log.Errorln("Database scan failed:", err)
Expand All @@ -47,11 +48,11 @@ func (u *User) Scan(row dbutil.Scannable) *User {

func (u *User) Insert() {
query := `
INSERT INTO "user" (mxid, uin, management_room, space_room)
VALUES ($1, $2, $3, $4)
INSERT INTO "user" (mxid, uin, management_room, space_room, official_account_space_room)
VALUES ($1, $2, $3, $4, $5)
`
args := []interface{}{
u.MXID, u.UID.Uin, u.ManagementRoom, u.SpaceRoom,
u.MXID, u.UID.Uin, u.ManagementRoom, u.SpaceRoom, u.OfficialAccountSpaceRoom,
}

_, err := u.db.Exec(query, args...)
Expand All @@ -63,11 +64,11 @@ func (u *User) Insert() {
func (u *User) Update() {
query := `
UPDATE "user"
SET uin=$1, management_room=$2, space_room=$3
WHERE mxid=$4
SET uin=$1, management_room=$2, space_room=$3, official_account_space_room=$4
WHERE mxid=$5
`
args := []interface{}{
u.UID.Uin, u.ManagementRoom, u.SpaceRoom, u.MXID,
u.UID.Uin, u.ManagementRoom, u.SpaceRoom, u.OfficialAccountSpaceRoom, u.MXID,
}
_, err := u.db.Exec(query, args...)
if err != nil {
Expand Down
48 changes: 48 additions & 0 deletions internal/database/userportal.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,51 @@ func (u *User) MarkInSpace(portal PortalKey) {
u.inSpaceCache[portal] = true
}
}

func (u *User) MarkNotInSpace(portal PortalKey) {
u.inSpaceCacheLock.Lock()
defer u.inSpaceCacheLock.Unlock()

query := `
INSERT INTO user_portal
(user_mxid, portal_uid, portal_receiver, in_space)
VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_uid, portal_receiver)
DO UPDATE SET
in_space=false
`
args := []interface{}{
u.MXID, portal.UID, portal.Receiver,
}

_, err := u.db.Exec(query, args...)
if err != nil {
u.log.Warnfln("Failed to update in space status: %v", err)
} else {
u.inSpaceCache[portal] = true
}
}

func (u *User) MarkInOfficialAccountSpace(portal PortalKey) {
u.inSpaceCacheLock.Lock()
defer u.inSpaceCacheLock.Unlock()

query := `
INSERT INTO user_portal
(user_mxid, portal_uid, portal_receiver, in_space)
VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_uid, portal_receiver)
DO UPDATE SET
in_official_account_space=true
`
args := []interface{}{
u.MXID, portal.UID, portal.Receiver,
}

_, err := u.db.Exec(query, args...)
if err != nil {
u.log.Warnfln("Failed to update in space status: %v", err)
} else {
u.inSpaceCache[portal] = true
}
}
2 changes: 1 addition & 1 deletion internal/database/userquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
log "maunium.net/go/maulogger/v2"
)

const userColumns = "mxid, uin, management_room, space_room"
const userColumns = "mxid, uin, management_room, space_room, official_account_space_room"

type UserQuery struct {
db *Database
Expand Down
Loading