Skip to content

Commit

Permalink
Add POST /api/v1/admin/merge_users (#3475)
Browse files Browse the repository at this point in the history
First backend support for #3246 : Adds a new `/api/v1/admin/merge_users`
endpoint to merge one or more source user accounts into an existing
destination user account. The source user accounts are deleted in the
process and all objects that belonged to the source users are
transferred to the destination user. Rows belonging to the source
accounts in the `user_stats` and `troll_stats` tables are also deleted.
  • Loading branch information
andreaskoepf authored Jun 14, 2023
1 parent aa272cc commit 7885202
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
28 changes: 25 additions & 3 deletions backend/oasst_backend/api/v1/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from oasst_backend.api import deps
from oasst_backend.config import Settings, settings
from oasst_backend.models import ApiClient, User
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.prompt_repository import PromptRepository, UserRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared import utils
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas.protocol import PageResult, SystemStats
from oasst_shared.utils import ScopeTimer, unaware_to_utc
from oasst_shared.utils import ScopeTimer, log_timing, unaware_to_utc
from starlette.status import HTTP_204_NO_CONTENT

router = APIRouter()

Expand Down Expand Up @@ -263,7 +264,7 @@ async def get_flagged_messages(
return resp


@router.post("/admin/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse)
@router.post("/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse)
async def process_flagged_messages(
message_id: UUID,
session: deps.Session = Depends(deps.get_db),
Expand All @@ -275,3 +276,24 @@ async def process_flagged_messages(
flagged_msg = pr.process_flagged_message(message_id=message_id)
resp = FlaggedMessageResponse(**flagged_msg.__dict__)
return resp


class MergeUsersRequest(pydantic.BaseModel):
destination_user_id: UUID
source_user_ids: list[UUID]


@log_timing(level="INFO")
@router.post("/merge_users", response_model=None, status_code=HTTP_204_NO_CONTENT)
def merge_users(
request: MergeUsersRequest,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> None:
@managed_tx_function(CommitMode.COMMIT)
def merge_users_tx(session: deps.Session):
ur = UserRepository(session, api_client)
ur.merge_users(destination_user_id=request.destination_user_id, source_user_ids=request.source_user_ids)

merge_users_tx()

logger.info(f"Merged users: {request=}")
36 changes: 35 additions & 1 deletion backend/oasst_backend/user_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from uuid import UUID

import oasst_backend.models as models
from oasst_backend.config import settings
from oasst_backend.models import ApiClient, User
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
Expand All @@ -9,7 +10,7 @@
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.utils import utcnow
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, and_, or_
from sqlmodel import Session, and_, delete, or_, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND


Expand Down Expand Up @@ -346,3 +347,36 @@ def update_user_last_activity(self, user: User, update_streak: bool = False) ->
user.streak_days = (current_time - user.streak_last_day_date).days

self.db.add(user)

@managed_tx_method(CommitMode.FLUSH)
def merge_users(self, destination_user_id: UUID, source_user_ids: list[UUID]) -> None:
source_user_ids = list(filter(lambda x: x != destination_user_id, source_user_ids))
if not source_user_ids:
return

# ensure the destination user exists
self.get_user(id=destination_user_id)

# update rows in tables that have affected users_ids as FK
models_to_update = [
models.Message,
models.MessageRevision,
models.MessageReaction,
models.MessageEmoji,
models.TextLabels,
models.Task,
models.Journal,
]
for table in models_to_update:
qry = update(table).where(table.user_id.in_(source_user_ids)).values(user_id=destination_user_id)
self.db.execute(qry)

# delete rows in user stats tables
models_to_delete = [models.UserStats, models.TrollStats]
for table in models_to_delete:
qry = delete(table).where(table.user_id.in_(source_user_ids))
self.db.execute(qry)

# finally delete source users from main user table
qry = delete(User).where(User.id.in_(source_user_ids))
self.db.execute(qry)

0 comments on commit 7885202

Please sign in to comment.