diff --git a/.github/workflows/ruff-check.yml b/.github/workflows/ruff-check.yml new file mode 100644 index 000000000..151002a85 --- /dev/null +++ b/.github/workflows/ruff-check.yml @@ -0,0 +1,28 @@ +name: Code Format Checker +on: + push: + branches: + - '*' + pull_request: + branches: + - '*' + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.12' + + - name: Install Ruff + run: | + python -m pip install --upgrade pip + pip install ruff + - name: Run Ruff Format + run: | + ruff check --output-format=github . diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d18fd3257..d5fab1ae9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,8 @@ Backend is built using FastAPI and uses SQLAlchemy as the ORM for database opera ### Python Code Formatting To maintain consistency in the codebase, we require all code to be formatted using ```bash -autopep8 --max-line-length 120 +ruff check . +ruff format . ``` ## Frontend diff --git a/Marzban.code-workspace b/Marzban.code-workspace index 1fde2265c..1dda574fa 100644 --- a/Marzban.code-workspace +++ b/Marzban.code-workspace @@ -10,14 +10,13 @@ "python.analysis.inlayHints.pytestParameters": true, "python.analysis.inlayHints.callArgumentNames": "all", "python.analysis.inlayHints.functionReturnTypes": true, + "[python]": { - "editor.defaultFormatter": "ms-python.autopep8", - "editor.formatOnSave": true, + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.formatOnSave": true, }, - "autopep8.args": [ - "--max-line-length", - "120" - ], + "ruff.configuration": "pyproject.toml", + "editor.codeActionsOnSave": { "source.organizeImports": "never" }, @@ -85,7 +84,8 @@ "ms-vscode.vscode-typescript-next", "yoavbls.pretty-ts-errors", "dbaeumer.vscode-eslint", - "esbenp.prettier-vscode" + "esbenp.prettier-vscode", + "charliermarsh.ruff" ] }, } \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py index 98d310371..ba0e30471 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -20,9 +20,7 @@ redoc_url="/redoc" if DOCS else None, ) -scheduler = BackgroundScheduler( - {"apscheduler.job_defaults.max_instances": 20}, timezone="UTC" -) +scheduler = BackgroundScheduler({"apscheduler.job_defaults.max_instances": 20}, timezone="UTC") logger = logging.getLogger("uvicorn.error") app.add_middleware( @@ -52,9 +50,7 @@ def on_startup(): paths = [f"{r.path}/" for r in app.routes] paths.append("/api/") if f"/{XRAY_SUBSCRIPTION_PATH}/" in paths: - raise ValueError( - f"you can't use /{XRAY_SUBSCRIPTION_PATH}/ as subscription path it reserved for {app.title}" - ) + raise ValueError(f"you can't use /{XRAY_SUBSCRIPTION_PATH}/ as subscription path it reserved for {app.title}") scheduler.start() diff --git a/app/dashboard/__init__.py b/app/dashboard/__init__.py index eeb789b80..da483b085 100644 --- a/app/dashboard/__init__.py +++ b/app/dashboard/__init__.py @@ -7,37 +7,38 @@ from fastapi.staticfiles import StaticFiles base_dir = Path(__file__).parent -build_dir = base_dir / 'build' -statics_dir = build_dir / 'statics' +build_dir = base_dir / "build" +statics_dir = build_dir / "statics" def build_api_interface(): - proc = subprocess.Popen( - ['pnpm', 'run', 'wait-port-gen-api'], - env={**os.environ, 'UVICORN_PORT': str(UVICORN_PORT)}, + subprocess.Popen( + ["pnpm", "run", "wait-port-gen-api"], + env={**os.environ, "UVICORN_PORT": str(UVICORN_PORT)}, cwd=base_dir, - stdout=subprocess.DEVNULL + stdout=subprocess.DEVNULL, ) + def build(): proc = subprocess.Popen( - ['pnpm', 'run', 'build', '--outDir', build_dir, '--assetsDir', 'statics'], - env={**os.environ, 'VITE_BASE_API': VITE_BASE_API}, - cwd=base_dir + ["pnpm", "run", "build", "--outDir", build_dir, "--assetsDir", "statics"], + env={**os.environ, "VITE_BASE_API": VITE_BASE_API}, + cwd=base_dir, ) proc.wait() - with open(build_dir / 'index.html', 'r') as file: + with open(build_dir / "index.html", "r") as file: html = file.read() - with open(build_dir / '404.html', 'w') as file: + with open(build_dir / "404.html", "w") as file: file.write(html) def run_dev(): build_api_interface() proc = subprocess.Popen( - ['pnpm', 'run', 'dev', '--base', os.path.join(DASHBOARD_PATH, '')], - env={**os.environ, 'VITE_BASE_API': VITE_BASE_API, 'DEBUG': 'false'}, - cwd=base_dir + ["pnpm", "run", "dev", "--base", os.path.join(DASHBOARD_PATH, "")], + env={**os.environ, "VITE_BASE_API": VITE_BASE_API, "DEBUG": "false"}, + cwd=base_dir, ) atexit.register(proc.terminate) @@ -47,16 +48,8 @@ def run_build(): if not build_dir.is_dir(): build() - app.mount( - DASHBOARD_PATH, - StaticFiles(directory=build_dir, html=True), - name="dashboard" - ) - app.mount( - '/statics/', - StaticFiles(directory=statics_dir, html=True), - name="statics" - ) + app.mount(DASHBOARD_PATH, StaticFiles(directory=build_dir, html=True), name="dashboard") + app.mount("/statics/", StaticFiles(directory=statics_dir, html=True), name="statics") @app.on_event("startup") diff --git a/app/db/__init__.py b/app/db/__init__.py index 7dfcd2f9f..57ce0cc20 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,37 +1,37 @@ -from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session -from .base import Base, SessionLocal, engine # noqa - - -class GetDB: # Context Manager - def __init__(self): - self.db = SessionLocal() - - def __enter__(self): - return self.db - - def __exit__(self, exc_type, exc_value, traceback): - if isinstance(exc_value, SQLAlchemyError): - self.db.rollback() # rollback on exception - - self.db.close() - - -def get_db(): # Dependency - with GetDB() as db: - yield db - - -from .crud import (create_admin, create_notification_reminder, # noqa - create_user, delete_notification_reminder, get_admin, - get_admins, get_jwt_secret_key, get_notification_reminder, - get_or_create_inbound, get_system_usage, - get_tls_certificate, get_user, get_user_by_id, get_users, - get_users_count, remove_admin, remove_user, revoke_user_sub, - set_owner, update_admin, update_user, update_user_status, reset_user_by_next, - update_user_sub, start_user_expire, get_admin_by_id, - get_admin_by_telegram_id) +from .base import Base, GetDB, get_db # noqa + + +from .crud import ( + create_admin, + create_notification_reminder, # noqa + create_user, + delete_notification_reminder, + get_admin, + get_admins, + get_jwt_secret_key, + get_notification_reminder, + get_or_create_inbound, + get_system_usage, + get_tls_certificate, + get_user, + get_user_by_id, + get_users, + get_users_count, + remove_admin, + remove_user, + revoke_user_sub, + set_owner, + update_admin, + update_user, + update_user_status, + reset_user_by_next, + update_user_sub, + start_user_expire, + get_admin_by_id, + get_admin_by_telegram_id, +) from .models import JWT, System, User # noqa @@ -60,18 +60,14 @@ def get_db(): # Dependency "get_admins", "get_admin_by_id", "get_admin_by_telegram_id", - "create_notification_reminder", "get_notification_reminder", "delete_notification_reminder", - "GetDB", "get_db", - "User", "System", "JWT", - "Base", "Session", ] diff --git a/app/db/base.py b/app/db/base.py index 411a6ea12..94c7b0a48 100644 --- a/app/db/base.py +++ b/app/db/base.py @@ -1,25 +1,24 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, DeclarativeBase +from sqlalchemy.exc import SQLAlchemyError + from config import ( SQLALCHEMY_DATABASE_URL, SQLALCHEMY_POOL_SIZE, SQLIALCHEMY_MAX_OVERFLOW, ) -IS_SQLITE = SQLALCHEMY_DATABASE_URL.startswith('sqlite') +IS_SQLITE = SQLALCHEMY_DATABASE_URL.startswith("sqlite") if IS_SQLITE: - engine = create_engine( - SQLALCHEMY_DATABASE_URL, - connect_args={"check_same_thread": False} - ) + engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) else: engine = create_engine( SQLALCHEMY_DATABASE_URL, pool_size=SQLALCHEMY_POOL_SIZE, max_overflow=SQLIALCHEMY_MAX_OVERFLOW, pool_recycle=3600, - pool_timeout=10 + pool_timeout=10, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -27,3 +26,22 @@ class Base(DeclarativeBase): pass + + +class GetDB: # Context Manager + def __init__(self): + self.db = SessionLocal() + + def __enter__(self): + return self.db + + def __exit__(self, exc_type, exc_value, traceback): + if isinstance(exc_value, SQLAlchemyError): + self.db.rollback() # rollback on exception + + self.db.close() + + +def get_db(): # Dependency + with GetDB() as db: + yield db diff --git a/app/db/crud.py b/app/db/crud.py index 25268b51a..3d672b899 100644 --- a/app/db/crud.py +++ b/app/db/crud.py @@ -54,8 +54,7 @@ def add_default_host(db: Session, inbound: ProxyInbound): db (Session): Database session. inbound (ProxyInbound): Proxy inbound to add the default host to. """ - host = ProxyHost( - remark="🚀 Marz ({USERNAME}) [{PROTOCOL} - {TRANSPORT}]", address="{SERVER_IP}", inbound=inbound) + host = ProxyHost(remark="🚀 Marz ({USERNAME}) [{PROTOCOL} - {TRANSPORT}]", address="{SERVER_IP}", inbound=inbound) db.add(host) db.commit() @@ -71,8 +70,7 @@ def get_or_create_inbound(db: Session, inbound_tag: str) -> ProxyInbound: Returns: ProxyInbound: The retrieved or newly created proxy inbound. """ - inbound = db.query(ProxyInbound).filter( - ProxyInbound.tag == inbound_tag).first() + inbound = db.query(ProxyInbound).filter(ProxyInbound.tag == inbound_tag).first() if not inbound: inbound = ProxyInbound(tag=inbound_tag) db.add(inbound) @@ -83,9 +81,9 @@ def get_or_create_inbound(db: Session, inbound_tag: str) -> ProxyInbound: def get_hosts( - db: Session, - offset: Optional[int] = 0, - limit: Optional[int] = 0, + db: Session, + offset: Optional[int] = 0, + limit: Optional[int] = 0, ) -> List[ProxyHost]: """ Retrieves hosts. @@ -241,31 +239,36 @@ def get_user_by_id(db: Session, user_id: int) -> Optional[User]: return get_user_queryset(db).filter(User.id == user_id).first() -UsersSortingOptions = Enum('UsersSortingOptions', { - 'username': User.username.asc(), - 'used_traffic': User.used_traffic.asc(), - 'data_limit': User.data_limit.asc(), - 'expire': User.expire.asc(), - 'created_at': User.created_at.asc(), - '-username': User.username.desc(), - '-used_traffic': User.used_traffic.desc(), - '-data_limit': User.data_limit.desc(), - '-expire': User.expire.desc(), - '-created_at': User.created_at.desc(), -}) - - -def get_users(db: Session, - offset: Optional[int] = None, - limit: Optional[int] = None, - usernames: Optional[List[str]] = None, - search: Optional[str] = None, - status: Optional[Union[UserStatus, list]] = None, - sort: Optional[List[UsersSortingOptions]] = None, - admin: Optional[Admin] = None, - admins: Optional[List[str]] = None, - reset_strategy: Optional[Union[UserDataLimitResetStrategy, list]] = None, - return_with_count: bool = False) -> Union[List[User], Tuple[List[User], int]]: +UsersSortingOptions = Enum( + "UsersSortingOptions", + { + "username": User.username.asc(), + "used_traffic": User.used_traffic.asc(), + "data_limit": User.data_limit.asc(), + "expire": User.expire.asc(), + "created_at": User.created_at.asc(), + "-username": User.username.desc(), + "-used_traffic": User.used_traffic.desc(), + "-data_limit": User.data_limit.desc(), + "-expire": User.expire.desc(), + "-created_at": User.created_at.desc(), + }, +) + + +def get_users( + db: Session, + offset: Optional[int] = None, + limit: Optional[int] = None, + usernames: Optional[List[str]] = None, + search: Optional[str] = None, + status: Optional[Union[UserStatus, list]] = None, + sort: Optional[List[UsersSortingOptions]] = None, + admin: Optional[Admin] = None, + admins: Optional[List[str]] = None, + reset_strategy: Optional[Union[UserDataLimitResetStrategy, list]] = None, + return_with_count: bool = False, +) -> Union[List[User], Tuple[List[User], int]]: """ Retrieves users based on various filters and options. @@ -288,8 +291,7 @@ def get_users(db: Session, query = get_user_queryset(db) if search: - query = query.filter(or_(User.username.ilike( - f"%{search}%"), User.note.ilike(f"%{search}%"))) + query = query.filter(or_(User.username.ilike(f"%{search}%"), User.note.ilike(f"%{search}%"))) if usernames: query = query.filter(User.username.in_(usernames)) @@ -302,11 +304,9 @@ def get_users(db: Session, if reset_strategy: if isinstance(reset_strategy, list): - query = query.filter( - User.data_limit_reset_strategy.in_(reset_strategy)) + query = query.filter(User.data_limit_reset_strategy.in_(reset_strategy)) else: - query = query.filter( - User.data_limit_reset_strategy == reset_strategy) + query = query.filter(User.data_limit_reset_strategy == reset_strategy) if admin: query = query.filter(User.admin == admin) @@ -345,22 +345,16 @@ def get_user_usages(db: Session, dbuser: User, start: datetime, end: datetime) - List[UserUsageResponse]: List of user usage responses. """ - usages = {0: UserUsageResponse( # Main Core - node_id=None, - node_name="Master", - used_traffic=0 - )} + usages = { + 0: UserUsageResponse( # Main Core + node_id=None, node_name="Master", used_traffic=0 + ) + } for node in db.query(Node).all(): - usages[node.id] = UserUsageResponse( - node_id=node.id, - node_name=node.name, - used_traffic=0 - ) + usages[node.id] = UserUsageResponse(node_id=node.id, node_name=node.name, used_traffic=0) - cond = and_(NodeUserUsage.user_id == dbuser.id, - NodeUserUsage.created_at >= start, - NodeUserUsage.created_at <= end) + cond = and_(NodeUserUsage.user_id == dbuser.id, NodeUserUsage.created_at >= start, NodeUserUsage.created_at <= end) for v in db.query(NodeUserUsage).filter(cond): try: @@ -406,13 +400,9 @@ def create_user(db: Session, user: UserCreate, admin: Admin = None) -> User: excluded_inbounds_tags = user.excluded_inbounds proxies = [] for proxy_type, settings in user.proxies.items(): - excluded_inbounds = [ - get_or_create_inbound(db, tag) for tag in excluded_inbounds_tags[proxy_type] - ] + excluded_inbounds = [get_or_create_inbound(db, tag) for tag in excluded_inbounds_tags[proxy_type]] proxies.append( - Proxy(type=proxy_type.value, - settings=settings.dict(no_obj=True), - excluded_inbounds=excluded_inbounds) + Proxy(type=proxy_type.value, settings=settings.dict(no_obj=True), excluded_inbounds=excluded_inbounds) ) dbuser = User( @@ -433,7 +423,9 @@ def create_user(db: Session, user: UserCreate, admin: Admin = None) -> User: expire=user.next_plan.expire, add_remaining_traffic=user.next_plan.add_remaining_traffic, fire_on_either=user.next_plan.fire_on_either, - ) if user.next_plan else None + ) + if user.next_plan + else None, ) db.add(dbuser) db.commit() @@ -486,14 +478,11 @@ def update_user(db: Session, dbuser: User, modify: UserModify) -> User: added_proxies: Dict[ProxyTypes, Proxy] = {} if modify.proxies: for proxy_type, settings in modify.proxies.items(): - dbproxy = db.query(Proxy) \ - .where(Proxy.user == dbuser, Proxy.type == proxy_type) \ - .first() + dbproxy = db.query(Proxy).where(Proxy.user == dbuser, Proxy.type == proxy_type).first() if dbproxy: dbproxy.settings = settings.dict(no_obj=True) else: - new_proxy = Proxy( - type=proxy_type, settings=settings.dict(no_obj=True)) + new_proxy = Proxy(type=proxy_type, settings=settings.dict(no_obj=True)) dbuser.proxies.append(new_proxy) added_proxies.update({proxy_type: new_proxy}) for proxy in dbuser.proxies: @@ -501,28 +490,27 @@ def update_user(db: Session, dbuser: User, modify: UserModify) -> User: db.delete(proxy) if modify.inbounds: for proxy_type, tags in modify.excluded_inbounds.items(): - dbproxy = db.query(Proxy) \ - .where(Proxy.user == dbuser, Proxy.type == proxy_type) \ - .first() or added_proxies.get(proxy_type) + dbproxy = db.query(Proxy).where( + Proxy.user == dbuser, Proxy.type == proxy_type + ).first() or added_proxies.get(proxy_type) if dbproxy: - dbproxy.excluded_inbounds = [ - get_or_create_inbound(db, tag) for tag in tags] + dbproxy.excluded_inbounds = [get_or_create_inbound(db, tag) for tag in tags] if modify.status is not None: dbuser.status = modify.status if modify.data_limit is not None: - dbuser.data_limit = (modify.data_limit or None) + dbuser.data_limit = modify.data_limit or None if dbuser.status not in (UserStatus.expired, UserStatus.disabled): if not dbuser.data_limit or dbuser.used_traffic < dbuser.data_limit: if dbuser.status != UserStatus.on_hold: dbuser.status = UserStatus.active for percent in sorted(NOTIFY_REACHED_USAGE_PERCENT, reverse=True): - if not dbuser.data_limit or (calculate_usage_percent( - dbuser.used_traffic, dbuser.data_limit) < percent): - reminder = get_notification_reminder( - db, dbuser.id, ReminderType.data_usage, threshold=percent) + if not dbuser.data_limit or ( + calculate_usage_percent(dbuser.used_traffic, dbuser.data_limit) < percent + ): + reminder = get_notification_reminder(db, dbuser.id, ReminderType.data_usage, threshold=percent) if reminder: delete_notification_reminder(db, reminder) @@ -540,10 +528,10 @@ def update_user(db: Session, dbuser: User, modify: UserModify) -> User: if not dbuser.expire or dbuser.expire > datetime.utcnow(): dbuser.status = UserStatus.active for days_left in sorted(NOTIFY_DAYS_LEFT): - if not dbuser.expire or (calculate_expiration_days( - dbuser.expire) > days_left): + if not dbuser.expire or (calculate_expiration_days(dbuser.expire) > days_left): reminder = get_notification_reminder( - db, dbuser.id, ReminderType.expiration_date, threshold=days_left) + db, dbuser.id, ReminderType.expiration_date, threshold=days_left + ) if reminder: delete_notification_reminder(db, reminder) else: @@ -638,16 +626,16 @@ def reset_user_by_next(db: Session, dbuser: User) -> User: dbuser.status = UserStatus.active.value if dbuser.next_plan.user_template_id is None: - dbuser.data_limit = dbuser.next_plan.data_limit + \ - (0 if dbuser.next_plan.add_remaining_traffic else dbuser.data_limit or 0 - dbuser.used_traffic) - dbuser.expire = timedelta( - seconds=dbuser.next_plan.expire) + datetime.now(UTC) + dbuser.data_limit = dbuser.next_plan.data_limit + ( + 0 if dbuser.next_plan.add_remaining_traffic else dbuser.data_limit or 0 - dbuser.used_traffic + ) + dbuser.expire = timedelta(seconds=dbuser.next_plan.expire) + datetime.now(UTC) else: dbuser.inbounds = dbuser.next_plan.user_template.inbounds - dbuser.data_limit = dbuser.next_plan.user_template.data_limit + \ - (0 if dbuser.next_plan.add_remaining_traffic else dbuser.data_limit or 0 - dbuser.used_traffic) - dbuser.expire = timedelta( - seconds=dbuser.next_plan.user_template.expire_duration) + datetime.now(UTC) + dbuser.data_limit = dbuser.next_plan.user_template.data_limit + ( + 0 if dbuser.next_plan.add_remaining_traffic else dbuser.data_limit or 0 - dbuser.used_traffic + ) + dbuser.expire = timedelta(seconds=dbuser.next_plan.user_template.expire_duration) + datetime.now(UTC) dbuser.used_traffic = 0 db.delete(dbuser.next_plan) @@ -738,13 +726,13 @@ def disable_all_active_users(db: Session, admin: Optional[Admin] = None): db (Session): Database session. admin (Optional[Admin]): Admin to filter users by, if any. """ - query = db.query(User).filter(User.status.in_( - (UserStatus.active, UserStatus.on_hold))) + query = db.query(User).filter(User.status.in_((UserStatus.active, UserStatus.on_hold))) if admin: query = query.filter(User.admin == admin) - query.update({User.status: UserStatus.disabled, - User.last_status_change: datetime.utcnow()}, synchronize_session=False) + query.update( + {User.status: UserStatus.disabled, User.last_status_change: datetime.utcnow()}, synchronize_session=False + ) db.commit() @@ -757,29 +745,30 @@ def activate_all_disabled_users(db: Session, admin: Optional[Admin] = None): db (Session): Database session. admin (Optional[Admin]): Admin to filter users by, if any. """ - query_for_active_users = db.query(User).filter( - User.status == UserStatus.disabled) + query_for_active_users = db.query(User).filter(User.status == UserStatus.disabled) query_for_on_hold_users = db.query(User).filter( and_( - User.status == UserStatus.disabled, User.expire.is_( - None), User.on_hold_expire_duration.isnot(None), User.online_at.is_(None) - )) + User.status == UserStatus.disabled, + User.expire.is_(None), + User.on_hold_expire_duration.isnot(None), + User.online_at.is_(None), + ) + ) if admin: - query_for_active_users = query_for_active_users.filter( - User.admin == admin) - query_for_on_hold_users = query_for_on_hold_users.filter( - User.admin == admin) + query_for_active_users = query_for_active_users.filter(User.admin == admin) + query_for_on_hold_users = query_for_on_hold_users.filter(User.admin == admin) query_for_on_hold_users.update( - {User.status: UserStatus.on_hold, User.last_status_change: datetime.utcnow()}, synchronize_session=False) + {User.status: UserStatus.on_hold, User.last_status_change: datetime.utcnow()}, synchronize_session=False + ) query_for_active_users.update( - {User.status: UserStatus.active, User.last_status_change: datetime.utcnow()}, synchronize_session=False) + {User.status: UserStatus.active, User.last_status_change: datetime.utcnow()}, synchronize_session=False + ) db.commit() -def autodelete_expired_users(db: Session, - include_limited_users: bool = False) -> List[User]: +def autodelete_expired_users(db: Session, include_limited_users: bool = False) -> List[User]: """ Deletes expired (optionally also limited) users whose auto-delete time has passed. @@ -791,19 +780,21 @@ def autodelete_expired_users(db: Session, Returns: list[User]: List of deleted users. """ - target_status = ( - [UserStatus.expired] if not include_limited_users - else [UserStatus.expired, UserStatus.limited] - ) + target_status = [UserStatus.expired] if not include_limited_users else [UserStatus.expired, UserStatus.limited] auto_delete = coalesce(User.auto_delete_in_days, USERS_AUTODELETE_DAYS) - query = db.query( - User, auto_delete, # Use global auto-delete days as fallback - ).filter( - auto_delete >= 0, # Negative values prevent auto-deletion - User.status.in_(target_status), - ).options(joinedload(User.admin)) + query = ( + db.query( + User, + auto_delete, # Use global auto-delete days as fallback + ) + .filter( + auto_delete >= 0, # Negative values prevent auto-deletion + User.status.in_(target_status), + ) + .options(joinedload(User.admin)) + ) # TODO: Handle time filter in query itself (NOTE: Be careful with sqlite's strange datetime handling) expired_users = [ @@ -818,9 +809,7 @@ def autodelete_expired_users(db: Session, return expired_users -def get_all_users_usages( - db: Session, admin: Admin, start: datetime, end: datetime -) -> List[UserUsageResponse]: +def get_all_users_usages(db: Session, admin: Admin, start: datetime, end: datetime) -> List[UserUsageResponse]: """ Retrieves usage data for all users associated with an admin within a specified time range. @@ -837,25 +826,19 @@ def get_all_users_usages( List[UserUsageResponse]: A list of UserUsageResponse objects, each representing the usage data for a specific node or the main core. """ - usages = {0: UserUsageResponse( # Main Core - node_id=None, - node_name="Master", - used_traffic=0 - )} + usages = { + 0: UserUsageResponse( # Main Core + node_id=None, node_name="Master", used_traffic=0 + ) + } for node in db.query(Node).all(): - usages[node.id] = UserUsageResponse( - node_id=node.id, - node_name=node.name, - used_traffic=0 - ) + usages[node.id] = UserUsageResponse(node_id=node.id, node_name=node.name, used_traffic=0) admin_users = set(user.id for user in get_users(db=db, admins=admin)) cond = and_( - NodeUserUsage.created_at >= start, - NodeUserUsage.created_at <= end, - NodeUserUsage.user_id.in_(admin_users) + NodeUserUsage.created_at >= start, NodeUserUsage.created_at <= end, NodeUserUsage.user_id.in_(admin_users) ) for v in db.query(NodeUserUsage).filter(cond): @@ -915,8 +898,7 @@ def start_user_expire(db: Session, dbuser: User) -> User: Returns: User: The updated user object. """ - dbuser.expire = datetime.now( - timezone.utc) + timedelta(seconds=dbuser.on_hold_expire_duration) + dbuser.expire = datetime.now(timezone.utc) + timedelta(seconds=dbuser.on_hold_expire_duration) dbuser.on_hold_expire_duration = None dbuser.on_hold_timeout = None db.commit() @@ -993,7 +975,7 @@ def create_admin(db: Session, admin: AdminCreate) -> Admin: hashed_password=admin.hashed_password, is_sudo=admin.is_sudo, telegram_id=admin.telegram_id if admin.telegram_id else None, - discord_webhook=admin.discord_webhook if admin.discord_webhook else None + discord_webhook=admin.discord_webhook if admin.discord_webhook else None, ) db.add(dbadmin) db.commit() @@ -1103,10 +1085,9 @@ def get_admin_by_telegram_id(db: Session, telegram_id: int) -> Admin: return db.query(Admin).filter(Admin.telegram_id == telegram_id).first() -def get_admins(db: Session, - offset: Optional[int] = None, - limit: Optional[int] = None, - username: Optional[str] = None) -> List[Admin]: +def get_admins( + db: Session, offset: Optional[int] = None, limit: Optional[int] = None, username: Optional[str] = None +) -> List[Admin]: """ Retrieves a list of admins with optional filters and pagination. @@ -1121,7 +1102,7 @@ def get_admins(db: Session, """ query = db.query(Admin) if username: - query = query.filter(Admin.username.ilike(f'%{username}%')) + query = query.filter(Admin.username.ilike(f"%{username}%")) if offset: query = query.offset(offset) if limit: @@ -1138,13 +1119,10 @@ def reset_admin_usage(db: Session, dbadmin: Admin) -> int: Returns: Admin: The updated admin. """ - if (dbadmin.users_usage == 0): + if dbadmin.users_usage == 0: return dbadmin - usage_log = AdminUsageLogs( - admin=dbadmin, - used_traffic_at_reset=dbadmin.users_usage - ) + usage_log = AdminUsageLogs(admin=dbadmin, used_traffic_at_reset=dbadmin.users_usage) db.add(usage_log) dbadmin.users_usage = 0 @@ -1173,8 +1151,7 @@ def create_user_template(db: Session, user_template: UserTemplateCreate) -> User expire_duration=user_template.expire_duration, username_prefix=user_template.username_prefix, username_suffix=user_template.username_suffix, - inbounds=db.query(ProxyInbound).filter( - ProxyInbound.tag.in_(inbound_tags)).all() + inbounds=db.query(ProxyInbound).filter(ProxyInbound.tag.in_(inbound_tags)).all(), ) db.add(dbuser_template) db.commit() @@ -1183,7 +1160,8 @@ def create_user_template(db: Session, user_template: UserTemplateCreate) -> User def update_user_template( - db: Session, dbuser_template: UserTemplate, modified_user_template: UserTemplateModify) -> UserTemplate: + db: Session, dbuser_template: UserTemplate, modified_user_template: UserTemplateModify +) -> UserTemplate: """ Updates a user template's details. @@ -1210,8 +1188,7 @@ def update_user_template( inbound_tags: List[str] = [] for _, i in modified_user_template.inbounds.items(): inbound_tags.extend(i) - dbuser_template.inbounds = db.query(ProxyInbound).filter( - ProxyInbound.tag.in_(inbound_tags)).all() + dbuser_template.inbounds = db.query(ProxyInbound).filter(ProxyInbound.tag.in_(inbound_tags)).all() db.commit() db.refresh(dbuser_template) @@ -1245,7 +1222,8 @@ def get_user_template(db: Session, user_template_id: int) -> UserTemplate: def get_user_templates( - db: Session, offset: Union[int, None] = None, limit: Union[int, None] = None) -> List[UserTemplate]: + db: Session, offset: Union[int, None] = None, limit: Union[int, None] = None +) -> List[UserTemplate]: """ Retrieves a list of user templates with optional pagination. @@ -1294,9 +1272,7 @@ def get_node_by_id(db: Session, node_id: int) -> Optional[Node]: return db.query(Node).filter(Node.id == node_id).first() -def get_nodes(db: Session, - status: Optional[Union[NodeStatus, list]] = None, - enabled: bool = None) -> List[Node]: +def get_nodes(db: Session, status: Optional[Union[NodeStatus, list]] = None, enabled: bool = None) -> List[Node]: """ Retrieves nodes based on optional status and enabled filters. @@ -1334,20 +1310,14 @@ def get_nodes_usage(db: Session, start: datetime, end: datetime) -> List[NodeUsa Returns: List[NodeUsageResponse]: A list of NodeUsageResponse objects containing usage data. """ - usages = {0: NodeUsageResponse( # Main Core - node_id=None, - node_name="Master", - uplink=0, - downlink=0 - )} + usages = { + 0: NodeUsageResponse( # Main Core + node_id=None, node_name="Master", uplink=0, downlink=0 + ) + } for node in db.query(Node).all(): - usages[node.id] = NodeUsageResponse( - node_id=node.id, - node_name=node.name, - uplink=0, - downlink=0 - ) + usages[node.id] = NodeUsageResponse(node_id=node.id, node_name=node.name, uplink=0, downlink=0) cond = and_(NodeUsage.created_at >= start, NodeUsage.created_at <= end) @@ -1372,10 +1342,7 @@ def create_node(db: Session, node: NodeCreate) -> Node: Returns: Node: The newly created Node object. """ - dbnode = Node(name=node.name, - address=node.address, - port=node.port, - api_port=node.api_port) + dbnode = Node(name=node.name, address=node.address, port=node.port, api_port=node.api_port) db.add(dbnode) db.commit() @@ -1462,7 +1429,8 @@ def update_node_status(db: Session, dbnode: Node, status: NodeStatus, message: s def create_notification_reminder( - db: Session, reminder_type: ReminderType, expires_at: datetime, user_id: int, threshold: Optional[int] = None) -> NotificationReminder: + db: Session, reminder_type: ReminderType, expires_at: datetime, user_id: int, threshold: Optional[int] = None +) -> NotificationReminder: """ Creates a new notification reminder. @@ -1476,8 +1444,7 @@ def create_notification_reminder( Returns: NotificationReminder: The newly created NotificationReminder object. """ - reminder = NotificationReminder( - type=reminder_type, expires_at=expires_at, user_id=user_id) + reminder = NotificationReminder(type=reminder_type, expires_at=expires_at, user_id=user_id) if threshold is not None: reminder.threshold = threshold db.add(reminder) @@ -1487,7 +1454,7 @@ def create_notification_reminder( def get_notification_reminder( - db: Session, user_id: int, reminder_type: ReminderType, threshold: Optional[int] = None + db: Session, user_id: int, reminder_type: ReminderType, threshold: Optional[int] = None ) -> Union[NotificationReminder, None]: """ Retrieves a notification reminder for a user. @@ -1502,8 +1469,7 @@ def get_notification_reminder( Union[NotificationReminder, None]: The NotificationReminder object if found and not expired, None otherwise. """ query = db.query(NotificationReminder).filter( - NotificationReminder.user_id == user_id, - NotificationReminder.type == reminder_type + NotificationReminder.user_id == user_id, NotificationReminder.type == reminder_type ) # If a threshold is provided, filter for reminders with this threshold @@ -1525,7 +1491,7 @@ def get_notification_reminder( def delete_notification_reminder_by_type( - db: Session, user_id: int, reminder_type: ReminderType, threshold: Optional[int] = None + db: Session, user_id: int, reminder_type: ReminderType, threshold: Optional[int] = None ) -> None: """ Deletes a notification reminder for a user based on the reminder type and optional threshold. @@ -1537,8 +1503,7 @@ def delete_notification_reminder_by_type( threshold (Optional[int]): The threshold to delete (e.g., days left or usage percent). If not provided, deletes all reminders of that type. """ stmt = delete(NotificationReminder).where( - NotificationReminder.user_id == user_id, - NotificationReminder.type == reminder_type + NotificationReminder.user_id == user_id, NotificationReminder.type == reminder_type ) # If a threshold is provided, include it in the filter @@ -1564,6 +1529,5 @@ def delete_notification_reminder(db: Session, dbreminder: NotificationReminder) def count_online_users(db: Session, time_delta: timedelta): twenty_four_hours_ago = datetime.utcnow() - time_delta - query = db.query(func.count(User.id)).filter(User.online_at.isnot( - None), User.online_at >= twenty_four_hours_ago) + query = db.query(func.count(User.id)).filter(User.online_at.isnot(None), User.online_at >= twenty_four_hours_ago) return query.scalar() diff --git a/app/db/models.py b/app/db/models.py index e47acad76..0a2c448d3 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -8,7 +8,6 @@ Column, DateTime, Enum, - False_, Float, ForeignKey, Integer, @@ -42,7 +41,7 @@ class Admin(Base): telegram_id = Column(BigInteger, nullable=True, default=None) discord_webhook = Column(String(1024), nullable=True, default=None) users_usage = Column(BigInteger, nullable=False, default=0) - is_disabled = Column(Boolean, nullable=False, server_default='0', default=False) + is_disabled = Column(Boolean, nullable=False, server_default="0", default=False) usage_logs = relationship("AdminUsageLogs", back_populates="admin") @@ -60,7 +59,7 @@ class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) - username = Column(String(34, collation='NOCASE'), unique=True, index=True) + username = Column(String(34, collation="NOCASE"), unique=True, index=True) proxies = relationship("Proxy", back_populates="user", cascade="all, delete-orphan") status = Column(Enum(UserStatus), nullable=False, default=UserStatus.active) used_traffic = Column(BigInteger, default=0) @@ -93,12 +92,7 @@ class User(Base): edit_at = Column(DateTime, nullable=True, default=None) last_status_change = Column(DateTime, default=datetime.utcnow, nullable=True) - next_plan = relationship( - "NextPlan", - uselist=False, - back_populates="user", - cascade="all, delete-orphan" - ) + next_plan = relationship("NextPlan", uselist=False, back_populates="user", cascade="all, delete-orphan") @hybrid_property def reseted_usage(self) -> int: @@ -107,17 +101,14 @@ def reseted_usage(self) -> int: @reseted_usage.expression def reseted_usage(cls): return ( - select(func.sum(UserUsageResetLogs.used_traffic_at_reset)). - where(UserUsageResetLogs.user_id == cls.id). - label('reseted_usage') + select(func.sum(UserUsageResetLogs.used_traffic_at_reset)) + .where(UserUsageResetLogs.user_id == cls.id) + .label("reseted_usage") ) @property def lifetime_used_traffic(self) -> int: - return int( - sum([log.used_traffic_at_reset for log in self.usage_logs]) - + self.used_traffic - ) + return int(sum([log.used_traffic_at_reset for log in self.usage_logs]) + self.used_traffic) @property def last_traffic_reset_time(self): @@ -159,15 +150,15 @@ def inbounds(self): class NextPlan(Base): - __tablename__ = 'next_plans' + __tablename__ = "next_plans" id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('users.id'), nullable=False) - user_template_id = Column(Integer, ForeignKey('user_templates.id'), nullable=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + user_template_id = Column(Integer, ForeignKey("user_templates.id"), nullable=True) data_limit = Column(BigInteger, nullable=False) expire = Column(Integer, nullable=True) - add_remaining_traffic = Column(Boolean, nullable=False, default=False, server_default='0') - fire_on_either = Column(Boolean, nullable=False, default=True, server_default='0') + add_remaining_traffic = Column(Boolean, nullable=False, default=False, server_default="0") + fire_on_either = Column(Boolean, nullable=False, default=True, server_default="0") user = relationship("User", back_populates="next_plan") user_template = relationship("UserTemplate", back_populates="next_plans") @@ -183,15 +174,9 @@ class UserTemplate(Base): username_prefix = Column(String(20), nullable=True) username_suffix = Column(String(20), nullable=True) - inbounds = relationship( - "ProxyInbound", secondary=template_inbounds_association - ) - - next_plans = relationship( - "NextPlan", - back_populates="user_template", - cascade="all, delete-orphan" - ) + inbounds = relationship("ProxyInbound", secondary=template_inbounds_association) + + next_plans = relationship("NextPlan", back_populates="user_template", cascade="all, delete-orphan") class UserUsageResetLogs(Base): @@ -212,9 +197,7 @@ class Proxy(Base): user = relationship("User", back_populates="proxies") type = Column(Enum(ProxyTypes), nullable=False) settings = Column(JSON, nullable=False) - excluded_inbounds = relationship( - "ProxyInbound", secondary=excluded_inbounds_association - ) + excluded_inbounds = relationship("ProxyInbound", secondary=excluded_inbounds_association) class ProxyInbound(Base): @@ -222,9 +205,7 @@ class ProxyInbound(Base): id = Column(Integer, primary_key=True) tag = Column(String(256), unique=True, nullable=False, index=True) - hosts = relationship( - "ProxyHost", back_populates="inbound", cascade="all, delete-orphan" - ) + hosts = relationship("ProxyHost", back_populates="inbound", cascade="all, delete-orphan") class ProxyHost(Base): @@ -251,24 +232,24 @@ class ProxyHost(Base): unique=False, nullable=False, default=ProxyHostSecurity.none, - server_default=ProxyHostSecurity.none.name + server_default=ProxyHostSecurity.none.name, ) fingerprint = Column( Enum(ProxyHostFingerprint), unique=False, nullable=False, default=ProxyHostSecurity.none, - server_default=ProxyHostSecurity.none.name + server_default=ProxyHostSecurity.none.name, ) inbound_tag = Column(String(256), ForeignKey("inbounds.tag"), nullable=False) inbound = relationship("ProxyInbound", back_populates="hosts") allowinsecure = Column(Boolean, nullable=True) is_disabled = Column(Boolean, nullable=True, default=False) - mux_enable = Column(Boolean, nullable=False, default=False, server_default='0') + mux_enable = Column(Boolean, nullable=False, default=False, server_default="0") fragment_setting = Column(String(100), nullable=True) noise_setting = Column(String(2000), nullable=True) - random_user_agent = Column(Boolean, nullable=False, default=False, server_default='0') + random_user_agent = Column(Boolean, nullable=False, default=False, server_default="0") use_sni_as_host = Column(Boolean, nullable=False, default=False, server_default="0") @@ -284,9 +265,7 @@ class JWT(Base): __tablename__ = "jwt" id = Column(Integer, primary_key=True) - secret_key = Column( - String(64), nullable=False, default=lambda: os.urandom(32).hex() - ) + secret_key = Column(String(64), nullable=False, default=lambda: os.urandom(32).hex()) class TLS(Base): @@ -301,7 +280,7 @@ class Node(Base): __tablename__ = "nodes" id = Column(Integer, primary_key=True) - name = Column(String(256, collation='NOCASE'), unique=True) + name = Column(String(256, collation="NOCASE"), unique=True) address = Column(String(256), unique=False, nullable=False) port = Column(Integer, unique=False, nullable=False) api_port = Column(Integer, unique=False, nullable=False) @@ -319,9 +298,7 @@ class Node(Base): class NodeUserUsage(Base): __tablename__ = "node_user_usages" - __table_args__ = ( - UniqueConstraint('created_at', 'user_id', 'node_id'), - ) + __table_args__ = (UniqueConstraint("created_at", "user_id", "node_id"),) id = Column(Integer, primary_key=True) created_at = Column(DateTime, unique=False, nullable=False) # one hour per record @@ -334,9 +311,7 @@ class NodeUserUsage(Base): class NodeUsage(Base): __tablename__ = "node_usages" - __table_args__ = ( - UniqueConstraint('created_at', 'node_id'), - ) + __table_args__ = (UniqueConstraint("created_at", "node_id"),) id = Column(Integer, primary_key=True) created_at = Column(DateTime, unique=False, nullable=False) # one hour per record diff --git a/app/dependencies.py b/app/dependencies.py index 060a07c13..4f25559e4 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -19,7 +19,9 @@ def validate_admin(db: Session, username: str, password: str) -> Optional[AdminV dbadmin = crud.get_admin(db, username) if dbadmin and AdminInDB.model_validate(dbadmin).verify_password(password): - return AdminValidationResult(username=dbadmin.username, is_sudo=dbadmin.is_sudo, is_disabled=dbadmin.is_disabled) + return AdminValidationResult( + username=dbadmin.username, is_sudo=dbadmin.is_sudo, is_disabled=dbadmin.is_disabled + ) return None @@ -44,8 +46,9 @@ def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[st """Validate if start and end dates are correct and if end is after start.""" try: if start: - start_date = start if isinstance(start, datetime) else datetime.fromisoformat( - start).astimezone(timezone.utc) + start_date = ( + start if isinstance(start, datetime) else datetime.fromisoformat(start).astimezone(timezone.utc) + ) else: start_date = datetime.now(timezone.utc) - timedelta(days=30) if end: @@ -68,28 +71,23 @@ def get_user_template(template_id: int, db: Session = Depends(get_db)): return dbuser_template -def get_validated_sub( - token: str, - db: Session = Depends(get_db) -) -> UserResponse: +def get_validated_sub(token: str, db: Session = Depends(get_db)) -> UserResponse: sub = get_subscription_payload(token) if not sub: raise HTTPException(status_code=404, detail="Not Found") - dbuser = crud.get_user(db, sub['username']) - if not dbuser or dbuser.created_at > sub['created_at']: + dbuser = crud.get_user(db, sub["username"]) + if not dbuser or dbuser.created_at > sub["created_at"]: raise HTTPException(status_code=404, detail="Not Found") - if dbuser.sub_revoked_at and dbuser.sub_revoked_at > sub['created_at']: + if dbuser.sub_revoked_at and dbuser.sub_revoked_at > sub["created_at"]: raise HTTPException(status_code=404, detail="Not Found") return dbuser def get_validated_user( - username: str, - admin: Admin = Depends(Admin.get_current), - db: Session = Depends(get_db) + username: str, admin: Admin = Depends(Admin.get_current), db: Session = Depends(get_db) ) -> UserResponse: dbuser = crud.get_user(db, username) if not dbuser: @@ -101,20 +99,13 @@ def get_validated_user( return dbuser -def get_expired_users_list(db: Session, admin: Admin, expired_after: datetime = None, - expired_before: datetime = None): - +def get_expired_users_list(db: Session, admin: Admin, expired_after: datetime = None, expired_before: datetime = None): dbadmin = crud.get_admin(db, admin.username) dbusers = crud.get_users( - db=db, - status=[UserStatus.expired, UserStatus.limited], - admin=dbadmin if not admin.is_sudo else None + db=db, status=[UserStatus.expired, UserStatus.limited], admin=dbadmin if not admin.is_sudo else None ) - return [ - u for u in dbusers - if u.expire and expired_after <= u.expire <= expired_before - ] + return [u for u in dbusers if u.expire and expired_after <= u.expire <= expired_before] def get_host(host_id: int, db: Session = Depends(get_db)) -> ProxyHost: @@ -127,5 +118,8 @@ def get_host(host_id: int, db: Session = Depends(get_db)) -> ProxyHost: def get_v2ray_links(user: UserResponse) -> list: return generate_v2ray_links( - user.proxies, user.inbounds, extra_data=user.model_dump(), reverse=False, + user.proxies, + user.inbounds, + extra_data=user.model_dump(), + reverse=False, ) diff --git a/app/discord/__init__.py b/app/discord/__init__.py index 6eb5f3e44..b768b9263 100644 --- a/app/discord/__init__.py +++ b/app/discord/__init__.py @@ -9,7 +9,7 @@ report_user_usage_reset, report_user_data_reset_by_next, report_user_subscription_revoked, - report_login + report_login, ) __all__ = [ @@ -21,5 +21,5 @@ "report_user_usage_reset", "report_user_data_reset_by_next", "report_user_subscription_revoked", - "report_login" + "report_login", ] diff --git a/app/discord/handlers/report.py b/app/discord/handlers/report.py index beedbafce..2b371278d 100644 --- a/app/discord/handlers/report.py +++ b/app/discord/handlers/report.py @@ -9,7 +9,7 @@ from config import DISCORD_WEBHOOK_URL -def send_webhooks(json_data, admin_webhook:str = None): +def send_webhooks(json_data, admin_webhook: str = None): if DISCORD_WEBHOOK_URL: send_webhook(json_data=json_data, webhook=DISCORD_WEBHOOK_URL) if admin_webhook: @@ -29,16 +29,16 @@ def send_webhook(json_data, webhook): def report_status_change(username: str, status: str, admin: Admin = None): _status = { - 'active': '**:white_check_mark: Activated**', - 'disabled': '**:x: Disabled**', - 'limited': '**:low_battery: #Limited**', - 'expired': '**:clock5: #Expired**' + "active": "**:white_check_mark: Activated**", + "disabled": "**:x: Disabled**", + "limited": "**:low_battery: #Limited**", + "expired": "**:clock5: #Expired**", } _status_color = { - 'active': int("9ae6b4", 16), - 'disabled': int("424b59", 16), - 'limited': int("f8a7a8", 16), - 'expired': int("fbd38d", 16) + "active": int("9ae6b4", 16), + "disabled": int("424b59", 16), + "limited": int("f8a7a8", 16), + "expired": int("fbd38d", 16), } statusChange = { "content": "", @@ -46,28 +46,32 @@ def report_status_change(username: str, status: str, admin: Admin = None): { "description": f"{_status[status]}\n----------------------\n**Username:** {username}", "color": _status_color[status], - "footer": { - "text": f"Belongs To: {admin.username if admin else None}" - }, + "footer": {"text": f"Belongs To: {admin.username if admin else None}"}, } ], } send_webhooks( - json_data=statusChange, - admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None - ) - - -def report_new_user(username: str, by: str, expire_date: int, data_limit: int, proxies: list, has_next_plan: bool, - data_limit_reset_strategy:UserDataLimitResetStrategy, admin: Admin = None): - - data_limit=readable_size(data_limit) if data_limit else "Unlimited" - expire_date=datetime.fromtimestamp(expire_date).strftime("%H:%M:%S %Y-%m-%d") if expire_date else "Never" - proxies="" if not proxies else ", ".join([escape_html(proxy) for proxy in proxies]) + json_data=statusChange, admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None + ) + + +def report_new_user( + username: str, + by: str, + expire_date: int, + data_limit: int, + proxies: list, + has_next_plan: bool, + data_limit_reset_strategy: UserDataLimitResetStrategy, + admin: Admin = None, +): + data_limit = readable_size(data_limit) if data_limit else "Unlimited" + expire_date = datetime.fromtimestamp(expire_date).strftime("%H:%M:%S %Y-%m-%d") if expire_date else "Never" + proxies = "" if not proxies else ", ".join([escape_html(proxy) for proxy in proxies]) reportNewUser = { - 'content': '', - 'embeds': [ + "content": "", + "embeds": [ { "title": ":new: Created", "description": f""" @@ -77,156 +81,136 @@ def report_new_user(username: str, by: str, expire_date: int, data_limit: int, p **Proxies:** {proxies} **Data Limit Reset Strategy:**{data_limit_reset_strategy} **Has Next Plan:**{has_next_plan}""", - - "footer": { - "text": f"Belongs To: {admin.username if admin else None}\nBy: {by}" - }, - "color": int("00ff00", 16) + "footer": {"text": f"Belongs To: {admin.username if admin else None}\nBy: {by}"}, + "color": int("00ff00", 16), } - ] + ], } send_webhooks( - json_data=reportNewUser, - admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None - ) - - -def report_user_modification(username: str, expire_date: int, data_limit: int, proxies: list, by: str, has_next_plan: bool, - data_limit_reset_strategy:UserDataLimitResetStrategy, admin: Admin = None): - - data_limit=readable_size(data_limit) if data_limit else "Unlimited" - expire_date=datetime.fromtimestamp(expire_date).strftime("%H:%M:%S %Y-%m-%d") if expire_date else "Never" - proxies="" if not proxies else ", ".join([escape_html(proxy) for proxy in proxies]) - protocols = proxies + json_data=reportNewUser, admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None + ) + + +def report_user_modification( + username: str, + expire_date: int, + data_limit: int, + proxies: list, + by: str, + has_next_plan: bool, + data_limit_reset_strategy: UserDataLimitResetStrategy, + admin: Admin = None, +): + data_limit = readable_size(data_limit) if data_limit else "Unlimited" + expire_date = datetime.fromtimestamp(expire_date).strftime("%H:%M:%S %Y-%m-%d") if expire_date else "Never" + proxies = "" if not proxies else ", ".join([escape_html(proxy) for proxy in proxies]) reportUserModification = { - 'content': '', - 'embeds': [ + "content": "", + "embeds": [ { - 'title': ':pencil2: Modified', - 'description': f""" + "title": ":pencil2: Modified", + "description": f""" **Username:** {username} **Traffic Limit:** {data_limit} **Expire Date:** {expire_date} **Proxies:** {proxies} **Data Limit Reset Strategy:**{data_limit_reset_strategy} **Has Next Plan:**{has_next_plan}""", - - "footer": { - "text": f"Belongs To: {admin.username if admin else None}\nBy: {by}" - }, - 'color': int("00ffff", 16) + "footer": {"text": f"Belongs To: {admin.username if admin else None}\nBy: {by}"}, + "color": int("00ffff", 16), } - ] + ], } send_webhooks( - reportUserModification, - admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None - ) + reportUserModification, admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None + ) def report_user_deletion(username: str, by: str, admin: Admin = None): userDeletion = { - 'content': '', - 'embeds': [ + "content": "", + "embeds": [ { - 'title': ':wastebasket: Deleted', - 'description': f'**Username: **{username}', - "footer": { - "text": f"Belongs To: {admin.username if admin else None}\nBy: {by}" - }, - 'color': int("ff0000", 16) + "title": ":wastebasket: Deleted", + "description": f"**Username: **{username}", + "footer": {"text": f"Belongs To: {admin.username if admin else None}\nBy: {by}"}, + "color": int("ff0000", 16), } - ] + ], } send_webhooks( - json_data=userDeletion, - admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None - ) + json_data=userDeletion, admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None + ) def report_user_usage_reset(username: str, by: str, admin: Admin = None): userUsageReset = { - 'content': '', - 'embeds': [ + "content": "", + "embeds": [ { - 'title': ':repeat: Reset', - 'description': f'**Username:** {username}', - "footer": { - "text": f"Belongs To: {admin.username if admin else None}\nBy: {by}" - }, - 'color': int('00ffff', 16) + "title": ":repeat: Reset", + "description": f"**Username:** {username}", + "footer": {"text": f"Belongs To: {admin.username if admin else None}\nBy: {by}"}, + "color": int("00ffff", 16), } - ] + ], } send_webhooks( - json_data=userUsageReset, - admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None - ) + json_data=userUsageReset, admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None + ) def report_user_data_reset_by_next(user: User, admin: Admin = None): userUsageReset = { - 'content': '', - 'embeds': [ + "content": "", + "embeds": [ { - 'title': ':repeat: AutoReset', - 'description': f""" + "title": ":repeat: AutoReset", + "description": f""" **Username:** {user.username} **Traffic Limit:** {user.data_limit} **Expire Date:** {user.expire}""", - - "footer": { - "text": f"Belongs To: {admin.username if admin else None}" - }, - 'color': int('00ffff', 16) + "footer": {"text": f"Belongs To: {admin.username if admin else None}"}, + "color": int("00ffff", 16), } - ] + ], } send_webhooks( - json_data=userUsageReset, - admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None - ) + json_data=userUsageReset, admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None + ) def report_user_subscription_revoked(username: str, by: str, admin: Admin = None): subscriptionRevoked = { - 'content': '', - 'embeds': [ + "content": "", + "embeds": [ { - 'title': ':repeat: Revoked', - 'description': f'**Username:** {username}', - "footer": { - "text": f"Belongs To: {admin.username if admin else None} \nBy: {by}" - }, - 'color': int('ff0000', 16) + "title": ":repeat: Revoked", + "description": f"**Username:** {username}", + "footer": {"text": f"Belongs To: {admin.username if admin else None} \nBy: {by}"}, + "color": int("ff0000", 16), } - ] + ], } send_webhooks( - json_data=subscriptionRevoked, - admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None - ) + json_data=subscriptionRevoked, admin_webhook=admin.discord_webhook if admin and admin.discord_webhook else None + ) def report_login(username: str, password: str, client_ip: str, status: str): login = { - 'content': '', - 'embeds': [ + "content": "", + "embeds": [ { - 'title': ':repeat: Login', - 'description': f""" + "title": ":repeat: Login", + "description": f""" **Username:** {username} **Password:** {password} **Client ip**: {client_ip}""", - "footer": { - "text": f"login status: {status}" - }, - 'color': int('ff0000', 16) + "footer": {"text": f"login status: {status}"}, + "color": int("ff0000", 16), } - ] + ], } - send_webhooks( - json_data=login, - admin_webhook=None - ) + send_webhooks(json_data=login, admin_webhook=None) diff --git a/app/jobs/0_xray_core.py b/app/jobs/0_xray_core.py index 1a1ce27ae..ff526166c 100644 --- a/app/jobs/0_xray_core.py +++ b/app/jobs/0_xray_core.py @@ -60,9 +60,9 @@ def start_core(): for node_id in node_ids: xray.operations.connect_node(node_id, config) - scheduler.add_job(core_health_check, 'interval', - seconds=JOB_CORE_HEALTH_CHECK_INTERVAL, - coalesce=True, max_instances=1) + scheduler.add_job( + core_health_check, "interval", seconds=JOB_CORE_HEALTH_CHECK_INTERVAL, coalesce=True, max_instances=1 + ) @app.on_event("shutdown") diff --git a/app/jobs/__init__.py b/app/jobs/__init__.py index 627a8bd8b..bc7f1935d 100644 --- a/app/jobs/__init__.py +++ b/app/jobs/__init__.py @@ -5,8 +5,8 @@ modules = glob.glob(join(dirname(__file__), "*.py")) for file in modules: - name = basename(file).replace('.py', '') - if name.startswith('_'): + name = basename(file).replace(".py", "") + if name.startswith("_"): continue spec = importlib.util.spec_from_file_location(name, file) diff --git a/app/jobs/record_usages.py b/app/jobs/record_usages.py index b37286deb..c5170634a 100644 --- a/app/jobs/record_usages.py +++ b/app/jobs/record_usages.py @@ -5,7 +5,7 @@ from typing import Union from pymysql.err import OperationalError -from sqlalchemy import and_, bindparam, insert, select, update +from sqlalchemy import and_, bindparam, insert, select, update from sqlalchemy.orm import Session from sqlalchemy.sql.dml import Insert @@ -22,9 +22,9 @@ def safe_execute(db: Session, stmt, params=None): - if db.bind.name == 'mysql': + if db.bind.name == "mysql": if isinstance(stmt, Insert): - stmt = stmt.prefix_with('IGNORE') + stmt = stmt.prefix_with("IGNORE") tries = 0 done = False @@ -45,41 +45,44 @@ def safe_execute(db: Session, stmt, params=None): db.commit() -def record_user_stats(params: list, node_id: Union[int, None], - consumption_factor: int = 1): +def record_user_stats(params: list, node_id: Union[int, None], consumption_factor: int = 1): if not params: return - created_at = datetime.fromisoformat(datetime.utcnow().strftime('%Y-%m-%dT%H:00:00')) + created_at = datetime.fromisoformat(datetime.utcnow().strftime("%Y-%m-%dT%H:00:00")) with GetDB() as db: # make user usage row if doesn't exist - select_stmt = select(NodeUserUsage.user_id) \ - .where(and_(NodeUserUsage.node_id == node_id, NodeUserUsage.created_at == created_at)) + select_stmt = select(NodeUserUsage.user_id).where( + and_(NodeUserUsage.node_id == node_id, NodeUserUsage.created_at == created_at) + ) existings = [r[0] for r in db.execute(select_stmt).fetchall()] uids_to_insert = set() for p in params: - uid = int(p['uid']) + uid = int(p["uid"]) if uid in existings: continue uids_to_insert.add(uid) if uids_to_insert: stmt = insert(NodeUserUsage).values( - user_id=bindparam('uid'), - created_at=created_at, - node_id=node_id, - used_traffic=0 + user_id=bindparam("uid"), created_at=created_at, node_id=node_id, used_traffic=0 ) - safe_execute(db, stmt, [{'uid': uid} for uid in uids_to_insert]) + safe_execute(db, stmt, [{"uid": uid} for uid in uids_to_insert]) # record - stmt = update(NodeUserUsage) \ - .values(used_traffic=NodeUserUsage.used_traffic + bindparam('value') * consumption_factor) \ - .where(and_(NodeUserUsage.user_id == bindparam('uid'), - NodeUserUsage.node_id == node_id, - NodeUserUsage.created_at == created_at)) + stmt = ( + update(NodeUserUsage) + .values(used_traffic=NodeUserUsage.used_traffic + bindparam("value") * consumption_factor) + .where( + and_( + NodeUserUsage.user_id == bindparam("uid"), + NodeUserUsage.node_id == node_id, + NodeUserUsage.created_at == created_at, + ) + ) + ) safe_execute(db, stmt, params) @@ -87,22 +90,24 @@ def record_node_stats(params: dict, node_id: Union[int, None]): if not params: return - created_at = datetime.fromisoformat(datetime.utcnow().strftime('%Y-%m-%dT%H:00:00')) + created_at = datetime.fromisoformat(datetime.utcnow().strftime("%Y-%m-%dT%H:00:00")) with GetDB() as db: - # make node usage row if doesn't exist - select_stmt = select(NodeUsage.node_id). \ - where(and_(NodeUsage.node_id == node_id, NodeUsage.created_at == created_at)) + select_stmt = select(NodeUsage.node_id).where( + and_(NodeUsage.node_id == node_id, NodeUsage.created_at == created_at) + ) notfound = db.execute(select_stmt).first() is None if notfound: stmt = insert(NodeUsage).values(created_at=created_at, node_id=node_id, uplink=0, downlink=0) safe_execute(db, stmt) # record - stmt = update(NodeUsage). \ - values(uplink=NodeUsage.uplink + bindparam('up'), downlink=NodeUsage.downlink + bindparam('down')). \ - where(and_(NodeUsage.node_id == node_id, NodeUsage.created_at == created_at)) + stmt = ( + update(NodeUsage) + .values(uplink=NodeUsage.uplink + bindparam("up"), downlink=NodeUsage.downlink + bindparam("down")) + .where(and_(NodeUsage.node_id == node_id, NodeUsage.created_at == created_at)) + ) safe_execute(db, stmt, params) @@ -110,8 +115,8 @@ def record_node_stats(params: dict, node_id: Union[int, None]): def get_users_stats(api: XRayAPI): try: params = defaultdict(int) - for stat in filter(attrgetter('value'), api.get_users_stats(reset=True, timeout=30)): - params[stat.name.split('.', 1)[0]] += stat.value + for stat in filter(attrgetter("value"), api.get_users_stats(reset=True, timeout=30)): + params[stat.name.split(".", 1)[0]] += stat.value params = list({"uid": uid, "value": value} for uid, value in params.items()) return params except xray_exc.XrayError: @@ -120,8 +125,10 @@ def get_users_stats(api: XRayAPI): def get_outbounds_stats(api: XRayAPI): try: - params = [{"up": stat.value, "down": 0} if stat.link == "uplink" else {"up": 0, "down": stat.value} - for stat in filter(attrgetter('value'), api.get_outbounds_stats(reset=True, timeout=10))] + params = [ + {"up": stat.value, "down": 0} if stat.link == "uplink" else {"up": 0, "down": stat.value} + for stat in filter(attrgetter("value"), api.get_outbounds_stats(reset=True, timeout=10)) + ] return params except xray_exc.XrayError: return [] @@ -144,7 +151,7 @@ def record_user_usages(): for node_id, params in api_params.items(): coefficient = usage_coefficient.get(node_id, 1) # get the usage coefficient for the node for param in params: - users_usage[param['uid']] += int(param['value'] * coefficient) # apply the usage coefficient + users_usage[param["uid"]] += int(param["value"] * coefficient) # apply the usage coefficient users_usage = list({"uid": uid, "value": value} for uid, value in users_usage.items()) if not users_usage: return @@ -160,20 +167,21 @@ def record_user_usages(): # record users usage with GetDB() as db: - stmt = update(User). \ - where(User.id == bindparam('uid')). \ - values( - used_traffic=User.used_traffic + bindparam('value'), - online_at=datetime.utcnow() + stmt = ( + update(User) + .where(User.id == bindparam("uid")) + .values(used_traffic=User.used_traffic + bindparam("value"), online_at=datetime.utcnow()) ) safe_execute(db, stmt, users_usage) admin_data = [{"admin_id": admin_id, "value": value} for admin_id, value in admin_usage.items()] if admin_data: - admin_update_stmt = update(Admin). \ - where(Admin.id == bindparam('admin_id')). \ - values(users_usage=Admin.users_usage + bindparam('value')) + admin_update_stmt = ( + update(Admin) + .where(Admin.id == bindparam("admin_id")) + .values(users_usage=Admin.users_usage + bindparam("value")) + ) safe_execute(db, admin_update_stmt, admin_data) if DISABLE_RECORDING_NODE_USAGE: @@ -197,17 +205,14 @@ def record_node_usages(): total_down = 0 for node_id, params in api_params.items(): for param in params: - total_up += param['up'] - total_down += param['down'] + total_up += param["up"] + total_down += param["down"] if not (total_up or total_down): return # record nodes usage with GetDB() as db: - stmt = update(System).values( - uplink=System.uplink + total_up, - downlink=System.downlink + total_down - ) + stmt = update(System).values(uplink=System.uplink + total_up, downlink=System.downlink + total_down) safe_execute(db, stmt) if DISABLE_RECORDING_NODE_USAGE: @@ -217,9 +222,9 @@ def record_node_usages(): record_node_stats(params, node_id) -scheduler.add_job(record_user_usages, 'interval', - seconds=JOB_RECORD_USER_USAGES_INTERVAL, - coalesce=True, max_instances=1) -scheduler.add_job(record_node_usages, 'interval', - seconds=JOB_RECORD_NODE_USAGES_INTERVAL, - coalesce=True, max_instances=1) +scheduler.add_job( + record_user_usages, "interval", seconds=JOB_RECORD_USER_USAGES_INTERVAL, coalesce=True, max_instances=1 +) +scheduler.add_job( + record_node_usages, "interval", seconds=JOB_RECORD_NODE_USAGES_INTERVAL, coalesce=True, max_instances=1 +) diff --git a/app/jobs/remove_expired_users.py b/app/jobs/remove_expired_users.py index eda9640bd..7be0c31a8 100644 --- a/app/jobs/remove_expired_users.py +++ b/app/jobs/remove_expired_users.py @@ -6,7 +6,7 @@ from app.utils import report from config import USER_AUTODELETE_INCLUDE_LIMITED_ACCOUNTS -SYSTEM_ADMIN = Admin(username='system', is_sudo=True, telegram_id=None, discord_webhook=None) +SYSTEM_ADMIN = Admin(username="system", is_sudo=True, telegram_id=None, discord_webhook=None) def remove_expired_users(): @@ -14,10 +14,10 @@ def remove_expired_users(): deleted_users = crud.autodelete_expired_users(db, USER_AUTODELETE_INCLUDE_LIMITED_ACCOUNTS) for user in deleted_users: - report.user_deleted(user.username, SYSTEM_ADMIN, - user_admin=Admin.model_validate(user.admin) if user.admin else None - ) + report.user_deleted( + user.username, SYSTEM_ADMIN, user_admin=Admin.model_validate(user.admin) if user.admin else None + ) logger.log(logging.INFO, "Expired user %s deleted." % user.username) -scheduler.add_job(remove_expired_users, 'interval', coalesce=True, hours=6, max_instances=1) +scheduler.add_job(remove_expired_users, "interval", coalesce=True, hours=6, max_instances=1) diff --git a/app/jobs/reset_user_data_usage.py b/app/jobs/reset_user_data_usage.py index 5fb5f3526..99a743709 100644 --- a/app/jobs/reset_user_data_usage.py +++ b/app/jobs/reset_user_data_usage.py @@ -15,17 +15,16 @@ def reset_user_data_usage(): now = datetime.utcnow() with GetDB() as db: - for user in get_users(db, - status=[ - UserStatus.active, - UserStatus.limited - ], - reset_strategy=[ - UserDataLimitResetStrategy.day.value, - UserDataLimitResetStrategy.week.value, - UserDataLimitResetStrategy.month.value, - UserDataLimitResetStrategy.year.value, - ]): + for user in get_users( + db, + status=[UserStatus.active, UserStatus.limited], + reset_strategy=[ + UserDataLimitResetStrategy.day.value, + UserDataLimitResetStrategy.week.value, + UserDataLimitResetStrategy.month.value, + UserDataLimitResetStrategy.year.value, + ], + ): last_reset_time = user.last_traffic_reset_time num_days_to_reset = reset_strategy_to_days[user.data_limit_reset_strategy] @@ -37,7 +36,7 @@ def reset_user_data_usage(): if user.status == UserStatus.limited: xray.operations.add_user(user) - logger.info(f"User data usage reset for User \"{user.username}\"") + logger.info(f'User data usage reset for User "{user.username}"') -scheduler.add_job(reset_user_data_usage, 'interval', coalesce=True, hours=1) +scheduler.add_job(reset_user_data_usage, "interval", coalesce=True, hours=1) diff --git a/app/jobs/review_users.py b/app/jobs/review_users.py index 41a04f578..844434289 100644 --- a/app/jobs/review_users.py +++ b/app/jobs/review_users.py @@ -34,8 +34,7 @@ def add_notification_reminders(db: Session, user: "User") -> None: if usage_percent >= percent: if not get_notification_reminder(db, user.id, ReminderType.data_usage, threshold=percent): report.data_usage_percent_reached( - db, usage_percent, UserResponse.model_validate(user), - user.id, user.expire, threshold=percent + db, usage_percent, UserResponse.model_validate(user), user.id, user.expire, threshold=percent ) break @@ -46,8 +45,7 @@ def add_notification_reminders(db: Session, user: "User") -> None: if expire_days <= days_left: if not get_notification_reminder(db, user.id, ReminderType.expiration_date, threshold=days_left): report.expire_days_reached( - db, expire_days, UserResponse.model_validate(user), - user.id, user.expire, threshold=days_left + db, expire_days, UserResponse.model_validate(user), user.id, user.expire, threshold=days_left ) break @@ -64,13 +62,11 @@ def review(): now = datetime.now(timezone.utc) with GetDB() as db: for user in get_users(db, status=UserStatus.active): - limited = user.data_limit and user.used_traffic >= user.data_limit expired = user.expire and user.expire.replace(tzinfo=timezone.utc) <= now if (limited or expired) and user.next_plan is not None: if user.next_plan is not None: - if user.next_plan.fire_on_either: reset_user_by_next_report(db, user) continue @@ -91,13 +87,13 @@ def review(): xray.operations.remove_user(user) update_user_status(db, user, status) - report.status_change(username=user.username, status=status, - user=UserResponse.model_validate(user), user_admin=user.admin) + report.status_change( + username=user.username, status=status, user=UserResponse.model_validate(user), user_admin=user.admin + ) - logger.info(f"User \"{user.username}\" status changed to {status.value}") + logger.info(f'User "{user.username}" status changed to {status.value}') for user in get_users(db, status=UserStatus.on_hold): - if user.edit_at: base_time = user.edit_at else: @@ -118,12 +114,9 @@ def review(): start_user_expire(db, user) user = UserResponse.model_validate(user) - report.status_change(username=user.username, status=status, - user=user, user_admin=user.admin) + report.status_change(username=user.username, status=status, user=user, user_admin=user.admin) - logger.info(f"User \"{user.username}\" status changed to {status.value}") + logger.info(f'User "{user.username}" status changed to {status.value}') -scheduler.add_job(review, 'interval', - seconds=JOB_REVIEW_USERS_INTERVAL, - coalesce=True, max_instances=1) +scheduler.add_job(review, "interval", seconds=JOB_REVIEW_USERS_INTERVAL, coalesce=True, max_instances=1) diff --git a/app/jobs/send_notifications.py b/app/jobs/send_notifications.py index e7aaffc7b..d6922c205 100644 --- a/app/jobs/send_notifications.py +++ b/app/jobs/send_notifications.py @@ -9,10 +9,13 @@ from app.db import GetDB from app.db.models import NotificationReminder from app.utils.notification import queue -from config import (JOB_SEND_NOTIFICATIONS_INTERVAL, - NUMBER_OF_RECURRENT_NOTIFICATIONS, - RECURRENT_NOTIFICATIONS_TIMEOUT, WEBHOOK_ADDRESS, - WEBHOOK_SECRET) +from config import ( + JOB_SEND_NOTIFICATIONS_INTERVAL, + NUMBER_OF_RECURRENT_NOTIFICATIONS, + RECURRENT_NOTIFICATIONS_TIMEOUT, + WEBHOOK_ADDRESS, + WEBHOOK_SECRET, +) session = Session() @@ -57,8 +60,8 @@ def send_notifications(): notifications_to_send = list() try: - while (notification := queue.popleft()): - if (notification.tries > NUMBER_OF_RECURRENT_NOTIFICATIONS): + while notification := queue.popleft(): + if notification.tries > NUMBER_OF_RECURRENT_NOTIFICATIONS: continue if notification.send_at > dt.utcnow().timestamp(): queue.append(notification) # add it to the queue again for the next check @@ -75,7 +78,8 @@ def send_notifications(): continue notification.tries += 1 notification.send_at = ( # schedule notification for n seconds later - dt.utcnow() + td(seconds=RECURRENT_NOTIFICATIONS_TIMEOUT)).timestamp() + dt.utcnow() + td(seconds=RECURRENT_NOTIFICATIONS_TIMEOUT) + ).timestamp() queue.append(notification) @@ -86,13 +90,12 @@ def delete_expired_reminders() -> None: if WEBHOOK_ADDRESS: + @app.on_event("shutdown") def app_shutdown(): logger.info("Sending pending notifications before shutdown...") send_notifications() logger.info("Send webhook job started") - scheduler.add_job(send_notifications, "interval", - seconds=JOB_SEND_NOTIFICATIONS_INTERVAL, - replace_existing=True) + scheduler.add_job(send_notifications, "interval", seconds=JOB_SEND_NOTIFICATIONS_INTERVAL, replace_existing=True) scheduler.add_job(delete_expired_reminders, "interval", hours=2, start_date=dt.utcnow() + td(minutes=1)) diff --git a/app/models/admin.py b/app/models/admin.py index b4f69c758..66d4b6102 100644 --- a/app/models/admin.py +++ b/app/models/admin.py @@ -27,7 +27,7 @@ class Admin(BaseModel): is_disabled: bool = False model_config = ConfigDict(from_attributes=True) - @field_validator("users_usage", mode='before') + @field_validator("users_usage", mode="before") def cast_to_int(cls, v): if v is None: # Allow None values return v @@ -43,10 +43,10 @@ def get_admin(cls, token: str, db: Session): if not payload: return - if payload['username'] in SUDOERS and payload['is_sudo'] is True: - return cls(username=payload['username'], is_sudo=True) + if payload["username"] in SUDOERS and payload["is_sudo"] is True: + return cls(username=payload["username"], is_sudo=True) - dbadmin = crud.get_admin(db, payload['username']) + dbadmin = crud.get_admin(db, payload["username"]) if not dbadmin: return @@ -59,9 +59,7 @@ def get_admin(cls, token: str, db: Session): return cls.model_validate(dbadmin) @classmethod - def get_current(cls, - db: Session = Depends(get_db), - token: str = Depends(oauth2_scheme)): + def get_current(cls, db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)): admin = cls.get_admin(token, db) if not admin: raise HTTPException( @@ -79,9 +77,7 @@ def get_current(cls, return admin @classmethod - def check_sudo_admin(cls, - db: Session = Depends(get_db), - token: str = Depends(oauth2_scheme)): + def check_sudo_admin(cls, db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)): admin = cls.get_admin(token, db) if not admin: raise HTTPException( @@ -96,10 +92,7 @@ def check_sudo_admin(cls, headers={"WWW-Authenticate": "Bearer"}, ) if not admin.is_sudo: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You're not allowed" - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You're not allowed") return admin diff --git a/app/models/host.py b/app/models/host.py index 0f17a8768..4f370a0cd 100644 --- a/app/models/host.py +++ b/app/models/host.py @@ -6,9 +6,10 @@ from app.models.proxy import ProxyTypes from pydantic import BaseModel, ConfigDict, Field, field_validator -FRAGMENT_PATTERN = re.compile(r'^((\d{1,4}-\d{1,4})|(\d{1,4})),((\d{1,3}-\d{1,3})|(\d{1,3})),(tlshello|\d|\d\-\d)$') +FRAGMENT_PATTERN = re.compile(r"^((\d{1,4}-\d{1,4})|(\d{1,4})),((\d{1,3}-\d{1,3})|(\d{1,3})),(tlshello|\d|\d\-\d)$") NOISE_PATTERN = re.compile( - r'^(rand:(\d{1,4}-\d{1,4}|\d{1,4})|str:.+|base64:.+)(,(\d{1,4}-\d{1,4}|\d{1,4}))?(&(rand:(\d{1,4}-\d{1,4}|\d{1,4})|str:.+|base64:.+)(,(\d{1,4}-\d{1,4}|\d{1,4}))?)*$') + r"^(rand:(\d{1,4}-\d{1,4}|\d{1,4})|str:.+|base64:.+)(,(\d{1,4}-\d{1,4}|\d{1,4}))?(&(rand:(\d{1,4}-\d{1,4}|\d{1,4})|str:.+|base64:.+)(,(\d{1,4}-\d{1,4}|\d{1,4}))?)*$" +) class ProxyHostSecurity(str, Enum): @@ -72,18 +73,18 @@ class BaseHost(BaseModel): use_sni_as_host: Union[bool, None] = None model_config = ConfigDict(from_attributes=True) - + class CreateHost(BaseHost): @field_validator("remark", mode="after") def validate_remark(cls, v): try: v.format_map(FormatVariables()) - except ValueError as exc: + except ValueError: raise ValueError("Invalid formatting variables") return v - + @field_validator("inbound_tag", mode="after") def validate_inbound(cls, v): if xray.config.get_inbound(v) is None: @@ -94,7 +95,7 @@ def validate_inbound(cls, v): def validate_address(cls, v): try: v.format_map(FormatVariables()) - except ValueError as exc: + except ValueError: raise ValueError("Invalid formatting variables") return v @@ -103,9 +104,7 @@ def validate_address(cls, v): @classmethod def validate_fragment(cls, v): if v and not FRAGMENT_PATTERN.match(v): - raise ValueError( - "Fragment setting must be like this: length,interval,packet (10-100,100-200,tlshello)." - ) + raise ValueError("Fragment setting must be like this: length,interval,packet (10-100,100-200,tlshello).") return v @field_validator("noise_setting", check_fields=False) @@ -113,13 +112,9 @@ def validate_fragment(cls, v): def validate_noise(cls, v): if v: if not NOISE_PATTERN.match(v): - raise ValueError( - "Noise setting must be like this: packet,delay (rand:10-20,100-200)." - ) + raise ValueError("Noise setting must be like this: packet,delay (rand:10-20,100-200).") if len(v) > 2000: - raise ValueError( - "Noise can't be longer that 2000 character" - ) + raise ValueError("Noise can't be longer that 2000 character") return v diff --git a/app/models/node.py b/app/models/node.py index d45304ce0..42f845b46 100644 --- a/app/models/node.py +++ b/app/models/node.py @@ -25,15 +25,17 @@ class Node(BaseModel): class NodeCreate(Node): - model_config = ConfigDict(json_schema_extra={ - "example": { - "name": "DE node", - "address": "192.168.1.1", - "port": 62050, - "api_port": 62051, - "usage_coefficient": 1 + model_config = ConfigDict( + json_schema_extra={ + "example": { + "name": "DE node", + "address": "192.168.1.1", + "port": 62050, + "api_port": 62051, + "usage_coefficient": 1, + } } - }) + ) class NodeModify(Node): @@ -43,16 +45,18 @@ class NodeModify(Node): api_port: Optional[int] = Field(None, nullable=True) status: Optional[NodeStatus] = Field(None, nullable=True) usage_coefficient: Optional[float] = Field(None, nullable=True) - model_config = ConfigDict(json_schema_extra={ - "example": { - "name": "DE node", - "address": "192.168.1.1", - "port": 62050, - "api_port": 62051, - "status": "disabled", - "usage_coefficient": 1.0 + model_config = ConfigDict( + json_schema_extra={ + "example": { + "name": "DE node", + "address": "192.168.1.1", + "port": 62050, + "api_port": 62051, + "status": "disabled", + "usage_coefficient": 1.0, + } } - }) + ) class NodeResponse(Node): diff --git a/app/models/user.py b/app/models/user.py index 692f11cef..1f1151faa 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -59,12 +59,8 @@ class NextPlanModel(BaseModel): class User(BaseModel): proxies: Dict[ProxyTypes, ProxySettings] = {} expire: datetime | int | None = Field(None, nullable=True) - data_limit: Optional[int] = Field( - ge=0, default=None, description="data_limit can be 0 or greater" - ) - data_limit_reset_strategy: UserDataLimitResetStrategy = ( - UserDataLimitResetStrategy.no_reset - ) + data_limit: Optional[int] = Field(ge=0, default=None, description="data_limit can be 0 or greater") + data_limit_reset_strategy: UserDataLimitResetStrategy = UserDataLimitResetStrategy.no_reset inbounds: Dict[ProxyTypes, List[str]] = {} note: Optional[str] = Field(None, nullable=True) sub_updated_at: Optional[datetime] = Field(None, nullable=True) @@ -77,7 +73,7 @@ class User(BaseModel): next_plan: Optional[NextPlanModel] = Field(None, nullable=True) - @field_validator('data_limit', mode='before') + @field_validator("data_limit", mode="before") def cast_to_int(cls, v): if v is None: # Allow None values return v @@ -91,11 +87,7 @@ def cast_to_int(cls, v): def validate_proxies(cls, v, values, **kwargs): if not v: raise ValueError("Each user needs at least one proxy") - return { - proxy_type: ProxySettings.from_dict( - proxy_type, v.get(proxy_type, {})) - for proxy_type in v - } + return {proxy_type: ProxySettings.from_dict(proxy_type, v.get(proxy_type, {})) for proxy_type in v} @field_validator("username", check_fields=False) @classmethod @@ -143,32 +135,29 @@ def validator_expire(cls, value): class UserCreate(User): username: str status: UserStatusCreate = None - model_config = ConfigDict(json_schema_extra={ - "example": { - "username": "user1234", - "proxies": { - "vmess": {"id": "35e4e39c-7d5c-4f4b-8b71-558e4f37ff53"}, - "vless": {}, - }, - "inbounds": { - "vmess": ["VMess TCP", "VMess Websocket"], - "vless": ["VLESS TCP REALITY", "VLESS GRPC REALITY"], - }, - "next_plan": { - "data_limit": 0, + model_config = ConfigDict( + json_schema_extra={ + "example": { + "username": "user1234", + "proxies": { + "vmess": {"id": "35e4e39c-7d5c-4f4b-8b71-558e4f37ff53"}, + "vless": {}, + }, + "inbounds": { + "vmess": ["VMess TCP", "VMess Websocket"], + "vless": ["VLESS TCP REALITY", "VLESS GRPC REALITY"], + }, + "next_plan": {"data_limit": 0, "expire": 0, "add_remaining_traffic": False, "fire_on_either": True}, "expire": 0, - "add_remaining_traffic": False, - "fire_on_either": True - }, - "expire": 0, - "data_limit": 0, - "data_limit_reset_strategy": "no_reset", - "status": "active", - "note": "", - "on_hold_timeout": "2023-11-03T20:30:00", - "on_hold_expire_duration": 0, + "data_limit": 0, + "data_limit_reset_strategy": "no_reset", + "status": "active", + "note": "", + "on_hold_timeout": "2023-11-03T20:30:00", + "on_hold_expire_duration": 0, + } } - }) + ) @property def excluded_inbounds(self): @@ -176,7 +165,7 @@ def excluded_inbounds(self): for proxy_type in self.proxies: excluded[proxy_type] = [] for inbound in xray.config.inbounds_by_protocol.get(proxy_type, []): - if not inbound["tag"] in self.inbounds.get(proxy_type, []): + if inbound["tag"] not in self.inbounds.get(proxy_type, []): excluded[proxy_type].append(inbound["tag"]) return excluded @@ -203,10 +192,7 @@ def validate_inbounds(cls, inbounds, values, **kwargs): # raise ValueError(f"{proxy_type} inbounds cannot be empty") else: - inbounds[proxy_type] = [ - i["tag"] - for i in xray.config.inbounds_by_protocol.get(proxy_type, []) - ] + inbounds[proxy_type] = [i["tag"] for i in xray.config.inbounds_by_protocol.get(proxy_type, [])] return inbounds @@ -215,7 +201,7 @@ def validate_status(cls, status, values): on_hold_expire = values.data.get("on_hold_expire_duration") expire = values.data.get("expire") if status == UserStatusCreate.on_hold: - if (on_hold_expire == 0 or on_hold_expire is None): + if on_hold_expire == 0 or on_hold_expire is None: raise ValueError("User cannot be on hold without a valid on_hold_expire_duration.") if expire: raise ValueError("User cannot be on hold with specified expire.") @@ -225,31 +211,28 @@ def validate_status(cls, status, values): class UserModify(User): status: UserStatusModify = None data_limit_reset_strategy: UserDataLimitResetStrategy = None - model_config = ConfigDict(json_schema_extra={ - "example": { - "proxies": { - "vmess": {"id": "35e4e39c-7d5c-4f4b-8b71-558e4f37ff53"}, - "vless": {}, - }, - "inbounds": { - "vmess": ["VMess TCP", "VMess Websocket"], - "vless": ["VLESS TCP REALITY", "VLESS GRPC REALITY"], - }, - "next_plan": { - "data_limit": 0, + model_config = ConfigDict( + json_schema_extra={ + "example": { + "proxies": { + "vmess": {"id": "35e4e39c-7d5c-4f4b-8b71-558e4f37ff53"}, + "vless": {}, + }, + "inbounds": { + "vmess": ["VMess TCP", "VMess Websocket"], + "vless": ["VLESS TCP REALITY", "VLESS GRPC REALITY"], + }, + "next_plan": {"data_limit": 0, "expire": 0, "add_remaining_traffic": False, "fire_on_either": True}, "expire": 0, - "add_remaining_traffic": False, - "fire_on_either": True - }, - "expire": 0, - "data_limit": 0, - "data_limit_reset_strategy": "no_reset", - "status": "active", - "note": "", - "on_hold_timeout": "2023-11-03T20:30:00", - "on_hold_expire_duration": 0, + "data_limit": 0, + "data_limit_reset_strategy": "no_reset", + "status": "active", + "note": "", + "on_hold_timeout": "2023-11-03T20:30:00", + "on_hold_expire_duration": 0, + } } - }) + ) @property def excluded_inbounds(self): @@ -257,7 +240,7 @@ def excluded_inbounds(self): for proxy_type in self.inbounds: excluded[proxy_type] = [] for inbound in xray.config.inbounds_by_protocol.get(proxy_type, []): - if not inbound["tag"] in self.inbounds.get(proxy_type, []): + if inbound["tag"] not in self.inbounds.get(proxy_type, []): excluded[proxy_type].append(inbound["tag"]) return excluded @@ -268,7 +251,6 @@ def validate_inbounds(cls, inbounds, values, **kwargs): # so inbounds particularly can be modified if inbounds: for proxy_type, tags in inbounds.items(): - # if not tags: # raise ValueError(f"{proxy_type} inbounds cannot be empty") @@ -280,18 +262,14 @@ def validate_inbounds(cls, inbounds, values, **kwargs): @field_validator("proxies", mode="before") def validate_proxies(cls, v): - return { - proxy_type: ProxySettings.from_dict( - proxy_type, v.get(proxy_type, {})) - for proxy_type in v - } + return {proxy_type: ProxySettings.from_dict(proxy_type, v.get(proxy_type, {})) for proxy_type in v} @field_validator("status", mode="before") def validate_status(cls, status, values): on_hold_expire = values.data.get("on_hold_expire_duration") expire = values.data.get("expire") if status == UserStatusCreate.on_hold: - if (on_hold_expire == 0 or on_hold_expire is None): + if on_hold_expire == 0 or on_hold_expire is None: raise ValueError("User cannot be on hold without a valid on_hold_expire_duration.") if expire: raise ValueError("User cannot be on hold with specified expire.") @@ -315,7 +293,7 @@ class UserResponse(User): def validate_subscription_url(self): if not self.subscription_url: salt = secrets.token_hex(8) - url_prefix = (XRAY_SUBSCRIPTION_URL_PREFIX).replace('*', salt) + url_prefix = (XRAY_SUBSCRIPTION_URL_PREFIX).replace("*", salt) token = create_subscription_token(self.username) self.subscription_url = f"{url_prefix}/{XRAY_SUBSCRIPTION_PATH}/{token}" return self @@ -326,7 +304,7 @@ def validate_proxies(cls, v, values, **kwargs): v = {p.type: p.settings for p in v} return super().validate_proxies(v, values, **kwargs) - @field_validator("used_traffic", "lifetime_used_traffic", mode='before') + @field_validator("used_traffic", "lifetime_used_traffic", mode="before") def cast_to_int(cls, v): if v is None: # Allow None values return v @@ -356,7 +334,7 @@ class UserUsageResponse(BaseModel): node_name: str used_traffic: int - @field_validator("used_traffic", mode='before') + @field_validator("used_traffic", mode="before") def cast_to_int(cls, v): if v is None: # Allow None values return v diff --git a/app/models/user_template.py b/app/models/user_template.py index 12d91b451..701bb3390 100644 --- a/app/models/user_template.py +++ b/app/models/user_template.py @@ -8,9 +8,7 @@ class UserTemplate(BaseModel): name: Optional[str] = Field(None, nullable=True) - data_limit: Optional[int] = Field( - ge=0, default=None, description="data_limit can be 0 or greater" - ) + data_limit: Optional[int] = Field(ge=0, default=None, description="data_limit can be 0 or greater") expire_duration: Optional[int] = Field( ge=0, default=None, description="expire_duration can be 0 or greater in seconds" ) @@ -21,29 +19,33 @@ class UserTemplate(BaseModel): class UserTemplateCreate(UserTemplate): - model_config = ConfigDict(json_schema_extra={ - "example": { - "name": "my template 1", - "username_prefix": None, - "username_suffix": None, - "inbounds": {"vmess": ["VMESS_INBOUND"], "vless": ["VLESS_INBOUND"]}, - "data_limit": 0, - "expire_duration": 0, + model_config = ConfigDict( + json_schema_extra={ + "example": { + "name": "my template 1", + "username_prefix": None, + "username_suffix": None, + "inbounds": {"vmess": ["VMESS_INBOUND"], "vless": ["VLESS_INBOUND"]}, + "data_limit": 0, + "expire_duration": 0, + } } - }) + ) class UserTemplateModify(UserTemplate): - model_config = ConfigDict(json_schema_extra={ - "example": { - "name": "my template 1", - "username_prefix": None, - "username_suffix": None, - "inbounds": {"vmess": ["VMESS_INBOUND"], "vless": ["VLESS_INBOUND"]}, - "data_limit": 0, - "expire_duration": 0, + model_config = ConfigDict( + json_schema_extra={ + "example": { + "name": "my template 1", + "username_prefix": None, + "username_suffix": None, + "inbounds": {"vmess": ["VMESS_INBOUND"], "vless": ["VLESS_INBOUND"]}, + "data_limit": 0, + "expire_duration": 0, + } } - }) + ) class UserTemplateResponse(UserTemplate): @@ -62,4 +64,5 @@ def validate_inbounds(cls, v): else: final[protocol] = [inbound["tag"]] return final + model_config = ConfigDict(from_attributes=True) diff --git a/app/routers/__init__.py b/app/routers/__init__.py index b0640a9b3..e0096df5b 100644 --- a/app/routers/__init__.py +++ b/app/routers/__init__.py @@ -1,11 +1,11 @@ from fastapi import APIRouter from . import ( - admin, - core, - node, - subscription, - system, - user_template, + admin, + core, + node, + subscription, + system, + user_template, user, home, host, @@ -22,10 +22,10 @@ node.router, user.router, subscription.router, - user_template.router + user_template.router, ] for router in routers: api_router.include_router(router) -__all__ = ["api_router"] \ No newline at end of file +__all__ = ["api_router"] diff --git a/app/routers/admin.py b/app/routers/admin.py index 1fa53a645..cd890979f 100644 --- a/app/routers/admin.py +++ b/app/routers/admin.py @@ -144,7 +144,8 @@ def get_admins( @router.post("/admin/{username}/users/disable", responses={403: responses._403, 404: responses._404}) def disable_all_active_users( dbadmin: Admin = Depends(get_admin_by_username), - db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin) + db: Session = Depends(get_db), + admin: Admin = Depends(Admin.check_sudo_admin), ): """Disable all active users under a specific admin""" crud.disable_all_active_users(db=db, admin=dbadmin) @@ -159,7 +160,8 @@ def disable_all_active_users( @router.post("/admin/{username}/users/activate", responses={403: responses._403, 404: responses._404}) def activate_all_disabled_users( dbadmin: Admin = Depends(get_admin_by_username), - db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin) + db: Session = Depends(get_db), + admin: Admin = Depends(Admin.check_sudo_admin), ): """Activate all disabled users under a specific admin""" crud.activate_all_disabled_users(db=db, admin=dbadmin) @@ -179,7 +181,7 @@ def activate_all_disabled_users( def reset_admin_usage( dbadmin: Admin = Depends(get_admin_by_username), db: Session = Depends(get_db), - current_admin: Admin = Depends(Admin.check_sudo_admin) + current_admin: Admin = Depends(Admin.check_sudo_admin), ): """Resets usage of admin.""" return crud.reset_admin_usage(db, dbadmin) @@ -191,8 +193,7 @@ def reset_admin_usage( responses={403: responses._403}, ) def get_admin_usage( - dbadmin: Admin = Depends(get_admin_by_username), - current_admin: Admin = Depends(Admin.check_sudo_admin) + dbadmin: Admin = Depends(get_admin_by_username), current_admin: Admin = Depends(Admin.check_sudo_admin) ): """Retrieve the usage of given admin.""" return dbadmin.users_usage diff --git a/app/routers/core.py b/app/routers/core.py index eda1cd55d..a06a38562 100644 --- a/app/routers/core.py +++ b/app/routers/core.py @@ -19,9 +19,7 @@ @router.websocket("/core/logs") async def core_logs(websocket: WebSocket, db: Session = Depends(get_db)): - token = websocket.query_params.get("token") or websocket.headers.get( - "Authorization", "" - ).removeprefix("Bearer ") + token = websocket.query_params.get("token") or websocket.headers.get("Authorization", "").removeprefix("Bearer ") admin = Admin.get_admin(token, db) if not admin: return await websocket.close(reason="Unauthorized", code=4401) @@ -36,9 +34,7 @@ async def core_logs(websocket: WebSocket, db: Session = Depends(get_db)): except ValueError: return await websocket.close(reason="Invalid interval value", code=4400) if interval > 10: - return await websocket.close( - reason="Interval must be more than 0 and at most 10 seconds", code=4400 - ) + return await websocket.close(reason="Interval must be more than 0 and at most 10 seconds", code=4400) await websocket.accept() @@ -108,9 +104,7 @@ def get_core_config(admin: Admin = Depends(Admin.check_sudo_admin)) -> dict: @router.put("/core/config", responses={403: responses._403}) -def modify_core_config( - payload: dict, admin: Admin = Depends(Admin.check_sudo_admin) -) -> dict: +def modify_core_config(payload: dict, admin: Admin = Depends(Admin.check_sudo_admin)) -> dict: """Modify the core configuration and restart the core.""" try: config = XRayConfig(payload, api_port=xray.config.api_port) diff --git a/app/routers/home.py b/app/routers/home.py index 30c7effdd..9747608a1 100644 --- a/app/routers/home.py +++ b/app/routers/home.py @@ -9,4 +9,4 @@ @router.get("/", response_class=HTMLResponse) def base(): - return render_template(HOME_PAGE_TEMPLATE) \ No newline at end of file + return render_template(HOME_PAGE_TEMPLATE) diff --git a/app/routers/host.py b/app/routers/host.py index 47d455a92..4baad6c96 100644 --- a/app/routers/host.py +++ b/app/routers/host.py @@ -13,11 +13,11 @@ router = APIRouter(tags=["Host"], prefix="/api/host", responses={401: responses._401, 403: responses._403}) -@router.post('/', response_model=HostResponse) +@router.post("/", response_model=HostResponse) def add_host( - new_host: CreateHost, - db: Session = Depends(get_db), - _: Admin = Depends(Admin.check_sudo_admin), + new_host: CreateHost, + db: Session = Depends(get_db), + _: Admin = Depends(Admin.check_sudo_admin), ): """ add a new host @@ -25,19 +25,19 @@ def add_host( **inbound_tag** must be available in one of xray config """ db_host = crud.add_host(db, new_host) - logger.info(f"Host \"{db_host.id}\" added") + logger.info(f'Host "{db_host.id}" added') xray.hosts.update() return db_host -@router.put('/{host_id}', response_model=HostResponse, responses={404: responses._404}) +@router.put("/{host_id}", response_model=HostResponse, responses={404: responses._404}) def modify_host( - modified_host: HostResponse, - db_host: ProxyHost = Depends(get_host), - db: Session = Depends(get_db), - _: Admin = Depends(Admin.check_sudo_admin), + modified_host: HostResponse, + db_host: ProxyHost = Depends(get_host), + db: Session = Depends(get_db), + _: Admin = Depends(Admin.check_sudo_admin), ): """ modify host by **id** @@ -46,35 +46,35 @@ def modify_host( """ db_host = crud.update_host(db, db_host, modified_host) - logger.info(f"Host \"{db_host.id}\" modified") + logger.info(f'Host "{db_host.id}" modified') xray.hosts.update() return db_host -@router.delete('/{host_id}', responses={404: responses._404}) +@router.delete("/{host_id}", responses={404: responses._404}) def remove_host( - db_host: ProxyHost = Depends(get_host), - db: Session = Depends(get_db), - _: Admin = Depends(Admin.check_sudo_admin), + db_host: ProxyHost = Depends(get_host), + db: Session = Depends(get_db), + _: Admin = Depends(Admin.check_sudo_admin), ): """ remove host by **id** """ crud.remove_host(db, db_host) - logger.info(f"Host \"{db_host.id}\" deleted") + logger.info(f'Host "{db_host.id}" deleted') xray.hosts.update() return {} -@router.get('/{host_id}', response_model=HostResponse) +@router.get("/{host_id}", response_model=HostResponse) def get_host( - db_host: HostResponse = Depends(get_host), - _: Admin = Depends(Admin.check_sudo_admin), + db_host: HostResponse = Depends(get_host), + _: Admin = Depends(Admin.check_sudo_admin), ): """ get host by **id** @@ -83,12 +83,12 @@ def get_host( return db_host -@router.get('s', response_model=List[HostResponse]) +@router.get("s", response_model=List[HostResponse]) def get_hosts( - offset: int = 0, - limit: int = 0, - db: Session = Depends(get_db), - _: Admin = Depends(Admin.check_sudo_admin), + offset: int = 0, + limit: int = 0, + db: Session = Depends(get_db), + _: Admin = Depends(Admin.check_sudo_admin), ): """ Get proxy hosts. @@ -98,9 +98,9 @@ def get_hosts( @router.put("s", response_model=List[HostResponse]) def modify_hosts( - modified_hosts: List[HostResponse], - db: Session = Depends(get_db), - _: Admin = Depends(Admin.check_sudo_admin), + modified_hosts: List[HostResponse], + db: Session = Depends(get_db), + _: Admin = Depends(Admin.check_sudo_admin), ): """ Modify proxy hosts and update the configuration. @@ -111,10 +111,10 @@ def modify_hosts( db_host = crud.get_host_by_id(db, host.id) if db_host: crud.update_host(db, db_host, host) - logger.info(f"Host \"{db_host.id}\" modified") + logger.info(f'Host "{db_host.id}" modified') else: db_host = crud.add_host(db, host) - logger.info(f"Host \"{db_host.id}\" added") + logger.info(f'Host "{db_host.id}" added') xray.hosts.update() diff --git a/app/routers/node.py b/app/routers/node.py index 21eb6321f..f95419bbd 100644 --- a/app/routers/node.py +++ b/app/routers/node.py @@ -20,15 +20,11 @@ ) from app.utils import responses -router = APIRouter( - tags=["Node"], prefix="/api", responses={401: responses._401, 403: responses._403} -) +router = APIRouter(tags=["Node"], prefix="/api", responses={401: responses._401, 403: responses._403}) @router.get("/node/settings", response_model=NodeSettings) -def get_node_settings( - db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin) -): +def get_node_settings(db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin)): """Retrieve the current node settings, including TLS certificate.""" tls = crud.get_tls_certificate(db) return NodeSettings(certificate=tls.certificate) @@ -46,9 +42,7 @@ def add_node( dbnode = crud.create_node(db, new_node) except IntegrityError: db.rollback() - raise HTTPException( - status_code=409, detail=f'Node "{new_node.name}" already exists' - ) + raise HTTPException(status_code=409, detail=f'Node "{new_node.name}" already exists') bg.add_task(xray.operations.connect_node, node_id=dbnode.id) @@ -67,9 +61,7 @@ def get_node( @router.websocket("/node/{node_id}/logs") async def node_logs(node_id: int, websocket: WebSocket, db: Session = Depends(get_db)): - token = websocket.query_params.get("token") or websocket.headers.get( - "Authorization", "" - ).removeprefix("Bearer ") + token = websocket.query_params.get("token") or websocket.headers.get("Authorization", "").removeprefix("Bearer ") admin = Admin.get_admin(token, db) if not admin: return await websocket.close(reason="Unauthorized", code=4401) @@ -90,9 +82,7 @@ async def node_logs(node_id: int, websocket: WebSocket, db: Session = Depends(ge except ValueError: return await websocket.close(reason="Invalid interval value", code=4400) if interval > 10: - return await websocket.close( - reason="Interval must be more than 0 and at most 10 seconds", code=4400 - ) + return await websocket.close(reason="Interval must be more than 0 and at most 10 seconds", code=4400) await websocket.accept() @@ -134,9 +124,7 @@ async def node_logs(node_id: int, websocket: WebSocket, db: Session = Depends(ge @router.get("/nodes", response_model=List[NodeResponse]) -def get_nodes( - db: Session = Depends(get_db), _: Admin = Depends(Admin.check_sudo_admin) -): +def get_nodes(db: Session = Depends(get_db), _: Admin = Depends(Admin.check_sudo_admin)): """Retrieve a list of all nodes. Accessible only to sudo admins.""" return crud.get_nodes(db) diff --git a/app/routers/subscription.py b/app/routers/subscription.py index aa96bc2e1..cb704f81b 100644 --- a/app/routers/subscription.py +++ b/app/routers/subscription.py @@ -29,11 +29,15 @@ "clash": {"config_format": "clash", "media_type": "text/yaml", "as_base64": False, "reverse": False}, "v2ray": {"config_format": "v2ray", "media_type": "text/plain", "as_base64": True, "reverse": False}, "outline": {"config_format": "outline", "media_type": "application/json", "as_base64": False, "reverse": False}, - "v2ray-json": {"config_format": "v2ray-json", "media_type": "application/json", "as_base64": False, - "reverse": False} + "v2ray-json": { + "config_format": "v2ray-json", + "media_type": "application/json", + "as_base64": False, + "reverse": False, + }, } -router = APIRouter(tags=['Subscription'], prefix=f'/{XRAY_SUBSCRIPTION_PATH}') +router = APIRouter(tags=["Subscription"], prefix=f"/{XRAY_SUBSCRIPTION_PATH}") def get_subscription_user_info(user: UserResponse) -> dict: @@ -52,7 +56,7 @@ def user_subscription( request: Request, db: Session = Depends(get_db), dbuser: UserResponse = Depends(get_validated_sub), - user_agent: str = Header(default="") + user_agent: str = Header(default=""), ): """Provides a subscription link based on the user agent (Clash, V2Ray, etc.).""" user: UserResponse = UserResponse.model_validate(dbuser) @@ -60,12 +64,7 @@ def user_subscription( accept_header = request.headers.get("Accept", "") if "text/html" in accept_header: links = generate_subscription(user=user, config_format="v2ray", as_base64=False, reverse=False) - return HTMLResponse( - render_template( - SUBSCRIPTION_PAGE_TEMPLATE, - {"user": user, "links": links.split("\n")} - ) - ) + return HTMLResponse(render_template(SUBSCRIPTION_PAGE_TEMPLATE, {"user": user, "links": links.split("\n")})) crud.update_user_sub(db, dbuser, user_agent) response_headers = { @@ -74,30 +73,27 @@ def user_subscription( "support-url": SUB_SUPPORT_URL, "profile-title": encode_title(SUB_PROFILE_TITLE), "profile-update-interval": SUB_UPDATE_INTERVAL, - "subscription-userinfo": "; ".join( - f"{key}={val}" - for key, val in get_subscription_user_info(user).items() - ) + "subscription-userinfo": "; ".join(f"{key}={val}" for key, val in get_subscription_user_info(user).items()), } - if re.match(r'^([Cc]lash-verge|[Cc]lash[-\.]?[Mm]eta|[Ff][Ll][Cc]lash|[Mm]ihomo)', user_agent): + if re.match(r"^([Cc]lash-verge|[Cc]lash[-\.]?[Mm]eta|[Ff][Ll][Cc]lash|[Mm]ihomo)", user_agent): conf = generate_subscription(user=user, config_format="clash-meta", as_base64=False, reverse=False) return Response(content=conf, media_type="text/yaml", headers=response_headers) - elif re.match(r'^([Cc]lash|[Ss]tash)', user_agent): + elif re.match(r"^([Cc]lash|[Ss]tash)", user_agent): conf = generate_subscription(user=user, config_format="clash", as_base64=False, reverse=False) return Response(content=conf, media_type="text/yaml", headers=response_headers) - elif re.match(r'^(SFA|SFI|SFM|SFT|[Kk]aring|[Hh]iddify[Nn]ext)', user_agent): + elif re.match(r"^(SFA|SFI|SFM|SFT|[Kk]aring|[Hh]iddify[Nn]ext)", user_agent): conf = generate_subscription(user=user, config_format="sing-box", as_base64=False, reverse=False) return Response(content=conf, media_type="application/json", headers=response_headers) - elif re.match(r'^(SS|SSR|SSD|SSS|Outline|Shadowsocks|SSconf)', user_agent): + elif re.match(r"^(SS|SSR|SSD|SSS|Outline|Shadowsocks|SSconf)", user_agent): conf = generate_subscription(user=user, config_format="outline", as_base64=False, reverse=False) return Response(content=conf, media_type="application/json", headers=response_headers) - elif (USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_V2RAYN) and re.match(r'^v2rayN/(\d+\.\d+)', user_agent): - version_str = re.match(r'^v2rayN/(\d+\.\d+)', user_agent).group(1) + elif (USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_V2RAYN) and re.match(r"^v2rayN/(\d+\.\d+)", user_agent): + version_str = re.match(r"^v2rayN/(\d+\.\d+)", user_agent).group(1) if LooseVersion(version_str) >= LooseVersion("6.40"): conf = generate_subscription(user=user, config_format="v2ray-json", as_base64=False, reverse=False) return Response(content=conf, media_type="application/json", headers=response_headers) @@ -105,8 +101,8 @@ def user_subscription( conf = generate_subscription(user=user, config_format="v2ray", as_base64=True, reverse=False) return Response(content=conf, media_type="text/plain", headers=response_headers) - elif (USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_V2RAYNG) and re.match(r'^v2rayNG/(\d+\.\d+\.\d+)', user_agent): - version_str = re.match(r'^v2rayNG/(\d+\.\d+\.\d+)', user_agent).group(1) + elif (USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_V2RAYNG) and re.match(r"^v2rayNG/(\d+\.\d+\.\d+)", user_agent): + version_str = re.match(r"^v2rayNG/(\d+\.\d+\.\d+)", user_agent).group(1) if LooseVersion(version_str) >= LooseVersion("1.8.29"): conf = generate_subscription(user=user, config_format="v2ray-json", as_base64=False, reverse=False) return Response(content=conf, media_type="application/json", headers=response_headers) @@ -117,7 +113,7 @@ def user_subscription( conf = generate_subscription(user=user, config_format="v2ray", as_base64=True, reverse=False) return Response(content=conf, media_type="text/plain", headers=response_headers) - elif re.match(r'^[Ss]treisand', user_agent): + elif re.match(r"^[Ss]treisand", user_agent): if USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_STREISAND: conf = generate_subscription(user=user, config_format="v2ray-json", as_base64=False, reverse=False) return Response(content=conf, media_type="application/json", headers=response_headers) @@ -125,8 +121,8 @@ def user_subscription( conf = generate_subscription(user=user, config_format="v2ray", as_base64=True, reverse=False) return Response(content=conf, media_type="text/plain", headers=response_headers) - elif (USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_HAPP) and re.match(r'^Happ/(\d+\.\d+\.\d+)', user_agent): - version_str = re.match(r'^Happ/(\d+\.\d+\.\d+)', user_agent).group(1) + elif (USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_HAPP) and re.match(r"^Happ/(\d+\.\d+\.\d+)", user_agent): + version_str = re.match(r"^Happ/(\d+\.\d+\.\d+)", user_agent).group(1) if LooseVersion(version_str) >= LooseVersion("1.11.0"): conf = generate_subscription(user=user, config_format="v2ray-json", as_base64=False, reverse=False) return Response(content=conf, media_type="application/json", headers=response_headers) @@ -157,10 +153,7 @@ def user_subscription_info( @router.get("/{token}/usage") def user_get_usage( - dbuser: UserResponse = Depends(get_validated_sub), - start: str = "", - end: str = "", - db: Session = Depends(get_db) + dbuser: UserResponse = Depends(get_validated_sub), start: str = "", end: str = "", db: Session = Depends(get_db) ): """Fetches the usage statistics for the user within a specified date range.""" start, end = validate_dates(start, end) @@ -176,7 +169,7 @@ def user_subscription_with_client_type( dbuser: UserResponse = Depends(get_validated_sub), client_type: str = Path(..., regex="sing-box|clash-meta|clash|outline|v2ray|v2ray-json"), db: Session = Depends(get_db), - user_agent: str = Header(default="") + user_agent: str = Header(default=""), ): """Provides a subscription link based on the specified client type (e.g., Clash, V2Ray).""" user: UserResponse = UserResponse.model_validate(dbuser) @@ -187,16 +180,12 @@ def user_subscription_with_client_type( "support-url": SUB_SUPPORT_URL, "profile-title": encode_title(SUB_PROFILE_TITLE), "profile-update-interval": SUB_UPDATE_INTERVAL, - "subscription-userinfo": "; ".join( - f"{key}={val}" - for key, val in get_subscription_user_info(user).items() - ) + "subscription-userinfo": "; ".join(f"{key}={val}" for key, val in get_subscription_user_info(user).items()), } config = client_config.get(client_type) - conf = generate_subscription(user=user, - config_format=config["config_format"], - as_base64=config["as_base64"], - reverse=config["reverse"]) + conf = generate_subscription( + user=user, config_format=config["config_format"], as_base64=config["as_base64"], reverse=config["reverse"] + ) return Response(content=conf, media_type=config["media_type"], headers=response_headers) diff --git a/app/routers/system.py b/app/routers/system.py index 15a8feac2..870a58c20 100644 --- a/app/routers/system.py +++ b/app/routers/system.py @@ -17,9 +17,7 @@ @router.get("/system", response_model=SystemStats) -def get_system_stats( - db: Session = Depends(get_db), admin: Admin = Depends(Admin.get_current) -): +def get_system_stats(db: Session = Depends(get_db), admin: Admin = Depends(Admin.get_current)): """Fetch system stats including memory, CPU, and user metrics.""" mem = memory_usage() cpu = cpu_usage() @@ -27,21 +25,11 @@ def get_system_stats( dbadmin: Union[Admin, None] = crud.get_admin(db, admin.username) total_user = crud.get_users_count(db, admin=dbadmin if not admin.is_sudo else None) - users_active = crud.get_users_count( - db, status=UserStatus.active, admin=dbadmin if not admin.is_sudo else None - ) - users_disabled = crud.get_users_count( - db, status=UserStatus.disabled, admin=dbadmin if not admin.is_sudo else None - ) - users_on_hold = crud.get_users_count( - db, status=UserStatus.on_hold, admin=dbadmin if not admin.is_sudo else None - ) - users_expired = crud.get_users_count( - db, status=UserStatus.expired, admin=dbadmin if not admin.is_sudo else None - ) - users_limited = crud.get_users_count( - db, status=UserStatus.limited, admin=dbadmin if not admin.is_sudo else None - ) + users_active = crud.get_users_count(db, status=UserStatus.active, admin=dbadmin if not admin.is_sudo else None) + users_disabled = crud.get_users_count(db, status=UserStatus.disabled, admin=dbadmin if not admin.is_sudo else None) + users_on_hold = crud.get_users_count(db, status=UserStatus.on_hold, admin=dbadmin if not admin.is_sudo else None) + users_expired = crud.get_users_count(db, status=UserStatus.expired, admin=dbadmin if not admin.is_sudo else None) + users_limited = crud.get_users_count(db, status=UserStatus.limited, admin=dbadmin if not admin.is_sudo else None) online_users = crud.count_online_users(db, timedelta(minutes=2)) realtime_bandwidth_stats = realtime_bandwidth() diff --git a/app/routers/user.py b/app/routers/user.py index 42055e520..9741ff834 100644 --- a/app/routers/user.py +++ b/app/routers/user.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime from typing import List, Optional, Union from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query @@ -28,7 +28,6 @@ def add_user( bg: BackgroundTasks, db: Session = Depends(get_db), admin: Admin = Depends(Admin.get_current), - ): """ Add a new user @@ -54,14 +53,12 @@ def add_user( status_code=400, detail=f"Protocol {proxy_type} is disabled on your server", ) - - if new_user.next_plan != None and new_user.next_plan.user_template_id != None: + + if new_user.next_plan is not None and new_user.next_plan.user_template_id is not None: get_user_template(new_user.next_plan.user_template_id) try: - dbuser = crud.create_user( - db, new_user, admin=crud.get_admin(db, admin.username) - ) + dbuser = crud.create_user(db, new_user, admin=crud.get_admin(db, admin.username)) except IntegrityError: db.rollback() raise HTTPException(status_code=409, detail="User already exists") @@ -79,7 +76,11 @@ def get_user(dbuser: UserResponse = Depends(get_validated_user)): return dbuser -@router.put("/user/{username}", response_model=UserResponse, responses={400: responses._400, 403: responses._403, 404: responses._404}) +@router.put( + "/user/{username}", + response_model=UserResponse, + responses={400: responses._400, 403: responses._403, 404: responses._404}, +) def modify_user( modified_user: UserModify, bg: BackgroundTasks, @@ -111,8 +112,8 @@ def modify_user( status_code=400, detail=f"Protocol {proxy_type} is disabled on your server", ) - - if modified_user.next_plan != None and modified_user.next_plan.user_template_id != None: + + if modified_user.next_plan is not None and modified_user.next_plan.user_template_id is not None: get_user_template(modified_user.next_plan.user_template_id) old_status = dbuser.status @@ -137,9 +138,7 @@ def modify_user( user_admin=dbuser.admin, by=admin, ) - logger.info( - f'User "{dbuser.username}" status changed from {old_status.value} to {user.status.value}' - ) + logger.info(f'User "{dbuser.username}" status changed from {old_status.value} to {user.status.value}') return user @@ -155,15 +154,15 @@ def remove_user( crud.remove_user(db, dbuser) bg.add_task(xray.operations.remove_user, dbuser=dbuser) - bg.add_task( - report.user_deleted, username=dbuser.username, user_admin=Admin.model_validate(dbuser.admin), by=admin - ) + bg.add_task(report.user_deleted, username=dbuser.username, user_admin=Admin.model_validate(dbuser.admin), by=admin) logger.info(f'User "{dbuser.username}" deleted') return {"detail": "User successfully deleted"} -@router.post("/user/{username}/reset", response_model=UserResponse, responses={403: responses._403, 404: responses._404}) +@router.post( + "/user/{username}/reset", response_model=UserResponse, responses={403: responses._403, 404: responses._404} +) def reset_user_data_usage( bg: BackgroundTasks, db: Session = Depends(get_db), @@ -176,15 +175,15 @@ def reset_user_data_usage( bg.add_task(xray.operations.add_user, dbuser=dbuser) user = UserResponse.model_validate(dbuser) - bg.add_task( - report.user_data_usage_reset, user=user, user_admin=dbuser.admin, by=admin - ) + bg.add_task(report.user_data_usage_reset, user=user, user_admin=dbuser.admin, by=admin) logger.info(f'User "{dbuser.username}"\'s usage was reset') return dbuser -@router.post("/user/{username}/revoke_sub", response_model=UserResponse, responses={403: responses._403, 404: responses._404}) +@router.post( + "/user/{username}/revoke_sub", response_model=UserResponse, responses={403: responses._403, 404: responses._404} +) def revoke_user_subscription( bg: BackgroundTasks, db: Session = Depends(get_db), @@ -197,16 +196,16 @@ def revoke_user_subscription( if dbuser.status in [UserStatus.active, UserStatus.on_hold]: bg.add_task(xray.operations.update_user, dbuser=dbuser) user = UserResponse.model_validate(dbuser) - bg.add_task( - report.user_subscription_revoked, user=user, user_admin=dbuser.admin, by=admin - ) + bg.add_task(report.user_subscription_revoked, user=user, user_admin=dbuser.admin, by=admin) logger.info(f'User "{dbuser.username}" subscription revoked') return user -@router.get("/users", response_model=UsersResponse, responses={400: responses._400, 403: responses._403, 404: responses._404}) +@router.get( + "/users", response_model=UsersResponse, responses={400: responses._400, 403: responses._403, 404: responses._404} +) def get_users( offset: int = None, limit: int = None, @@ -226,9 +225,7 @@ def get_users( try: sort.append(crud.UsersSortingOptions[opt]) except KeyError: - raise HTTPException( - status_code=400, detail=f'"{opt}" is not a valid sort option' - ) + raise HTTPException(status_code=400, detail=f'"{opt}" is not a valid sort option') users, count = crud.get_users( db=db, @@ -246,9 +243,7 @@ def get_users( @router.post("/users/reset", responses={403: responses._403, 404: responses._404}) -def reset_users_data_usage( - db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin) -): +def reset_users_data_usage(db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin)): """Reset all users data usage""" dbadmin = crud.get_admin(db, admin.username) crud.reset_all_users_data_usage(db=db, admin=dbadmin) @@ -260,7 +255,9 @@ def reset_users_data_usage( return {"detail": "Users successfully reset."} -@router.get("/user/{username}/usage", response_model=UserUsagesResponse, responses={403: responses._403, 404: responses._404}) +@router.get( + "/user/{username}/usage", response_model=UserUsagesResponse, responses={403: responses._403, 404: responses._404} +) def get_user_usage( start: str = "", end: str = "", @@ -275,7 +272,9 @@ def get_user_usage( return {"usages": usages, "username": dbuser.username} -@router.post("/user/{username}/active-next", response_model=UserResponse, responses={403: responses._403, 404: responses._404}) +@router.post( + "/user/{username}/active-next", response_model=UserResponse, responses={403: responses._403, 404: responses._404} +) def active_next_plan( bg: BackgroundTasks, db: Session = Depends(get_db), @@ -286,7 +285,7 @@ def active_next_plan( if dbuser is None or dbuser.next_plan is None: raise HTTPException( status_code=404, - detail=f"User doesn't have next plan", + detail="User doesn't have next plan", ) dbuser = crud.reset_user_by_next(db=db, dbuser=dbuser) @@ -296,7 +295,9 @@ def active_next_plan( user = UserResponse.model_validate(dbuser) bg.add_task( - report.user_data_reset_by_next, user=user, user_admin=dbuser.admin, + report.user_data_reset_by_next, + user=user, + user_admin=dbuser.admin, ) logger.info(f'User "{dbuser.username}"\'s usage was reset by next plan') @@ -314,9 +315,7 @@ def get_users_usage( """Get all users usage""" start, end = validate_dates(start, end) - usages = crud.get_all_users_usages( - db=db, start=start, end=end, admin=owner if admin.is_sudo else [admin.username] - ) + usages = crud.get_all_users_usages(db=db, start=start, end=end, admin=owner if admin.is_sudo else [admin.username]) return {"usages": usages} @@ -384,9 +383,7 @@ def delete_expired_users( removed_users = [u.username for u in expired_users] if not removed_users: - raise HTTPException( - status_code=404, detail="No expired users found in the specified date range" - ) + raise HTTPException(status_code=404, detail="No expired users found in the specified date range") crud.remove_users(db, expired_users) @@ -395,9 +392,7 @@ def delete_expired_users( bg.add_task( report.user_deleted, username=removed_user, - user_admin=next( - (u.admin for u in expired_users if u.username == removed_user), None - ), + user_admin=next((u.admin for u in expired_users if u.username == removed_user), None), by=admin, ) diff --git a/app/routers/user_template.py b/app/routers/user_template.py index 7b7b3f4b7..fbcfd912f 100644 --- a/app/routers/user_template.py +++ b/app/routers/user_template.py @@ -5,17 +5,15 @@ from app.db import Session, crud, get_db from app.models.admin import Admin -from app.models.user_template import (UserTemplateCreate, UserTemplateModify, - UserTemplateResponse) +from app.models.user_template import UserTemplateCreate, UserTemplateModify, UserTemplateResponse from app.dependencies import get_user_template -router = APIRouter(tags=['User Template'], prefix='/api') +router = APIRouter(tags=["User Template"], prefix="/api") + @router.post("/user_template", response_model=UserTemplateResponse) def add_user_template( - new_user_template: UserTemplateCreate, - db: Session = Depends(get_db), - admin: Admin = Depends(Admin.check_sudo_admin) + new_user_template: UserTemplateCreate, db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin) ): """ Add a new user template @@ -34,8 +32,8 @@ def add_user_template( @router.get("/user_template/{template_id}", response_model=UserTemplateResponse) def get_user_template_endpoint( - dbuser_template: UserTemplateResponse = Depends(get_user_template), - admin: Admin = Depends(Admin.get_current)): + dbuser_template: UserTemplateResponse = Depends(get_user_template), admin: Admin = Depends(Admin.get_current) +): """Get User Template information with id""" return dbuser_template @@ -45,7 +43,7 @@ def modify_user_template( modify_user_template: UserTemplateModify, db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin), - dbuser_template: UserTemplateResponse = Depends(get_user_template) + dbuser_template: UserTemplateResponse = Depends(get_user_template), ): """ Modify User Template @@ -66,7 +64,7 @@ def modify_user_template( def remove_user_template( db: Session = Depends(get_db), admin: Admin = Depends(Admin.check_sudo_admin), - dbuser_template: UserTemplateResponse = Depends(get_user_template) + dbuser_template: UserTemplateResponse = Depends(get_user_template), ): """Remove a User Template by its ID""" return crud.remove_user_template(db, dbuser_template) @@ -74,10 +72,7 @@ def remove_user_template( @router.get("/user_template", response_model=List[UserTemplateResponse]) def get_user_templates( - offset: int = None, - limit: int = None, - db: Session = Depends(get_db), - admin: Admin = Depends(Admin.get_current) + offset: int = None, limit: int = None, db: Session = Depends(get_db), admin: Admin = Depends(Admin.get_current) ): """Get a list of User Templates with optional pagination""" return crud.get_user_templates(db, offset, limit) diff --git a/app/subscription/clash.py b/app/subscription/clash.py index b987d5eb2..c64a576dc 100644 --- a/app/subscription/clash.py +++ b/app/subscription/clash.py @@ -20,17 +20,17 @@ class ClashConfiguration(object): def __init__(self): self.data = { - 'proxies': [], - 'proxy-groups': [], + "proxies": [], + "proxy-groups": [], # Some clients rely on "rules" option and will fail without it. - 'rules': [] + "rules": [], } self.proxy_remarks = [] self.mux_template = render_template(MUX_TEMPLATE) user_agent_data = json.loads(render_template(USER_AGENT_TEMPLATE)) - if 'list' in user_agent_data and isinstance(user_agent_data['list'], list): - self.user_agent_list = user_agent_data['list'] + if "list" in user_agent_data and isinstance(user_agent_data["list"], list): + self.user_agent_list = user_agent_data["list"] else: self.user_agent_list = [] @@ -43,17 +43,13 @@ def __init__(self): def render(self, reverse=False): if reverse: - self.data['proxies'].reverse() + self.data["proxies"].reverse() yaml.add_representer(UUID, yml_uuid_representer) return yaml.dump( yaml.load( - render_template( - CLASH_SUBSCRIPTION_TEMPLATE, - {"conf": self.data, "proxy_remarks": self.proxy_remarks} - ), - Loader=yaml.SafeLoader - + render_template(CLASH_SUBSCRIPTION_TEMPLATE, {"conf": self.data, "proxy_remarks": self.proxy_remarks}), + Loader=yaml.SafeLoader, ), sort_keys=False, allow_unicode=True, @@ -66,24 +62,22 @@ def __repr__(self) -> str: return self.render() def _remark_validation(self, remark): - if not remark in self.proxy_remarks: + if remark not in self.proxy_remarks: return remark c = 2 while True: - new = f'{remark} ({c})' - if not new in self.proxy_remarks: + new = f"{remark} ({c})" + if new not in self.proxy_remarks: return new c += 1 def http_config( - self, - path="", - host="", - random_user_agent: bool = False, + self, + path="", + host="", + random_user_agent: bool = False, ): - config = copy.deepcopy(self.settings.get("http-opts", { - 'headers': {} - })) + config = copy.deepcopy(self.settings.get("http-opts", {"headers": {}})) if path: config["path"] = [path] @@ -97,13 +91,13 @@ def http_config( return config def ws_config( - self, - path="", - host="", - max_early_data=None, - early_data_header_name="", - is_httpupgrade: bool = False, - random_user_agent: bool = False, + self, + path="", + host="", + max_early_data=None, + early_data_header_name="", + is_httpupgrade: bool = False, + random_user_agent: bool = False, ): config = copy.deepcopy(self.settings.get("ws-opts", {})) if (host or random_user_agent) and "headers" not in config: @@ -150,78 +144,72 @@ def tcp_config(self, path="", host=""): return config - def make_node(self, - name: str, - remark: str, - type: str, - server: str, - port: int, - network: str, - tls: bool, - sni: str, - host: str, - path: str, - headers: str = '', - udp: bool = True, - alpn: str = '', - ais: bool = '', - mux_enable: bool = False, - random_user_agent: bool = False): - + def make_node( + self, + name: str, + remark: str, + type: str, + server: str, + port: int, + network: str, + tls: bool, + sni: str, + host: str, + path: str, + headers: str = "", + udp: bool = True, + alpn: str = "", + ais: bool = "", + mux_enable: bool = False, + random_user_agent: bool = False, + ): if network in ["grpc", "gun"]: path = get_grpc_gun(path) - if type == 'shadowsocks': - type = 'ss' + if type == "shadowsocks": + type = "ss" if network in ("http", "h2", "h3"): network = "h2" - if network in ('tcp', 'raw') and headers == 'http': - network = 'http' - if network == 'httpupgrade': - network = 'ws' + if network in ("tcp", "raw") and headers == "http": + network = "http" + if network == "httpupgrade": + network = "ws" is_httpupgrade = True else: is_httpupgrade = False - node = { - 'name': remark, - 'type': type, - 'server': server, - 'port': port, - 'network': network, - 'udp': udp - } + node = {"name": remark, "type": type, "server": server, "port": port, "network": network, "udp": udp} if "?ed=" in path: path, max_early_data = path.split("?ed=") - max_early_data, = max_early_data.split("/") + (max_early_data,) = max_early_data.split("/") max_early_data = int(max_early_data) early_data_header_name = "Sec-WebSocket-Protocol" else: max_early_data = None early_data_header_name = "" - if type == 'ss': # shadowsocks + if type == "ss": # shadowsocks return node if tls: - node['tls'] = True - if type == 'trojan': - node['sni'] = sni + node["tls"] = True + if type == "trojan": + node["sni"] = sni else: - node['servername'] = sni + node["servername"] = sni if alpn: - node['alpn'] = alpn.split(',') + node["alpn"] = alpn.split(",") if ais: - node['skip-cert-verify'] = ais + node["skip-cert-verify"] = ais - if network == 'http': + if network == "http": net_opts = self.http_config( path=path, host=host, random_user_agent=random_user_agent, ) - elif network == 'ws': + elif network == "ws": net_opts = self.ws_config( path=path, host=host, @@ -231,31 +219,31 @@ def make_node(self, random_user_agent=random_user_agent, ) - elif network == 'grpc' or network == 'gun': + elif network == "grpc" or network == "gun": net_opts = self.grpc_config(path=path) - elif network == 'h2': + elif network == "h2": net_opts = self.h2_config(path=path, host=host) - elif network in ('tcp', 'raw'): + elif network in ("tcp", "raw"): net_opts = self.tcp_config(path=path, host=host) else: net_opts = {} - node[f'{network}-opts'] = net_opts + node[f"{network}-opts"] = net_opts mux_json = json.loads(self.mux_template) mux_config = mux_json["clash"] if mux_enable: - node['smux'] = mux_config + node["smux"] = mux_config return node def add(self, remark: str, address: str, inbound: dict, settings: dict): # not supported by clash - if inbound['network'] in ("kcp", "splithttp", "xhttp"): + if inbound["network"] in ("kcp", "splithttp", "xhttp"): return proxy_remark = self._remark_validation(remark) @@ -263,62 +251,64 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): node = self.make_node( name=remark, remark=proxy_remark, - type=inbound['protocol'], + type=inbound["protocol"], server=address, - port=inbound['port'], - network=inbound['network'], - tls=(inbound['tls'] == 'tls'), - sni=inbound['sni'], - host=inbound['host'], - path=inbound['path'], - headers=inbound['header_type'], + port=inbound["port"], + network=inbound["network"], + tls=(inbound["tls"] == "tls"), + sni=inbound["sni"], + host=inbound["host"], + path=inbound["path"], + headers=inbound["header_type"], udp=True, - alpn=inbound.get('alpn', ''), - ais=inbound.get('ais', False), - mux_enable=inbound.get('mux_enable', False), - random_user_agent=inbound.get("random_user_agent") + alpn=inbound.get("alpn", ""), + ais=inbound.get("ais", False), + mux_enable=inbound.get("mux_enable", False), + random_user_agent=inbound.get("random_user_agent"), ) - if inbound['protocol'] == 'vmess': - node['uuid'] = settings['id'] - node['alterId'] = 0 - node['cipher'] = 'auto' + if inbound["protocol"] == "vmess": + node["uuid"] = settings["id"] + node["alterId"] = 0 + node["cipher"] = "auto" - elif inbound['protocol'] == 'trojan': - node['password'] = settings['password'] + elif inbound["protocol"] == "trojan": + node["password"] = settings["password"] - elif inbound['protocol'] == 'shadowsocks': - node['password'] = settings['password'] - node['cipher'] = settings['method'] + elif inbound["protocol"] == "shadowsocks": + node["password"] = settings["password"] + node["cipher"] = settings["method"] else: return - self.data['proxies'].append(node) + self.data["proxies"].append(node) self.proxy_remarks.append(proxy_remark) class ClashMetaConfiguration(ClashConfiguration): - def make_node(self, - name: str, - remark: str, - type: str, - server: str, - port: int, - network: str, - tls: bool, - sni: str, - host: str, - path: str, - headers: str = '', - udp: bool = True, - alpn: str = '', - fp: str = '', - pbk: str = '', - sid: str = '', - ais: bool = '', - mux_enable: bool = False, - random_user_agent: bool = False): + def make_node( + self, + name: str, + remark: str, + type: str, + server: str, + port: int, + network: str, + tls: bool, + sni: str, + host: str, + path: str, + headers: str = "", + udp: bool = True, + alpn: str = "", + fp: str = "", + pbk: str = "", + sid: str = "", + ais: bool = "", + mux_enable: bool = False, + random_user_agent: bool = False, + ): node = super().make_node( name=name, remark=remark, @@ -335,18 +325,20 @@ def make_node(self, alpn=alpn, ais=ais, mux_enable=mux_enable, - random_user_agent=random_user_agent + random_user_agent=random_user_agent, ) if fp: - node['client-fingerprint'] = fp + node["client-fingerprint"] = fp if pbk: - node['reality-opts'] = {"public-key": pbk, "short-id": sid} + node["reality-opts"] = {"public-key": pbk, "short-id": sid} return node def add(self, remark: str, address: str, inbound: dict, settings: dict): # not supported by clash-meta - if inbound['network'] in ("kcp", "splithttp", "xhttp") or (inbound['network'] == "quic" and inbound["header_type"] != "none"): + if inbound["network"] in ("kcp", "splithttp", "xhttp") or ( + inbound["network"] == "quic" and inbound["header_type"] != "none" + ): return proxy_remark = self._remark_validation(remark) @@ -354,45 +346,49 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): node = self.make_node( name=remark, remark=proxy_remark, - type=inbound['protocol'], + type=inbound["protocol"], server=address, - port=inbound['port'], - network=inbound['network'], - tls=(inbound['tls'] in ('tls', 'reality')), - sni=inbound['sni'], - host=inbound['host'], - path=inbound['path'], - headers=inbound['header_type'], + port=inbound["port"], + network=inbound["network"], + tls=(inbound["tls"] in ("tls", "reality")), + sni=inbound["sni"], + host=inbound["host"], + path=inbound["path"], + headers=inbound["header_type"], udp=True, - alpn=inbound.get('alpn', ''), - fp=inbound.get('fp', ''), - pbk=inbound.get('pbk', ''), - sid=inbound.get('sid', ''), - ais=inbound.get('ais', False), - mux_enable=inbound.get('mux_enable', False), - random_user_agent=inbound.get("random_user_agent") + alpn=inbound.get("alpn", ""), + fp=inbound.get("fp", ""), + pbk=inbound.get("pbk", ""), + sid=inbound.get("sid", ""), + ais=inbound.get("ais", False), + mux_enable=inbound.get("mux_enable", False), + random_user_agent=inbound.get("random_user_agent"), ) - if inbound['protocol'] == 'vmess': - node['uuid'] = settings['id'] - node['alterId'] = 0 - node['cipher'] = 'auto' + if inbound["protocol"] == "vmess": + node["uuid"] = settings["id"] + node["alterId"] = 0 + node["cipher"] = "auto" - elif inbound['protocol'] == 'vless': - node['uuid'] = settings['id'] + elif inbound["protocol"] == "vless": + node["uuid"] = settings["id"] - if inbound['network'] in ('tcp', 'raw', 'kcp') and inbound['header_type'] != 'http' and inbound['tls'] != 'none': - node['flow'] = settings.get('flow', '') + if ( + inbound["network"] in ("tcp", "raw", "kcp") + and inbound["header_type"] != "http" + and inbound["tls"] != "none" + ): + node["flow"] = settings.get("flow", "") - elif inbound['protocol'] == 'trojan': - node['password'] = settings['password'] + elif inbound["protocol"] == "trojan": + node["password"] = settings["password"] - elif inbound['protocol'] == 'shadowsocks': - node['password'] = settings['password'] - node['cipher'] = settings['method'] + elif inbound["protocol"] == "shadowsocks": + node["password"] = settings["password"] + node["cipher"] = settings["method"] else: return - self.data['proxies'].append(node) + self.data["proxies"].append(node) self.proxy_remarks.append(proxy_remark) diff --git a/app/subscription/funcs.py b/app/subscription/funcs.py index 9fcc3d0de..e3d0603f2 100644 --- a/app/subscription/funcs.py +++ b/app/subscription/funcs.py @@ -4,17 +4,18 @@ def get_grpc_gun(path: str) -> str: servicename = path.rsplit("/", 1)[0] streamname = path.rsplit("/", 1)[1].split("|")[0] - + if streamname == "Tun": return servicename[1:] - + return "%s%s%s" % (servicename, "/", streamname) + def get_grpc_multi(path: str) -> str: if not path.startswith("/"): return path - + servicename = path.rsplit("/", 1)[0] streamname = path.rsplit("/", 1)[1].split("|")[1] - return "%s%s%s" % (servicename, "/", streamname) \ No newline at end of file + return "%s%s%s" % (servicename, "/", streamname) diff --git a/app/subscription/outline.py b/app/subscription/outline.py index 67b802928..5e61f8cea 100644 --- a/app/subscription/outline.py +++ b/app/subscription/outline.py @@ -15,9 +15,7 @@ def render(self, reverse=False): self.config = dict(items) return json.dumps(self.config, indent=0) - def make_outbound( - self, remark: str, address: str, port: int, password: str, method: str - ): + def make_outbound(self, remark: str, address: str, port: int, password: str, method: str): config = { "method": method, "password": password, @@ -38,4 +36,4 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): password=settings["password"], method=settings["method"], ) - self.add_directly(outbound) \ No newline at end of file + self.add_directly(outbound) diff --git a/app/subscription/share.py b/app/subscription/share.py index 56c5de406..5a065d4a5 100644 --- a/app/subscription/share.py +++ b/app/subscription/share.py @@ -12,7 +12,14 @@ from app import xray from app.utils.system import get_public_ip, get_public_ipv6, readable_size -from . import * +from . import ( + V2rayShareLink, + V2rayJsonConfig, + SingBoxConfiguration, + ClashConfiguration, + ClashMetaConfiguration, + OutlineConfiguration, +) if TYPE_CHECKING: from app.models.user import UserResponse @@ -52,7 +59,7 @@ def generate_v2ray_links(proxies: dict, inbounds: dict, extra_data: dict, revers def generate_clash_subscription( - proxies: dict, inbounds: dict, extra_data: dict, reverse: bool, is_meta: bool = False + proxies: dict, inbounds: dict, extra_data: dict, reverse: bool, is_meta: bool = False ) -> str: if is_meta is True: conf = ClashMetaConfiguration() @@ -60,49 +67,45 @@ def generate_clash_subscription( conf = ClashConfiguration() format_variables = setup_format_variables(extra_data) - return process_inbounds_and_tags( - inbounds, proxies, format_variables, conf=conf, reverse=reverse - ) + return process_inbounds_and_tags(inbounds, proxies, format_variables, conf=conf, reverse=reverse) -def generate_singbox_subscription( - proxies: dict, inbounds: dict, extra_data: dict, reverse: bool -) -> str: +def generate_singbox_subscription(proxies: dict, inbounds: dict, extra_data: dict, reverse: bool) -> str: conf = SingBoxConfiguration() format_variables = setup_format_variables(extra_data) - return process_inbounds_and_tags( - inbounds, proxies, format_variables, conf=conf, reverse=reverse - ) + return process_inbounds_and_tags(inbounds, proxies, format_variables, conf=conf, reverse=reverse) def generate_outline_subscription( - proxies: dict, inbounds: dict, extra_data: dict, reverse: bool, + proxies: dict, + inbounds: dict, + extra_data: dict, + reverse: bool, ) -> str: conf = OutlineConfiguration() format_variables = setup_format_variables(extra_data) - return process_inbounds_and_tags( - inbounds, proxies, format_variables, conf=conf, reverse=reverse - ) + return process_inbounds_and_tags(inbounds, proxies, format_variables, conf=conf, reverse=reverse) def generate_v2ray_json_subscription( - proxies: dict, inbounds: dict, extra_data: dict, reverse: bool, + proxies: dict, + inbounds: dict, + extra_data: dict, + reverse: bool, ) -> str: conf = V2rayJsonConfig() format_variables = setup_format_variables(extra_data) - return process_inbounds_and_tags( - inbounds, proxies, format_variables, conf=conf, reverse=reverse - ) + return process_inbounds_and_tags(inbounds, proxies, format_variables, conf=conf, reverse=reverse) def generate_subscription( - user: "UserResponse", - config_format: Literal["v2ray", "clash-meta", "clash", "sing-box", "outline", "v2ray-json"], - as_base64: bool, - reverse: bool, + user: "UserResponse", + config_format: Literal["v2ray", "clash-meta", "clash", "sing-box", "outline", "v2ray-json"], + as_base64: bool, + reverse: bool, ) -> str: kwargs = { "proxies": user.proxies, @@ -229,18 +232,18 @@ def setup_format_variables(extra_data: dict) -> dict: def process_inbounds_and_tags( - inbounds: dict, - proxies: dict, - format_variables: dict, - conf: Union[ - V2rayShareLink, - V2rayJsonConfig, - SingBoxConfiguration, - ClashConfiguration, - ClashMetaConfiguration, - OutlineConfiguration - ], - reverse=False, + inbounds: dict, + proxies: dict, + format_variables: dict, + conf: Union[ + V2rayShareLink, + V2rayJsonConfig, + SingBoxConfiguration, + ClashConfiguration, + ClashMetaConfiguration, + OutlineConfiguration, + ], + reverse=False, ) -> Union[List, str]: for _, host in xray.hosts.items(): tag = host["inbound_tag"] @@ -274,9 +277,9 @@ def process_inbounds_and_tags( address = "" address_list = host["address"] - if host['address']: + if host["address"]: salt = secrets.token_hex(8) - address = random.choice(address_list).replace('*', salt) + address = random.choice(address_list).replace("*", salt) if sids := host_inbound.get("sids"): host_inbound["sid"] = random.choice(sids) @@ -298,8 +301,7 @@ def process_inbounds_and_tags( "alpn": host["alpn"] if host["alpn"] else None, "path": path, "fp": host["fingerprint"] or host_inbound.get("fp", ""), - "ais": host["allowinsecure"] - or host_inbound.get("allowinsecure", ""), + "ais": host["allowinsecure"] or host_inbound.get("allowinsecure", ""), "mux_enable": host["mux_enable"], "fragment_setting": host["fragment_setting"], "noise_setting": host["noise_setting"], @@ -311,7 +313,7 @@ def process_inbounds_and_tags( remark=host["remark"].format_map(format_variables), address=address.format_map(format_variables), inbound=host_inbound, - settings=settings.dict(no_obj=True) + settings=settings.dict(no_obj=True), ) return conf.render(reverse=reverse) diff --git a/app/subscription/singbox.py b/app/subscription/singbox.py index c6576d18b..9afaad7cb 100644 --- a/app/subscription/singbox.py +++ b/app/subscription/singbox.py @@ -7,24 +7,18 @@ from app.subscription.funcs import get_grpc_gun from app.templates import render_template -from config import ( - MUX_TEMPLATE, - SINGBOX_SETTINGS_TEMPLATE, - SINGBOX_SUBSCRIPTION_TEMPLATE, - USER_AGENT_TEMPLATE -) +from config import MUX_TEMPLATE, SINGBOX_SETTINGS_TEMPLATE, SINGBOX_SUBSCRIPTION_TEMPLATE, USER_AGENT_TEMPLATE class SingBoxConfiguration(str): - def __init__(self): self.proxy_remarks = [] self.config = json.loads(render_template(SINGBOX_SUBSCRIPTION_TEMPLATE)) self.mux_template = render_template(MUX_TEMPLATE) user_agent_data = json.loads(render_template(USER_AGENT_TEMPLATE)) - if 'list' in user_agent_data and isinstance(user_agent_data['list'], list): - self.user_agent_list = user_agent_data['list'] + if "list" in user_agent_data and isinstance(user_agent_data["list"], list): + self.user_agent_list = user_agent_data["list"] else: self.user_agent_list = [] @@ -36,12 +30,12 @@ def __init__(self): del user_agent_data def _remark_validation(self, remark): - if not remark in self.proxy_remarks: + if remark not in self.proxy_remarks: return remark c = 2 while True: - new = f'{remark} ({c})' - if not new in self.proxy_remarks: + new = f"{remark} ({c})" + if new not in self.proxy_remarks: return new c += 1 @@ -50,11 +44,9 @@ def add_outbound(self, outbound_data): def render(self, reverse=False): urltest_types = ["vmess", "vless", "trojan", "shadowsocks", "hysteria2", "tuic", "http", "ssh"] - urltest_tags = [outbound["tag"] - for outbound in self.config["outbounds"] if outbound["type"] in urltest_types] + urltest_tags = [outbound["tag"] for outbound in self.config["outbounds"] if outbound["type"] in urltest_types] selector_types = ["vmess", "vless", "trojan", "shadowsocks", "hysteria2", "tuic", "http", "ssh", "urltest"] - selector_tags = [outbound["tag"] - for outbound in self.config["outbounds"] if outbound["type"] in selector_types] + selector_tags = [outbound["tag"] for outbound in self.config["outbounds"] if outbound["type"] in selector_types] for outbound in self.config["outbounds"]: if outbound.get("type") == "urltest": @@ -66,23 +58,21 @@ def render(self, reverse=False): if reverse: self.config["outbounds"].reverse() - return json.dumps(self.config, indent=4,cls=UUIDEncoder) + return json.dumps(self.config, indent=4, cls=UUIDEncoder) @staticmethod - def tls_config(sni=None, fp=None, tls=None, pbk=None, - sid=None, alpn=None, ais=None): - + def tls_config(sni=None, fp=None, tls=None, pbk=None, sid=None, alpn=None, ais=None): config = {} - if tls in ['tls', 'reality']: + if tls in ["tls", "reality"]: config["enabled"] = True if sni is not None: config["server_name"] = sni - if tls == 'tls' and ais: - config['insecure'] = ais + if tls == "tls" and ais: + config["insecure"] = ais - if tls == 'reality': + if tls == "reality": config["reality"] = {"enabled": True} if pbk: config["reality"]["public_key"] = pbk @@ -90,23 +80,19 @@ def tls_config(sni=None, fp=None, tls=None, pbk=None, config["reality"]["short_id"] = sid if fp: - config["utls"] = { - "enabled": bool(fp), - "fingerprint": fp - } + config["utls"] = {"enabled": bool(fp), "fingerprint": fp} if alpn: config["alpn"] = [alpn] if not isinstance(alpn, list) else alpn return config - def http_config(self, host='', path='', random_user_agent: bool = False): - config = copy.deepcopy(self.settings.get("httpSettings", { - "idle_timeout": "15s", - "ping_timeout": "15s", - "method": "GET", - "headers": {} - })) + def http_config(self, host="", path="", random_user_agent: bool = False): + config = copy.deepcopy( + self.settings.get( + "httpSettings", {"idle_timeout": "15s", "ping_timeout": "15s", "method": "GET", "headers": {}} + ) + ) if "headers" not in config: config["headers"] = {} @@ -120,11 +106,10 @@ def http_config(self, host='', path='', random_user_agent: bool = False): return config - def ws_config(self, host='', path='', random_user_agent: bool = False, - max_early_data=None, early_data_header_name=None): - config = copy.deepcopy(self.settings.get("wsSettings", { - "headers": {} - })) + def ws_config( + self, host="", path="", random_user_agent: bool = False, max_early_data=None, early_data_header_name=None + ): + config = copy.deepcopy(self.settings.get("wsSettings", {"headers": {}})) if "headers" not in config: config["headers"] = {} @@ -141,7 +126,7 @@ def ws_config(self, host='', path='', random_user_agent: bool = False, return config - def grpc_config(self, path=''): + def grpc_config(self, path=""): config = copy.deepcopy(self.settings.get("grpcSettings", {})) if path: @@ -149,10 +134,8 @@ def grpc_config(self, path=''): return config - def httpupgrade_config(self, host='', path='', random_user_agent: bool = False): - config = copy.deepcopy(self.settings.get("httpupgradeSettings", { - "headers": {} - })) + def httpupgrade_config(self, host="", path="", random_user_agent: bool = False): + config = copy.deepcopy(self.settings.get("httpupgradeSettings", {"headers": {}})) if "headers" not in config: config["headers"] = {} @@ -164,15 +147,15 @@ def httpupgrade_config(self, host='', path='', random_user_agent: bool = False): return config - def transport_config(self, - transport_type='', - host='', - path='', - max_early_data=None, - early_data_header_name=None, - random_user_agent: bool = False, - ): - + def transport_config( + self, + transport_type="", + host="", + path="", + max_early_data=None, + early_data_header_name=None, + random_user_agent: bool = False, + ): transport_config = {} if transport_type: @@ -202,32 +185,32 @@ def transport_config(self, random_user_agent=random_user_agent, ) - transport_config['type'] = transport_type + transport_config["type"] = transport_type return transport_config - def make_outbound(self, - type: str, - remark: str, - address: str, - port: int, - net='', - path='', - host='', - flow='', - tls='', - sni='', - fp='', - alpn='', - pbk='', - sid='', - headers='', - ais='', - mux_enable: bool = False, - random_user_agent: bool = False, - ): - + def make_outbound( + self, + type: str, + remark: str, + address: str, + port: int, + net="", + path="", + host="", + flow="", + tls="", + sni="", + fp="", + alpn="", + pbk="", + sid="", + headers="", + ais="", + mux_enable: bool = False, + random_user_agent: bool = False, + ): if isinstance(port, str): - ports = port.split(',') + ports = port.split(",") port = int(choice(ports)) config = { @@ -237,30 +220,30 @@ def make_outbound(self, "server_port": port, } - if net in ('tcp', 'raw', 'kcp') and headers != 'http' and (tls or tls != 'none'): + if net in ("tcp", "raw", "kcp") and headers != "http" and (tls or tls != "none"): if flow: config["flow"] = flow - if net == 'h2': - net = 'http' - alpn = 'h2' - elif net == 'h3': - net = 'http' - alpn = 'h3' - elif net in ['tcp', 'raw'] and headers == 'http': - net = 'http' + if net == "h2": + net = "http" + alpn = "h2" + elif net == "h3": + net = "http" + alpn = "h3" + elif net in ["tcp", "raw"] and headers == "http": + net = "http" - if net in ['http', 'ws', 'quic', 'grpc', 'httpupgrade']: + if net in ["http", "ws", "quic", "grpc", "httpupgrade"]: max_early_data = None early_data_header_name = None if "?ed=" in path: path, max_early_data = path.split("?ed=") - max_early_data, = max_early_data.split("/") + (max_early_data,) = max_early_data.split("/") max_early_data = int(max_early_data) early_data_header_name = "Sec-WebSocket-Protocol" - config['transport'] = self.transport_config( + config["transport"] = self.transport_config( transport_type=net, host=host, path=path, @@ -269,22 +252,19 @@ def make_outbound(self, random_user_agent=random_user_agent, ) - if tls in ('tls', 'reality'): - config['tls'] = self.tls_config(sni=sni, fp=fp, tls=tls, - pbk=pbk, sid=sid, alpn=alpn, - ais=ais) + if tls in ("tls", "reality"): + config["tls"] = self.tls_config(sni=sni, fp=fp, tls=tls, pbk=pbk, sid=sid, alpn=alpn, ais=ais) mux_json = json.loads(self.mux_template) mux_config = mux_json["sing-box"] - config['multiplex'] = mux_config - if config['multiplex']["enabled"]: - config['multiplex']["enabled"] = mux_enable + config["multiplex"] = mux_config + if config["multiplex"]["enabled"]: + config["multiplex"]["enabled"] = mux_enable return config def add(self, remark: str, address: str, inbound: dict, settings: dict): - net = inbound["network"] path = inbound["path"] @@ -295,42 +275,43 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): if net in ("grpc", "gun"): path = get_grpc_gun(path) - alpn = inbound.get('alpn', None) + alpn = inbound.get("alpn", None) remark = self._remark_validation(remark) self.proxy_remarks.append(remark) outbound = self.make_outbound( remark=remark, - type=inbound['protocol'], + type=inbound["protocol"], address=address, - port=inbound['port'], + port=inbound["port"], net=net, - tls=(inbound['tls']), - flow=settings.get('flow', ''), - sni=inbound['sni'], - host=inbound['host'], + tls=(inbound["tls"]), + flow=settings.get("flow", ""), + sni=inbound["sni"], + host=inbound["host"], path=path, alpn=alpn.rsplit(sep=",") if alpn else None, - fp=inbound.get('fp', ''), - pbk=inbound.get('pbk', ''), - sid=inbound.get('sid', ''), - headers=inbound['header_type'], - ais=inbound.get('ais', ''), - mux_enable=inbound.get('mux_enable', False), - random_user_agent=inbound.get('random_user_agent', False),) - - if inbound['protocol'] == 'vmess': - outbound['uuid'] = settings['id'] - - elif inbound['protocol'] == 'vless': - outbound['uuid'] = settings['id'] - - elif inbound['protocol'] == 'trojan': - outbound['password'] = settings['password'] - - elif inbound['protocol'] == 'shadowsocks': - outbound['password'] = settings['password'] - outbound['method'] = settings['method'] + fp=inbound.get("fp", ""), + pbk=inbound.get("pbk", ""), + sid=inbound.get("sid", ""), + headers=inbound["header_type"], + ais=inbound.get("ais", ""), + mux_enable=inbound.get("mux_enable", False), + random_user_agent=inbound.get("random_user_agent", False), + ) + + if inbound["protocol"] == "vmess": + outbound["uuid"] = settings["id"] + + elif inbound["protocol"] == "vless": + outbound["uuid"] = settings["id"] + + elif inbound["protocol"] == "trojan": + outbound["password"] = settings["password"] + + elif inbound["protocol"] == "shadowsocks": + outbound["password"] = settings["password"] + outbound["method"] = settings["method"] self.add_outbound(outbound) diff --git a/app/subscription/v2ray.py b/app/subscription/v2ray.py index 9569a4634..c11d5619c 100644 --- a/app/subscription/v2ray.py +++ b/app/subscription/v2ray.py @@ -72,9 +72,9 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): ais=inbound.get("ais", ""), fs=inbound.get("fragment_setting", ""), multiMode=multi_mode, - sc_max_each_post_bytes=inbound.get('scMaxEachPostBytes'), - sc_max_concurrent_posts=inbound.get('scMaxConcurrentPosts'), - sc_min_posts_interval_ms=inbound.get('scMinPostsIntervalMs'), + sc_max_each_post_bytes=inbound.get("scMaxEachPostBytes"), + sc_max_concurrent_posts=inbound.get("scMaxConcurrentPosts"), + sc_min_posts_interval_ms=inbound.get("scMinPostsIntervalMs"), x_padding_bytes=inbound.get("xPaddingBytes"), mode=inbound.get("mode", ""), noGRPCHeader=inbound.get("noGRPCHeader"), @@ -82,7 +82,7 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): keepAlivePeriod=inbound.get("keepAlivePeriod", 0), scStreamUpServerSecs=inbound.get("scStreamUpServerSecs"), xmux=inbound.get("xmux", {}), - downloadSettings=inbound.get("downloadSettings", {}) + downloadSettings=inbound.get("downloadSettings", {}), ) elif inbound["protocol"] == "vless": @@ -106,9 +106,9 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): ais=inbound.get("ais", ""), fs=inbound.get("fragment_setting", ""), multiMode=multi_mode, - sc_max_each_post_bytes=inbound.get('scMaxEachPostBytes'), - sc_max_concurrent_posts=inbound.get('scMaxConcurrentPosts'), - sc_min_posts_interval_ms=inbound.get('scMinPostsIntervalMs'), + sc_max_each_post_bytes=inbound.get("scMaxEachPostBytes"), + sc_max_concurrent_posts=inbound.get("scMaxConcurrentPosts"), + sc_min_posts_interval_ms=inbound.get("scMinPostsIntervalMs"), x_padding_bytes=inbound.get("xPaddingBytes"), mode=inbound.get("mode", ""), noGRPCHeader=inbound.get("noGRPCHeader"), @@ -116,7 +116,7 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): keepAlivePeriod=inbound.get("keepAlivePeriod", 0), scStreamUpServerSecs=inbound.get("scStreamUpServerSecs"), xmux=inbound.get("xmux", {}), - downloadSettings=inbound.get("downloadSettings", {}) + downloadSettings=inbound.get("downloadSettings", {}), ) elif inbound["protocol"] == "trojan": @@ -140,9 +140,9 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): ais=inbound.get("ais", ""), fs=inbound.get("fragment_setting", ""), multiMode=multi_mode, - sc_max_each_post_bytes=inbound.get('scMaxEachPostBytes'), - sc_max_concurrent_posts=inbound.get('scMaxConcurrentPosts'), - sc_min_posts_interval_ms=inbound.get('scMinPostsIntervalMs'), + sc_max_each_post_bytes=inbound.get("scMaxEachPostBytes"), + sc_max_concurrent_posts=inbound.get("scMaxConcurrentPosts"), + sc_min_posts_interval_ms=inbound.get("scMinPostsIntervalMs"), x_padding_bytes=inbound.get("xPaddingBytes"), mode=inbound.get("mode", ""), noGRPCHeader=inbound.get("noGRPCHeader"), @@ -150,7 +150,7 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): keepAlivePeriod=inbound.get("keepAlivePeriod", 0), xmux=inbound.get("xmux", {}), scStreamUpServerSecs=inbound.get("scStreamUpServerSecs"), - downloadSettings=inbound.get("downloadSettings", {}) + downloadSettings=inbound.get("downloadSettings", {}), ) elif inbound["protocol"] == "shadowsocks": @@ -168,36 +168,36 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): @classmethod def vmess( - cls, - remark: str, - address: str, - port: int, - id: Union[str, UUID], - host="", - net="tcp", - path="", - type="", - tls="none", - sni="", - fp="", - alpn="", - pbk="", - sid="", - spx="", - ais="", - fs="", - multiMode: bool = False, - sc_max_each_post_bytes: int | None = None, - sc_max_concurrent_posts: int | None = None, - sc_min_posts_interval_ms: int | None = None, - x_padding_bytes: str | None = None, - mode: str = "", - noGRPCHeader: bool | None = None, - heartbeatPeriod: int | None = None, - scStreamUpServerSecs: int | None = None, - keepAlivePeriod: int = 0, - xmux: dict = {}, - downloadSettings: dict = {}, + cls, + remark: str, + address: str, + port: int, + id: Union[str, UUID], + host="", + net="tcp", + path="", + type="", + tls="none", + sni="", + fp="", + alpn="", + pbk="", + sid="", + spx="", + ais="", + fs="", + multiMode: bool = False, + sc_max_each_post_bytes: int | None = None, + sc_max_concurrent_posts: int | None = None, + sc_min_posts_interval_ms: int | None = None, + x_padding_bytes: str | None = None, + mode: str = "", + noGRPCHeader: bool | None = None, + heartbeatPeriod: int | None = None, + scStreamUpServerSecs: int | None = None, + keepAlivePeriod: int = 0, + xmux: dict = {}, + downloadSettings: dict = {}, ): payload = { "add": address, @@ -269,65 +269,56 @@ def vmess( if heartbeatPeriod: payload["heartbeatPeriod"] = heartbeatPeriod - return ( - "vmess://" - + base64.b64encode( - json.dumps(payload, sort_keys=True).encode("utf-8") - ).decode() - ) + return "vmess://" + base64.b64encode(json.dumps(payload, sort_keys=True).encode("utf-8")).decode() @classmethod - def vless(cls, - remark: str, - address: str, - port: int, - id: Union[str, UUID], - net='ws', - path='', - host='', - type='', - flow='', - tls='none', - sni='', - fp='', - alpn='', - pbk='', - sid='', - spx='', - ais='', - fs="", - multiMode: bool = False, - sc_max_each_post_bytes: int | None = None, - sc_max_concurrent_posts: int | None = None, - sc_min_posts_interval_ms: int | None = None, - x_padding_bytes: str | None = None, - mode: str = "", - noGRPCHeader: bool | None = None, - heartbeatPeriod: int | None = None, - scStreamUpServerSecs: int | None = None, - keepAlivePeriod: int = 0, - xmux: dict = {}, - downloadSettings: dict = {}, - ): - - payload = { - "security": tls, - "type": net, - "headerType": type - } - if flow and (tls in ('tls', 'reality') and net in ('tcp', 'raw', 'kcp') and type != 'http'): - payload['flow'] = flow + def vless( + cls, + remark: str, + address: str, + port: int, + id: Union[str, UUID], + net="ws", + path="", + host="", + type="", + flow="", + tls="none", + sni="", + fp="", + alpn="", + pbk="", + sid="", + spx="", + ais="", + fs="", + multiMode: bool = False, + sc_max_each_post_bytes: int | None = None, + sc_max_concurrent_posts: int | None = None, + sc_min_posts_interval_ms: int | None = None, + x_padding_bytes: str | None = None, + mode: str = "", + noGRPCHeader: bool | None = None, + heartbeatPeriod: int | None = None, + scStreamUpServerSecs: int | None = None, + keepAlivePeriod: int = 0, + xmux: dict = {}, + downloadSettings: dict = {}, + ): + payload = {"security": tls, "type": net, "headerType": type} + if flow and (tls in ("tls", "reality") and net in ("tcp", "raw", "kcp") and type != "http"): + payload["flow"] = flow - if net == 'grpc': - payload['serviceName'] = path + if net == "grpc": + payload["serviceName"] = path payload["authority"] = host if multiMode: payload["mode"] = "multi" else: payload["mode"] = "gun" - elif net == 'quic': - payload['key'] = path + elif net == "quic": + payload["key"] = path payload["quicSecurity"] = host elif net in ("splithttp", "xhttp"): @@ -356,8 +347,8 @@ def vless(cls, if extra: payload["extra"] = (json.dumps(extra)).replace(" ", "") - elif net == 'kcp': - payload['seed'] = path + elif net == "kcp": + payload["seed"] = path payload["host"] = host elif net == "ws": @@ -388,57 +379,48 @@ def vless(cls, if spx: payload["spx"] = spx - return ( - "vless://" - + f"{id}@{address}:{port}?" - + urlparse.urlencode(payload) - + f"#{(urlparse.quote(remark))}" - ) + return "vless://" + f"{id}@{address}:{port}?" + urlparse.urlencode(payload) + f"#{(urlparse.quote(remark))}" @classmethod - def trojan(cls, - remark: str, - address: str, - port: int, - password: str, - net='tcp', - path='', - host='', - type='', - flow='', - tls='none', - sni='', - fp='', - alpn='', - pbk='', - sid='', - spx='', - ais='', - fs="", - multiMode: bool = False, - sc_max_each_post_bytes: int | None = None, - sc_max_concurrent_posts: int | None = None, - sc_min_posts_interval_ms: int | None = None, - x_padding_bytes: str | None = None, - mode: str = "", - noGRPCHeader: bool | None = None, - heartbeatPeriod: int | None = None, - scStreamUpServerSecs: int | None = None, - keepAlivePeriod: int = 0, - xmux: dict = {}, - downloadSettings: dict = {}, - ): - - payload = { - "security": tls, - "type": net, - "headerType": type - } - if flow and (tls in ('tls', 'reality') and net in ('tcp', 'raw', 'kcp') and type != 'http'): - payload['flow'] = flow + def trojan( + cls, + remark: str, + address: str, + port: int, + password: str, + net="tcp", + path="", + host="", + type="", + flow="", + tls="none", + sni="", + fp="", + alpn="", + pbk="", + sid="", + spx="", + ais="", + fs="", + multiMode: bool = False, + sc_max_each_post_bytes: int | None = None, + sc_max_concurrent_posts: int | None = None, + sc_min_posts_interval_ms: int | None = None, + x_padding_bytes: str | None = None, + mode: str = "", + noGRPCHeader: bool | None = None, + heartbeatPeriod: int | None = None, + scStreamUpServerSecs: int | None = None, + keepAlivePeriod: int = 0, + xmux: dict = {}, + downloadSettings: dict = {}, + ): + payload = {"security": tls, "type": net, "headerType": type} + if flow and (tls in ("tls", "reality") and net in ("tcp", "raw", "kcp") and type != "http"): + payload["flow"] = flow - if net == 'grpc': - payload['serviceName'] = path + if net == "grpc": + payload["serviceName"] = path payload["authority"] = host if multiMode: payload["mode"] = "multi" @@ -471,12 +453,12 @@ def trojan(cls, if extra: payload["extra"] = (json.dumps(extra)).replace(" ", "") - elif net == 'quic': - payload['key'] = path + elif net == "quic": + payload["key"] = path payload["quicSecurity"] = host - elif net == 'kcp': - payload['seed'] = path + elif net == "kcp": + payload["seed"] = path payload["host"] = host elif net == "ws": @@ -514,9 +496,7 @@ def trojan(cls, ) @classmethod - def shadowsocks( - cls, remark: str, address: str, port: int, password: str, method: str - ): + def shadowsocks(cls, remark: str, address: str, port: int, password: str, method: str): return ( "ss://" + base64.b64encode(f"{method}:{password}".encode()).decode() @@ -525,29 +505,26 @@ def shadowsocks( class V2rayJsonConfig(str): - def __init__(self): self.config = [] self.template = render_template(V2RAY_SUBSCRIPTION_TEMPLATE) self.mux_template = render_template(MUX_TEMPLATE) user_agent_data = json.loads(render_template(USER_AGENT_TEMPLATE)) - if 'list' in user_agent_data and isinstance(user_agent_data['list'], list): - self.user_agent_list = user_agent_data['list'] + if "list" in user_agent_data and isinstance(user_agent_data["list"], list): + self.user_agent_list = user_agent_data["list"] else: self.user_agent_list = [] - grpc_user_agent_data = json.loads( - render_template(GRPC_USER_AGENT_TEMPLATE)) + grpc_user_agent_data = json.loads(render_template(GRPC_USER_AGENT_TEMPLATE)) - if 'list' in grpc_user_agent_data and isinstance(grpc_user_agent_data['list'], list): - self.grpc_user_agent_data = grpc_user_agent_data['list'] + if "list" in grpc_user_agent_data and isinstance(grpc_user_agent_data["list"], list): + self.grpc_user_agent_data = grpc_user_agent_data["list"] else: self.grpc_user_agent_data = [] try: - self.settings = json.loads( - render_template(V2RAY_SETTINGS_TEMPLATE)) + self.settings = json.loads(render_template(V2RAY_SETTINGS_TEMPLATE)) except TemplateNotFound: self.settings = {} @@ -566,18 +543,16 @@ def render(self, reverse=False): @staticmethod def tls_config(sni=None, fp=None, alpn=None, ais: bool = False) -> dict: - tlsSettings = {} if sni is not None: tlsSettings["serverName"] = sni - tlsSettings['allowInsecure'] = ais if ais else False + tlsSettings["allowInsecure"] = ais if ais else False if fp: tlsSettings["fingerprint"] = fp if alpn: - tlsSettings["alpn"] = [alpn] if not isinstance( - alpn, list) else alpn + tlsSettings["alpn"] = [alpn] if not isinstance(alpn, list) else alpn tlsSettings["show"] = False @@ -585,7 +560,6 @@ def tls_config(sni=None, fp=None, alpn=None, ais: bool = False) -> dict: @staticmethod def reality_config(sni=None, fp=None, pbk=None, sid=None, spx=None) -> dict: - realitySettings = {} if sni is not None: realitySettings["serverName"] = sni @@ -603,7 +577,9 @@ def reality_config(sni=None, fp=None, pbk=None, sid=None, spx=None) -> dict: return realitySettings - def ws_config(self, path: str = "", host: str = "", random_user_agent: bool = False, heartbeatPeriod: int = 0) -> dict: + def ws_config( + self, path: str = "", host: str = "", random_user_agent: bool = False, heartbeatPeriod: int = 0 + ) -> dict: wsSettings = copy.deepcopy(self.settings.get("wsSettings", {})) if "headers" not in wsSettings: @@ -620,8 +596,7 @@ def ws_config(self, path: str = "", host: str = "", random_user_agent: bool = Fa return wsSettings def httpupgrade_config(self, path: str = "", host: str = "", random_user_agent: bool = False) -> dict: - httpupgradeSettings = copy.deepcopy( - self.settings.get("httpupgradeSettings", {})) + httpupgradeSettings = copy.deepcopy(self.settings.get("httpupgradeSettings", {})) if "headers" not in httpupgradeSettings: httpupgradeSettings["headers"] = {} @@ -630,23 +605,26 @@ def httpupgrade_config(self, path: str = "", host: str = "", random_user_agent: if host: httpupgradeSettings["host"] = host if random_user_agent: - httpupgradeSettings["headers"]["User-Agent"] = choice( - self.user_agent_list) + httpupgradeSettings["headers"]["User-Agent"] = choice(self.user_agent_list) return httpupgradeSettings - def splithttp_config(self, path: str = "", host: str = "", random_user_agent: bool = False, - sc_max_each_post_bytes: int | None = None, - sc_max_concurrent_posts: int | None = None, - sc_min_posts_interval_ms: int | None = None, - x_padding_bytes: str | None = None, - xmux: dict = {}, - downloadSettings: dict = {}, - mode: str = "", - noGRPCHeader: bool | None = None, - scStreamUpServerSecs: int | None = None, - keepAlivePeriod: int = 0, - ) -> dict: + def splithttp_config( + self, + path: str = "", + host: str = "", + random_user_agent: bool = False, + sc_max_each_post_bytes: int | None = None, + sc_max_concurrent_posts: int | None = None, + sc_min_posts_interval_ms: int | None = None, + x_padding_bytes: str | None = None, + xmux: dict = {}, + downloadSettings: dict = {}, + mode: str = "", + noGRPCHeader: bool | None = None, + scStreamUpServerSecs: int | None = None, + keepAlivePeriod: int = 0, + ) -> dict: config = copy.deepcopy(self.settings.get("splithttpSettings", {})) config["mode"] = mode @@ -681,14 +659,20 @@ def splithttp_config(self, path: str = "", host: str = "", random_user_agent: bo return config - def grpc_config(self, path: str = "", host: str = "", multiMode: bool = False, - random_user_agent: bool = False) -> dict: - config = copy.deepcopy(self.settings.get("grpcSettings", { - "idle_timeout": 60, - "health_check_timeout": 20, - "permit_without_stream": False, - "initial_windows_size": 35538 - })) + def grpc_config( + self, path: str = "", host: str = "", multiMode: bool = False, random_user_agent: bool = False + ) -> dict: + config = copy.deepcopy( + self.settings.get( + "grpcSettings", + { + "idle_timeout": 60, + "health_check_timeout": 20, + "permit_without_stream": False, + "initial_windows_size": 35538, + }, + ) + ) config["multiMode"] = multiMode @@ -704,29 +688,28 @@ def grpc_config(self, path: str = "", host: str = "", multiMode: bool = False, def tcp_config(self, headers="none", path: str = "", host: str = "", random_user_agent: bool = False) -> dict: if headers == "http": - config = copy.deepcopy(self.settings.get("tcphttpSettings", { - "header": { - "request": { - "headers": { - "Accept-Encoding": [ - "gzip", "deflate" - ], - "Connection": [ - "keep-alive" - ], - "Pragma": "no-cache" - }, - "method": "GET", - "version": "1.1" - } - } - })) + config = copy.deepcopy( + self.settings.get( + "tcphttpSettings", + { + "header": { + "request": { + "headers": { + "Accept-Encoding": ["gzip", "deflate"], + "Connection": ["keep-alive"], + "Pragma": "no-cache", + }, + "method": "GET", + "version": "1.1", + } + } + }, + ) + ) else: - config = copy.deepcopy(self.settings.get("tcpSettings", self.settings.get("rawSettings", { - "header": { - "type": "none" - } - }))) + config = copy.deepcopy( + self.settings.get("tcpSettings", self.settings.get("rawSettings", {"header": {"type": "none"}})) + ) if "header" not in config: config["header"] = {} @@ -748,24 +731,17 @@ def tcp_config(self, headers="none", path: str = "", host: str = "", random_user config["header"]["request"]["headers"]["Host"] = [host] if random_user_agent: - config["header"]["request"]["headers"]["User-Agent"] = [ - choice(self.user_agent_list)] + config["header"]["request"]["headers"]["User-Agent"] = [choice(self.user_agent_list)] return config def http_config(self, net="http", path: str = "", host: str = "", random_user_agent: bool = False) -> dict: if net == "h2": - config = copy.deepcopy(self.settings.get("h2Settings", { - "header": {} - })) + config = copy.deepcopy(self.settings.get("h2Settings", {"header": {}})) elif net == "h3": - config = copy.deepcopy(self.settings.get("h3Settings", { - "header": {} - })) + config = copy.deepcopy(self.settings.get("h3Settings", {"header": {}})) else: - config = self.settings.get("httpSettings", { - "header": {} - }) + config = self.settings.get("httpSettings", {"header": {}}) if "header" not in config: config["header"] = {} @@ -775,19 +751,14 @@ def http_config(self, net="http", path: str = "", host: str = "", random_user_ag else: config["host"] = [] if random_user_agent: - config["headers"]["User-Agent"] = [ - choice(self.user_agent_list)] + config["headers"]["User-Agent"] = [choice(self.user_agent_list)] return config def quic_config(self, path=None, host=None, header=None) -> dict: - quicSettings = copy.deepcopy(self.settings.get("quicSettings", { - "security": "none", - "header": { - "type": "none" - }, - "key": "" - })) + quicSettings = copy.deepcopy( + self.settings.get("quicSettings", {"security": "none", "header": {"type": "none"}, "key": ""}) + ) if "header" not in quicSettings: quicSettings["header"] = {"type": "none"} @@ -801,18 +772,21 @@ def quic_config(self, path=None, host=None, header=None) -> dict: return quicSettings def kcp_config(self, seed=None, host=None, header=None) -> dict: - kcpSettings = copy.deepcopy(self.settings.get("kcpSettings", { - "header": { - "type": "none" - }, - "mtu": 1350, - "tti": 50, - "uplinkCapacity": 12, - "downlinkCapacity": 100, - "congestion": False, - "readBufferSize": 2, - "writeBufferSize": 2, - })) + kcpSettings = copy.deepcopy( + self.settings.get( + "kcpSettings", + { + "header": {"type": "none"}, + "mtu": 1350, + "tti": 50, + "uplinkCapacity": 12, + "downlinkCapacity": 100, + "congestion": False, + "readBufferSize": 2, + "writeBufferSize": 2, + }, + ) + ) if "header" not in kcpSettings: kcpSettings["header"] = {"type": "none"} @@ -826,10 +800,9 @@ def kcp_config(self, seed=None, host=None, header=None) -> dict: return kcpSettings @staticmethod - def stream_setting_config(network=None, security=None, - network_setting=None, tls_settings=None, - sockopt=None) -> dict: - + def stream_setting_config( + network=None, security=None, network_setting=None, tls_settings=None, sockopt=None + ) -> dict: streamSettings = {"network": network} if security and security != "none": @@ -840,7 +813,7 @@ def stream_setting_config(network=None, security=None, streamSettings[f"{network}Settings"] = network_setting if sockopt: - streamSettings['sockopt'] = sockopt + streamSettings["sockopt"] = sockopt return streamSettings @@ -852,12 +825,7 @@ def vmess_config(address=None, port=None, id=None) -> dict: "address": address, "port": port, "users": [ - { - "id": id, - "alterId": 0, - "email": "https://gozargah.github.io/marzban/", - "security": "auto" - } + {"id": id, "alterId": 0, "email": "https://gozargah.github.io/marzban/", "security": "auto"} ], } ] @@ -877,7 +845,7 @@ def vless_config(address=None, port=None, id=None, flow="") -> dict: "encryption": "none", "email": "https://gozargah.github.io/marzban/", "alterId": 0, - "flow": flow + "flow": flow, } ], } @@ -914,12 +882,8 @@ def shadowsocks_config(address=None, port=None, password=None, method=None) -> d @staticmethod def make_fragment(fragment: str) -> dict: - length, interval, packets = fragment.split(',') - return { - "packets": packets, - "length": length, - "interval": interval - } + length, interval, packets = fragment.split(",") + return {"packets": packets, "length": length, "interval": interval} @staticmethod def make_noises(noises: str) -> list: @@ -927,13 +891,9 @@ def make_noises(noises: str) -> list: noises_settings = [] for n in sn: try: - tp, delay = n.split(',') + tp, delay = n.split(",") _type, packet = tp.split(":") - noises_settings.append({ - "type": _type, - "packet": packet, - "delay": delay - }) + noises_settings.append({"type": _type, "packet": packet, "delay": delay}) except ValueError: pass @@ -943,119 +903,111 @@ def make_noises(noises: str) -> list: def make_dialer_outbound(fragment: str = "", noises: str = "") -> Union[dict, None]: dialer_settings = {} if fragment: - dialer_settings["fragment"] = V2rayJsonConfig.make_fragment( - fragment) + dialer_settings["fragment"] = V2rayJsonConfig.make_fragment(fragment) if noises: dialer_settings["noises"] = V2rayJsonConfig.make_noises(noises) if dialer_settings: - return { - "tag": "dialer", - "protocol": "freedom", - "settings": dialer_settings - } + return {"tag": "dialer", "protocol": "freedom", "settings": dialer_settings} return None - def make_stream_setting(self, - net='', - path='', - host='', - tls='', - sni='', - fp='', - alpn='', - pbk='', - sid='', - spx='', - headers='', - ais='', - dialer_proxy='', - multiMode: bool = False, - random_user_agent: bool = False, - sc_max_each_post_bytes: int | None = None, - sc_max_concurrent_posts: int | None = None, - sc_min_posts_interval_ms: int | None = None, - x_padding_bytes: str | None = None, - xmux: dict = {}, - downloadSettings: dict = {}, - mode: str = "", - noGRPCHeader: bool | None = None, - scStreamUpServerSecs: int | None = None, - heartbeatPeriod: int = 0, - keepAlivePeriod: int = 0, - ) -> dict: - + def make_stream_setting( + self, + net="", + path="", + host="", + tls="", + sni="", + fp="", + alpn="", + pbk="", + sid="", + spx="", + headers="", + ais="", + dialer_proxy="", + multiMode: bool = False, + random_user_agent: bool = False, + sc_max_each_post_bytes: int | None = None, + sc_max_concurrent_posts: int | None = None, + sc_min_posts_interval_ms: int | None = None, + x_padding_bytes: str | None = None, + xmux: dict = {}, + downloadSettings: dict = {}, + mode: str = "", + noGRPCHeader: bool | None = None, + scStreamUpServerSecs: int | None = None, + heartbeatPeriod: int = 0, + keepAlivePeriod: int = 0, + ) -> dict: if net == "ws": network_setting = self.ws_config( - path=path, host=host, random_user_agent=random_user_agent, heartbeatPeriod=heartbeatPeriod) + path=path, host=host, random_user_agent=random_user_agent, heartbeatPeriod=heartbeatPeriod + ) elif net == "grpc": network_setting = self.grpc_config( - path=path, host=host, multiMode=multiMode, random_user_agent=random_user_agent) + path=path, host=host, multiMode=multiMode, random_user_agent=random_user_agent + ) elif net in ("h3", "h2", "http"): - network_setting = self.http_config( - net=net, path=path, host=host, random_user_agent=random_user_agent) + network_setting = self.http_config(net=net, path=path, host=host, random_user_agent=random_user_agent) elif net == "kcp": - network_setting = self.kcp_config( - seed=path, host=host, header=headers) + network_setting = self.kcp_config(seed=path, host=host, header=headers) elif net in ("tcp", "raw") and tls != "reality": network_setting = self.tcp_config( - headers=headers, path=path, host=host, random_user_agent=random_user_agent) + headers=headers, path=path, host=host, random_user_agent=random_user_agent + ) elif net == "quic": - network_setting = self.quic_config( - path=path, host=host, header=headers) + network_setting = self.quic_config(path=path, host=host, header=headers) elif net == "httpupgrade": - network_setting = self.httpupgrade_config( - path=path, host=host, random_user_agent=random_user_agent) + network_setting = self.httpupgrade_config(path=path, host=host, random_user_agent=random_user_agent) elif net in ("splithttp", "xhttp"): - network_setting = self.splithttp_config(path=path, host=host, random_user_agent=random_user_agent, - sc_max_each_post_bytes=sc_max_each_post_bytes, - sc_max_concurrent_posts=sc_max_concurrent_posts, - sc_min_posts_interval_ms=sc_min_posts_interval_ms, - x_padding_bytes=x_padding_bytes, - xmux=xmux, - downloadSettings=downloadSettings, - mode=mode, - noGRPCHeader=noGRPCHeader, - keepAlivePeriod=keepAlivePeriod, - scStreamUpServerSecs=scStreamUpServerSecs, - ) + network_setting = self.splithttp_config( + path=path, + host=host, + random_user_agent=random_user_agent, + sc_max_each_post_bytes=sc_max_each_post_bytes, + sc_max_concurrent_posts=sc_max_concurrent_posts, + sc_min_posts_interval_ms=sc_min_posts_interval_ms, + x_padding_bytes=x_padding_bytes, + xmux=xmux, + downloadSettings=downloadSettings, + mode=mode, + noGRPCHeader=noGRPCHeader, + keepAlivePeriod=keepAlivePeriod, + scStreamUpServerSecs=scStreamUpServerSecs, + ) else: network_setting = {} if tls == "tls": tls_settings = self.tls_config(sni=sni, fp=fp, alpn=alpn, ais=ais) elif tls == "reality": - tls_settings = self.reality_config( - sni=sni, fp=fp, pbk=pbk, sid=sid, spx=spx) + tls_settings = self.reality_config(sni=sni, fp=fp, pbk=pbk, sid=sid, spx=spx) else: tls_settings = None if dialer_proxy: - sockopt = { - "dialerProxy": dialer_proxy - } + sockopt = {"dialerProxy": dialer_proxy} else: sockopt = None - return self.stream_setting_config(network=net, security=tls, - network_setting=network_setting, - tls_settings=tls_settings, - sockopt=sockopt) + return self.stream_setting_config( + network=net, security=tls, network_setting=network_setting, tls_settings=tls_settings, sockopt=sockopt + ) def add(self, remark: str, address: str, inbound: dict, settings: dict): - - net = inbound['network'] - protocol = inbound['protocol'] - port = inbound['port'] + net = inbound["network"] + protocol = inbound["protocol"] + port = inbound["port"] if isinstance(port, str): - ports = port.split(',') + ports = port.split(",") port = int(choice(ports)) - tls = (inbound['tls']) - headers = inbound['header_type'] - fragment = inbound['fragment_setting'] - noise = inbound['noise_setting'] + tls = inbound["tls"] + headers = inbound["header_type"] + fragment = inbound["fragment_setting"] + noise = inbound["noise_setting"] path = inbound["path"] multi_mode = inbound.get("multiMode", False) @@ -1065,65 +1017,54 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): else: path = get_grpc_gun(path) - outbound = { - "tag": "proxy", - "protocol": protocol - } + outbound = {"tag": "proxy", "protocol": protocol} - if inbound['protocol'] == 'vmess': - outbound["settings"] = self.vmess_config(address=address, - port=port, - id=settings['id']) + if inbound["protocol"] == "vmess": + outbound["settings"] = self.vmess_config(address=address, port=port, id=settings["id"]) - elif inbound['protocol'] == 'vless': - if net in ('tcp', 'raw', 'kcp') and headers != 'http' and tls in ('tls', 'reality'): - flow = settings.get('flow', '') + elif inbound["protocol"] == "vless": + if net in ("tcp", "raw", "kcp") and headers != "http" and tls in ("tls", "reality"): + flow = settings.get("flow", "") else: flow = None - outbound["settings"] = self.vless_config(address=address, - port=port, - id=settings['id'], - flow=flow) + outbound["settings"] = self.vless_config(address=address, port=port, id=settings["id"], flow=flow) - elif inbound['protocol'] == 'trojan': - outbound["settings"] = self.trojan_config(address=address, - port=port, - password=settings['password']) + elif inbound["protocol"] == "trojan": + outbound["settings"] = self.trojan_config(address=address, port=port, password=settings["password"]) - elif inbound['protocol'] == 'shadowsocks': - outbound["settings"] = self.shadowsocks_config(address=address, - port=port, - password=settings['password'], - method=settings['method']) + elif inbound["protocol"] == "shadowsocks": + outbound["settings"] = self.shadowsocks_config( + address=address, port=port, password=settings["password"], method=settings["method"] + ) outbounds = [outbound] - dialer_proxy = '' + dialer_proxy = "" extra_outbound = self.make_dialer_outbound(fragment, noise) if extra_outbound: - dialer_proxy = extra_outbound['tag'] + dialer_proxy = extra_outbound["tag"] outbounds.append(extra_outbound) - alpn = inbound.get('alpn', None) + alpn = inbound.get("alpn", None) outbound["streamSettings"] = self.make_stream_setting( net=net, tls=tls, - sni=inbound['sni'], - host=inbound['host'], + sni=inbound["sni"], + host=inbound["host"], path=path, alpn=alpn.rsplit(sep=",") if alpn else None, - fp=inbound.get('fp', ''), - pbk=inbound.get('pbk', ''), - sid=inbound.get('sid', ''), - spx=inbound.get('spx', ''), + fp=inbound.get("fp", ""), + pbk=inbound.get("pbk", ""), + sid=inbound.get("sid", ""), + spx=inbound.get("spx", ""), headers=headers, - ais=inbound.get('ais', ''), + ais=inbound.get("ais", ""), dialer_proxy=dialer_proxy, multiMode=multi_mode, - random_user_agent=inbound.get('random_user_agent', False), - sc_max_each_post_bytes=inbound.get('scMaxEachPostBytes'), - sc_max_concurrent_posts=inbound.get('scMaxConcurrentPosts'), - sc_min_posts_interval_ms=inbound.get('scMinPostsIntervalMs'), + random_user_agent=inbound.get("random_user_agent", False), + sc_max_each_post_bytes=inbound.get("scMaxEachPostBytes"), + sc_max_concurrent_posts=inbound.get("scMaxConcurrentPosts"), + sc_min_posts_interval_ms=inbound.get("scMinPostsIntervalMs"), x_padding_bytes=inbound.get("xPaddingBytes"), xmux=inbound.get("xmux", {}), downloadSettings=inbound.get("downloadSettings", {}), @@ -1137,7 +1078,7 @@ def add(self, remark: str, address: str, inbound: dict, settings: dict): mux_json = json.loads(self.mux_template) mux_config = mux_json["v2ray"] - if inbound.get('mux_enable', False): + if inbound.get("mux_enable", False): outbound["mux"] = mux_config outbound["mux"]["enabled"] = True diff --git a/app/telegram/__init__.py b/app/telegram/__init__.py index 54ffa5383..b4abddf8f 100644 --- a/app/telegram/__init__.py +++ b/app/telegram/__init__.py @@ -8,11 +8,12 @@ bot = None if TELEGRAM_API_TOKEN: - apihelper.proxy = {'http': TELEGRAM_PROXY_URL, 'https': TELEGRAM_PROXY_URL} + apihelper.proxy = {"http": TELEGRAM_PROXY_URL, "https": TELEGRAM_PROXY_URL} bot = TeleBot(TELEGRAM_API_TOKEN) handler_names = ["admin", "report", "user"] + @app.on_event("startup") def start_bot(): if bot: @@ -21,7 +22,8 @@ def start_bot(): spec = importlib.util.spec_from_file_location(name, f"{handler_dir}{name}.py") spec.loader.exec_module(importlib.util.module_from_spec(spec)) - from app.telegram import utils # setup custom handlers + from app.telegram import utils # setup custom handlers + utils.setup() thread = Thread(target=bot.infinity_polling, daemon=True) @@ -37,7 +39,7 @@ def start_bot(): report_user_usage_reset, report_user_data_reset_by_next, report_user_subscription_revoked, - report_login + report_login, ) __all__ = [ @@ -50,5 +52,5 @@ def start_bot(): "report_user_usage_reset", "report_user_data_reset_by_next", "report_user_subscription_revoked", - "report_login" + "report_login", ] diff --git a/app/telegram/handlers/admin.py b/app/telegram/handlers/admin.py index 7ac252754..9aaaf9325 100644 --- a/app/telegram/handlers/admin.py +++ b/app/telegram/handlers/admin.py @@ -84,7 +84,7 @@ def get_system_info(): onhold_users=onhold_users, deactivate_users=total_users - (active_users + onhold_users), up_speed=readable_size(realtime_bandwidth().outgoing_bytes), - down_speed=readable_size(realtime_bandwidth().incoming_bytes) + down_speed=readable_size(realtime_bandwidth().incoming_bytes), ) @@ -105,51 +105,53 @@ def cleanup_messages(chat_id: int) -> None: mem_store.set(f"{chat_id}:messages_to_delete", []) -@bot.message_handler(commands=['start', 'help'], is_admin=True) +@bot.message_handler(commands=["start", "help"], is_admin=True) def help_command(message: types.Message): cleanup_messages(message.chat.id) bot.clear_step_handler_by_chat_id(message.chat.id) - return bot.reply_to(message, """ + return bot.reply_to( + message, + """ {user_link} Welcome to Marzban Telegram-Bot Admin Panel. Here you can manage your users and proxies. To get started, use the buttons below. Also, You can get and modify users by /user command. -""".format( - user_link=user_link(message.from_user) - ), parse_mode="html", reply_markup=BotKeyboard.main_menu()) +""".format(user_link=user_link(message.from_user)), + parse_mode="html", + reply_markup=BotKeyboard.main_menu(), + ) -@bot.callback_query_handler(cb_query_equals('system'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("system"), is_admin=True) def system_command(call: types.CallbackQuery): return bot.edit_message_text( get_system_info(), call.message.chat.id, call.message.message_id, parse_mode="MarkdownV2", - reply_markup=BotKeyboard.main_menu() + reply_markup=BotKeyboard.main_menu(), ) -@bot.callback_query_handler(cb_query_equals('restart'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("restart"), is_admin=True) def restart_command(call: types.CallbackQuery): bot.edit_message_text( - '⚠️ Are you sure? This will restart Xray core.', + "⚠️ Are you sure? This will restart Xray core.", call.message.chat.id, call.message.message_id, - reply_markup=BotKeyboard.confirm_action(action='restart') + reply_markup=BotKeyboard.confirm_action(action="restart"), ) -@bot.callback_query_handler(cb_query_startswith('delete:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("delete:"), is_admin=True) def delete_user_command(call: types.CallbackQuery): - username = call.data.split(':')[1] + username = call.data.split(":")[1] bot.edit_message_text( - f'⚠️ Are you sure? This will delete user `{username}`.', + f"⚠️ Are you sure? This will delete user `{username}`.", call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action( - action='delete', username=username) + reply_markup=BotKeyboard.confirm_action(action="delete", username=username), ) @@ -161,8 +163,7 @@ def suspend_user_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action( - action="suspend", username=username), + reply_markup=BotKeyboard.confirm_action(action="suspend", username=username), ) @@ -174,8 +175,7 @@ def activate_user_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action( - action="activate", username=username), + reply_markup=BotKeyboard.confirm_action(action="activate", username=username), ) @@ -187,12 +187,11 @@ def reset_usage_user_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action( - action="reset_usage", username=username), + reply_markup=BotKeyboard.confirm_action(action="reset_usage", username=username), ) -@bot.callback_query_handler(cb_query_equals('edit_all'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("edit_all"), is_admin=True) def edit_all_command(call: types.CallbackQuery): with GetDB() as db: total_users = crud.get_users_count(db) @@ -213,37 +212,40 @@ def edit_all_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.edit_all_menu() + reply_markup=BotKeyboard.edit_all_menu(), ) -@bot.callback_query_handler(cb_query_equals('delete_expired'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("delete_expired"), is_admin=True) def delete_expired_command(call: types.CallbackQuery): bot.edit_message_text( - f"⚠️ Are you sure? This will *DELETE All Expired Users*‼️", + "⚠️ Are you sure? This will *DELETE All Expired Users*‼️", call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action(action="delete_expired")) + reply_markup=BotKeyboard.confirm_action(action="delete_expired"), + ) -@bot.callback_query_handler(cb_query_equals('delete_limited'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("delete_limited"), is_admin=True) def delete_limited_command(call: types.CallbackQuery): bot.edit_message_text( - f"⚠️ Are you sure? This will *DELETE All Limited Users*‼️", + "⚠️ Are you sure? This will *DELETE All Limited Users*‼️", call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action(action="delete_limited")) + reply_markup=BotKeyboard.confirm_action(action="delete_limited"), + ) -@bot.callback_query_handler(cb_query_equals('add_data'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("add_data"), is_admin=True) def add_data_command(call: types.CallbackQuery): msg = bot.edit_message_text( - f"🔋 Enter Data Limit to increase or decrease (GB):", + "🔋 Enter Data Limit to increase or decrease (GB):", call.message.chat.id, call.message.message_id, - reply_markup=BotKeyboard.inline_cancel_action()) + reply_markup=BotKeyboard.inline_cancel_action(), + ) schedule_delete_message(call.message.chat.id, call.message.id) schedule_delete_message(call.message.chat.id, msg.id) return bot.register_next_step_handler(call.message, add_data_step) @@ -255,7 +257,7 @@ def add_data_step(message): if not data_limit: raise ValueError except ValueError: - wait_msg = bot.send_message(message.chat.id, '❌ Data limit must be a number and not zero.') + wait_msg = bot.send_message(message.chat.id, "❌ Data limit must be a number and not zero.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, add_data_step) schedule_delete_message(message.chat.id, message.message_id) @@ -264,18 +266,20 @@ def add_data_step(message): f"⚠️ Are you sure? this will change Data limit of all users according to " f"{'+' if data_limit > 0 else '-'}{readable_size(abs(data_limit * 1024*1024*1024))}", parse_mode="html", - reply_markup=BotKeyboard.confirm_action('add_data', data_limit)) + reply_markup=BotKeyboard.confirm_action("add_data", data_limit), + ) cleanup_messages(message.chat.id) schedule_delete_message(message.chat.id, msg.id) -@bot.callback_query_handler(cb_query_equals('add_time'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("add_time"), is_admin=True) def add_time_command(call: types.CallbackQuery): msg = bot.edit_message_text( - f"📅 Enter Days to increase or decrease expiry:", + "📅 Enter Days to increase or decrease expiry:", call.message.chat.id, call.message.message_id, - reply_markup=BotKeyboard.inline_cancel_action()) + reply_markup=BotKeyboard.inline_cancel_action(), + ) schedule_delete_message(call.message.chat.id, call.message.id) schedule_delete_message(call.message.chat.id, msg.id) return bot.register_next_step_handler(call.message, add_time_step) @@ -287,7 +291,7 @@ def add_time_step(message): if not days: raise ValueError except ValueError: - wait_msg = bot.send_message(message.chat.id, '❌ Days must be as a number and not zero.') + wait_msg = bot.send_message(message.chat.id, "❌ Days must be as a number and not zero.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, add_time_step) schedule_delete_message(message.chat.id, message.message_id) @@ -295,7 +299,8 @@ def add_time_step(message): message.chat.id, f"⚠️ Are you sure? this will change Expiry Time of all users according to {days} Days", parse_mode="html", - reply_markup=BotKeyboard.confirm_action('add_time', days)) + reply_markup=BotKeyboard.confirm_action("add_time", days), + ) cleanup_messages(message.chat.id) schedule_delete_message(message.chat.id, msg.id) @@ -307,7 +312,8 @@ def inbound_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.inbounds_menu(call.data, xray.config.inbounds_by_tag)) + reply_markup=BotKeyboard.inbounds_menu(call.data, xray.config.inbounds_by_tag), + ) @bot.callback_query_handler(cb_query_startswith("confirm_inbound"), is_admin=True) @@ -317,7 +323,8 @@ def delete_expired_confirm_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action(action=call.data[8:])) + reply_markup=BotKeyboard.confirm_action(action=call.data[8:]), + ) @bot.callback_query_handler(cb_query_startswith("edit:"), is_admin=True) @@ -327,27 +334,23 @@ def edit_command(call: types.CallbackQuery): with GetDB() as db: db_user = crud.get_user(db, username) if not db_user: - return bot.answer_callback_query( - call.id, - '❌ User not found.', - show_alert=True - ) + return bot.answer_callback_query(call.id, "❌ User not found.", show_alert=True) user = UserResponse.model_validate(db_user) - mem_store.set(f'{call.message.chat.id}:username', username) - mem_store.set(f'{call.message.chat.id}:data_limit', db_user.data_limit) + mem_store.set(f"{call.message.chat.id}:username", username) + mem_store.set(f"{call.message.chat.id}:data_limit", db_user.data_limit) # if status is on_hold set expire_date to an integer that is duration else set a datetime if db_user.status == UserStatus.on_hold: - mem_store.set(f'{call.message.chat.id}:expire_date', db_user.on_hold_expire_duration) - mem_store.set(f'{call.message.chat.id}:expire_on_hold_timeout', db_user.on_hold_timeout) + mem_store.set(f"{call.message.chat.id}:expire_date", db_user.on_hold_expire_duration) + mem_store.set(f"{call.message.chat.id}:expire_on_hold_timeout", db_user.on_hold_timeout) expire_date = db_user.on_hold_expire_duration else: - mem_store.set(f'{call.message.chat.id}:expire_date', - db_user.expire if db_user.expire else None) + mem_store.set(f"{call.message.chat.id}:expire_date", db_user.expire if db_user.expire else None) expire_date = db_user.expire if db_user.expire else None mem_store.set( - f'{call.message.chat.id}:protocols', - {protocol.value: inbounds for protocol, inbounds in db_user.inbounds.items()}) + f"{call.message.chat.id}:protocols", + {protocol.value: inbounds for protocol, inbounds in db_user.inbounds.items()}, + ) bot.edit_message_text( f"📝 Editing user `{username}`", call.message.chat.id, @@ -360,21 +363,17 @@ def edit_command(call: types.CallbackQuery): data_limit=db_user.data_limit, expire_date=expire_date, expire_on_hold_duration=expire_date if isinstance(expire_date, int) else None, - expire_on_hold_timeout=mem_store.get(f'{call.message.chat.id}:expire_on_hold_timeout'), - ) + expire_on_hold_timeout=mem_store.get(f"{call.message.chat.id}:expire_on_hold_timeout"), + ), ) -@bot.callback_query_handler(cb_query_equals('help_edit'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("help_edit"), is_admin=True) def help_edit_command(call: types.CallbackQuery): - bot.answer_callback_query( - call.id, - text="Press the (✏️ Edit) button to edit", - show_alert=True - ) + bot.answer_callback_query(call.id, text="Press the (✏️ Edit) button to edit", show_alert=True) -@bot.callback_query_handler(cb_query_equals('cancel'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("cancel"), is_admin=True) def cancel_command(call: types.CallbackQuery): bot.clear_step_handler_by_chat_id(call.message.chat.id) return bot.edit_message_text( @@ -382,11 +381,11 @@ def cancel_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="MarkdownV2", - reply_markup=BotKeyboard.main_menu() + reply_markup=BotKeyboard.main_menu(), ) -@bot.callback_query_handler(cb_query_startswith('edit_user:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("edit_user:"), is_admin=True) def edit_user_command(call: types.CallbackQuery): _, username, action = call.data.split(":") schedule_delete_message(call.message.chat.id, call.message.id) @@ -395,13 +394,12 @@ def edit_user_command(call: types.CallbackQuery): if action == "data": msg = bot.send_message( call.message.chat.id, - '📶 Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.', - reply_markup=BotKeyboard.inline_cancel_action(f'user:{username}') + "📶 Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.", + reply_markup=BotKeyboard.inline_cancel_action(f"user:{username}"), ) mem_store.set(f"{call.message.chat.id}:edit_msg_text", call.message.text) bot.clear_step_handler_by_chat_id(call.message.chat.id) - bot.register_next_step_handler( - call.message, edit_user_data_limit_step, username) + bot.register_next_step_handler(call.message, edit_user_data_limit_step, username) schedule_delete_message(call.message.chat.id, msg.message_id) elif action == "expire": text = """\ @@ -419,13 +417,13 @@ def edit_user_command(call: types.CallbackQuery): call.message.chat.id, text, parse_mode="markdown", - reply_markup=BotKeyboard.inline_cancel_action(f'user:{username}')) + reply_markup=BotKeyboard.inline_cancel_action(f"user:{username}"), + ) mem_store.set(f"{call.message.chat.id}:edit_msg_text", call.message.text) bot.clear_step_handler_by_chat_id(call.message.chat.id) - bot.register_next_step_handler( - call.message, edit_user_expire_step, username=username) + bot.register_next_step_handler(call.message, edit_user_expire_step, username=username) schedule_delete_message(call.message.chat.id, msg.message_id) - elif action == 'expire_on_hold_timeout': + elif action == "expire_on_hold_timeout": text = """\ 📅 Enter Timeout for on hold `3d` for 3 days @@ -436,7 +434,8 @@ def edit_user_command(call: types.CallbackQuery): call.message.chat.id, text, parse_mode="markdown", - reply_markup=BotKeyboard.inline_cancel_action(f'user:{username}')) + reply_markup=BotKeyboard.inline_cancel_action(f"user:{username}"), + ) bot.clear_step_handler_by_chat_id(call.message.chat.id) bot.register_next_step_handler(call.message, edit_user_expire_on_hold_timeout_step, username=username) schedule_delete_message(call.message.chat.id, msg.message_id) @@ -446,13 +445,13 @@ def edit_user_expire_on_hold_timeout_step(message: types.Message, username: str) try: now = datetime.now() today = datetime(year=now.year, month=now.month, day=now.day, hour=23, minute=59, second=59) - if re.match(r'^[0-9]{1,3}([MmDd])$', message.text): + if re.match(r"^[0-9]{1,3}([MmDd])$", message.text): expire_on_hold_timeout = today - number = int(re.findall(r'^[0-9]{1,3}', message.text)[0]) - symbol = re.findall('[MmDd]$', message.text)[0].upper() - if symbol == 'M': + number = int(re.findall(r"^[0-9]{1,3}", message.text)[0]) + symbol = re.findall("[MmDd]$", message.text)[0].upper() + if symbol == "M": expire_on_hold_timeout = today + relativedelta(months=number) - elif symbol == 'D': + elif symbol == "D": expire_on_hold_timeout = today + relativedelta(days=number) elif not message.text.isnumeric(): expire_on_hold_timeout = datetime.strptime(message.text, "%Y-%m-%d") @@ -461,15 +460,15 @@ def edit_user_expire_on_hold_timeout_step(message: types.Message, username: str) else: raise ValueError if expire_on_hold_timeout and expire_on_hold_timeout < today: - wait_msg = bot.send_message(message.chat.id, '❌ Expire date must be greater than today.') + wait_msg = bot.send_message(message.chat.id, "❌ Expire date must be greater than today.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, edit_user_expire_on_hold_timeout_step, username=username) except ValueError: - wait_msg = bot.send_message(message.chat.id, '❌ Date is not in any of valid formats.') + wait_msg = bot.send_message(message.chat.id, "❌ Date is not in any of valid formats.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, edit_user_expire_on_hold_timeout_step, username=username) - mem_store.set(f'{message.chat.id}:expire_on_hold_timeout', expire_on_hold_timeout) + mem_store.set(f"{message.chat.id}:expire_on_hold_timeout", expire_on_hold_timeout) expire_date = mem_store.get(f"{message.chat.id}:expire_date") schedule_delete_message(message.chat.id, message.message_id) bot.send_message( @@ -477,11 +476,13 @@ def edit_user_expire_on_hold_timeout_step(message: types.Message, username: str) f"📝 Editing user: {username}", parse_mode="html", reply_markup=BotKeyboard.select_protocols( - mem_store.get(f'{message.chat.id}:protocols'), "edit", - username=username, data_limit=mem_store.get(f'{message.chat.id}:data_limit'), + mem_store.get(f"{message.chat.id}:protocols"), + "edit", + username=username, + data_limit=mem_store.get(f"{message.chat.id}:data_limit"), expire_on_hold_duration=expire_date if isinstance(expire_date, int) else None, - expire_on_hold_timeout=mem_store.get(f'{message.chat.id}:expire_on_hold_timeout') - ) + expire_on_hold_timeout=mem_store.get(f"{message.chat.id}:expire_on_hold_timeout"), + ), ) cleanup_messages(message.chat.id) @@ -489,15 +490,15 @@ def edit_user_expire_on_hold_timeout_step(message: types.Message, username: str) def edit_user_data_limit_step(message: types.Message, username: str): try: if float(message.text) < 0: - wait_msg = bot.send_message(message.chat.id, '❌ Data limit must be greater or equal to 0.') + wait_msg = bot.send_message(message.chat.id, "❌ Data limit must be greater or equal to 0.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, edit_user_data_limit_step, username=username) data_limit = float(message.text) * 1024 * 1024 * 1024 except ValueError: - wait_msg = bot.send_message(message.chat.id, '❌ Data limit must be a number.') + wait_msg = bot.send_message(message.chat.id, "❌ Data limit must be a number.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, edit_user_data_limit_step, username=username) - mem_store.set(f'{message.chat.id}:data_limit', data_limit) + mem_store.set(f"{message.chat.id}:data_limit", data_limit) schedule_delete_message(message.chat.id, message.message_id) text = mem_store.get(f"{message.chat.id}:edit_msg_text") mem_store.delete(f"{message.chat.id}:edit_msg_text") @@ -506,27 +507,32 @@ def edit_user_data_limit_step(message: types.Message, username: str): text or f"📝 Editing user {username}", parse_mode="html", reply_markup=BotKeyboard.select_protocols( - mem_store.get(f'{message.chat.id}:protocols'), "edit", - username=username, data_limit=data_limit, expire_date=mem_store.get(f'{message.chat.id}:expire_date'))) + mem_store.get(f"{message.chat.id}:protocols"), + "edit", + username=username, + data_limit=data_limit, + expire_date=mem_store.get(f"{message.chat.id}:expire_date"), + ), + ) cleanup_messages(message.chat.id) def edit_user_expire_step(message: types.Message, username: str): - last_expiry = mem_store.get(f'{message.chat.id}:expire_date') + last_expiry = mem_store.get(f"{message.chat.id}:expire_date") try: now = datetime.now() today = datetime(year=now.year, month=now.month, day=now.day, hour=23, minute=59, second=59) - if re.match(r'^[0-9]{1,3}([MmDd])$', message.text): + if re.match(r"^[0-9]{1,3}([MmDd])$", message.text): expire_date = today - number_pattern = r'^[0-9]{1,3}' + number_pattern = r"^[0-9]{1,3}" number = int(re.findall(number_pattern, message.text)[0]) - symbol_pattern = r'[MmDd]$' + symbol_pattern = r"[MmDd]$" symbol = re.findall(symbol_pattern, message.text)[0].upper() - if symbol == 'M': + if symbol == "M": expire_date = today + relativedelta(months=number) if isinstance(last_expiry, int): expire_date = number * 24 * 60 * 60 * 30 - elif symbol == 'D': + elif symbol == "D": expire_date = today + relativedelta(days=number) if isinstance(last_expiry, int): expire_date = number * 24 * 60 * 60 @@ -537,15 +543,15 @@ def edit_user_expire_step(message: types.Message, username: str): else: raise ValueError if expire_date and isinstance(expire_date, datetime) and expire_date < today: - wait_msg = bot.send_message(message.chat.id, '❌ Expire date must be greater than today.') + wait_msg = bot.send_message(message.chat.id, "❌ Expire date must be greater than today.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, edit_user_expire_step, username=username) except ValueError: - wait_msg = bot.send_message(message.chat.id, '❌ Date is not in any of valid formats.') + wait_msg = bot.send_message(message.chat.id, "❌ Date is not in any of valid formats.") schedule_delete_message(message.chat.id, wait_msg.message_id) return bot.register_next_step_handler(wait_msg, edit_user_expire_step, username=username) - mem_store.set(f'{message.chat.id}:expire_date', expire_date) + mem_store.set(f"{message.chat.id}:expire_date", expire_date) schedule_delete_message(message.chat.id, message.message_id) text = mem_store.get(f"{message.chat.id}:edit_msg_text") mem_store.delete(f"{message.chat.id}:edit_msg_text") @@ -554,17 +560,21 @@ def edit_user_expire_step(message: types.Message, username: str): text or f"📝 Editing user: {username}", parse_mode="html", reply_markup=BotKeyboard.select_protocols( - mem_store.get(f'{message.chat.id}:protocols'), "edit", - username=username, data_limit=mem_store.get(f'{message.chat.id}:data_limit'), + mem_store.get(f"{message.chat.id}:protocols"), + "edit", + username=username, + data_limit=mem_store.get(f"{message.chat.id}:data_limit"), expire_date=expire_date, expire_on_hold_duration=expire_date if isinstance(expire_date, int) else None, - expire_on_hold_timeout=mem_store.get(f'{message.chat.id}:expire_on_hold_timeout'))) + expire_on_hold_timeout=mem_store.get(f"{message.chat.id}:expire_on_hold_timeout"), + ), + ) cleanup_messages(message.chat.id) -@bot.callback_query_handler(cb_query_startswith('users:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("users:"), is_admin=True) def users_command(call: types.CallbackQuery): - page = int(call.data.split(':')[1]) if len(call.data.split(':')) > 1 else 1 + page = int(call.data.split(":")[1]) if len(call.data.split(":")) > 1 else 1 with GetDB() as db: total_pages = math.ceil(crud.get_users_count(db) / 10) users = crud.get_users(db, offset=(page - 1) * 10, limit=10, sort=[crud.UsersSortingOptions["-created_at"]]) @@ -580,50 +590,53 @@ def users_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.user_list( - users, page, total_pages=total_pages) + reply_markup=BotKeyboard.user_list(users, page, total_pages=total_pages), ) -@bot.callback_query_handler(cb_query_startswith('edit_note:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("edit_note:"), is_admin=True) def edit_note_command(call: types.CallbackQuery): - username = call.data.split(':')[1] + username = call.data.split(":")[1] with GetDB() as db: db_user = crud.get_user(db, username) if not db_user: - return bot.answer_callback_query(call.id, '❌ User not found.', show_alert=True) + return bot.answer_callback_query(call.id, "❌ User not found.", show_alert=True) schedule_delete_message(call.message.chat.id, call.message.id) cleanup_messages(call.message.chat.id) msg = bot.send_message( call.message.chat.id, - f'📝 Current Note: {db_user.note}\n\nSend new Note for {username}', + f"📝 Current Note: {db_user.note}\n\nSend new Note for {username}", parse_mode="HTML", - reply_markup=BotKeyboard.inline_cancel_action(f'user:{username}')) - mem_store.set(f'{call.message.chat.id}:username', username) + reply_markup=BotKeyboard.inline_cancel_action(f"user:{username}"), + ) + mem_store.set(f"{call.message.chat.id}:username", username) schedule_delete_message(call.message.chat.id, msg.id) bot.register_next_step_handler(msg, edit_note_step) def edit_note_step(message: types.Message): - note = message.text or '' + note = message.text or "" if len(note) > 500: - wait_msg = bot.send_message(message.chat.id, '❌ Note can not be more than 500 characters.') + wait_msg = bot.send_message(message.chat.id, "❌ Note can not be more than 500 characters.") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, edit_note_step) with GetDB() as db: - username = mem_store.get(f'{message.chat.id}:username') + username = mem_store.get(f"{message.chat.id}:username") if not username: cleanup_messages(message.chat.id) - bot.reply_to(message, '❌ Something went wrong!\n restart bot /start') + bot.reply_to(message, "❌ Something went wrong!\n restart bot /start") db_user = crud.get_user(db, username) last_note = db_user.note modify = UserModify(note=note) db_user = crud.update_user(db, db_user, modify) user = UserResponse.model_validate(db_user) bot.reply_to( - message, get_user_info_text(db_user), parse_mode="html", - reply_markup=BotKeyboard.user_menu(user_info={'status': user.status, 'username': user.username})) + message, + get_user_info_text(db_user), + parse_mode="html", + reply_markup=BotKeyboard.user_menu(user_info={"status": user.status, "username": user.username}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 📝 #Edit_Note #From_Bot @@ -634,25 +647,28 @@ def edit_note_step(message: types.Message): ➖➖➖➖➖➖➖➖➖ By : {message.from_user.full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass -@bot.callback_query_handler(cb_query_startswith('user:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("user:"), is_admin=True) def user_command(call: types.CallbackQuery): bot.clear_step_handler_by_chat_id(call.message.chat.id) - username = call.data.split(':')[1] - page = int(call.data.split(':')[2]) if len(call.data.split(':')) > 2 else 1 + username = call.data.split(":")[1] + page = int(call.data.split(":")[2]) if len(call.data.split(":")) > 2 else 1 with GetDB() as db: db_user = crud.get_user(db, username) if not db_user: - return bot.answer_callback_query(call.id, '❌ User not found.', show_alert=True) + return bot.answer_callback_query(call.id, "❌ User not found.", show_alert=True) user = UserResponse.model_validate(db_user) bot.edit_message_text( get_user_info_text(db_user), - call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.user_menu({'username': user.username, 'status': user.status}, page=page)) + call.message.chat.id, + call.message.message_id, + parse_mode="HTML", + reply_markup=BotKeyboard.user_menu({"username": user.username, "status": user.status}, page=page), + ) @bot.callback_query_handler(cb_query_startswith("revoke_sub:"), is_admin=True) @@ -663,7 +679,8 @@ def revoke_sub_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="markdown", - reply_markup=BotKeyboard.confirm_action(action=call.data)) + reply_markup=BotKeyboard.confirm_action(action=call.data), + ) @bot.callback_query_handler(cb_query_startswith("links:"), is_admin=True) @@ -680,16 +697,16 @@ def links_command(call: types.CallbackQuery): text = f"{user.subscription_url}\n\n\n" for link in get_v2ray_links(user): if len(text) > 4056: - text += '\n\n...' + text += "\n\n..." break - text += f'\n{link}' + text += f"\n{link}" bot.edit_message_text( text, call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.show_links(username) + reply_markup=BotKeyboard.show_links(username), ) @@ -707,19 +724,14 @@ def genqr_command(call: types.CallbackQuery): bot.answer_callback_query(call.id, "Generating QR code...") - if qr_select == 'configs': + if qr_select == "configs": for link in get_v2ray_links(user): f = io.BytesIO() qr = qrcode.QRCode(border=6) qr.add_data(link) qr.make_image().save(f) f.seek(0) - bot.send_photo( - call.message.chat.id, - photo=f, - caption=f"{link}", - parse_mode="HTML" - ) + bot.send_photo(call.message.chat.id, photo=f, caption=f"{link}", parse_mode="HTML") else: data_limit = readable_size(user.data_limit) if user.data_limit else "Unlimited" used_traffic = readable_size(user.used_traffic) if user.used_traffic else "-" @@ -753,7 +765,7 @@ def genqr_command(call: types.CallbackQuery): photo=f, caption=text, parse_mode="HTML", - reply_markup=BotKeyboard.subscription_page(user.subscription_url) + reply_markup=BotKeyboard.subscription_page(user.subscription_url), ) try: bot.delete_message(call.message.chat.id, call.message.message_id) @@ -763,19 +775,14 @@ def genqr_command(call: types.CallbackQuery): text = f"{user.subscription_url}\n\n\n" for link in get_v2ray_links(user): if len(text) > 4056: - text += '\n\n...' + text += "\n\n..." break text += f"{link}\n\n" - bot.send_message( - call.message.chat.id, - text, - "HTML", - reply_markup=BotKeyboard.show_links(username) - ) + bot.send_message(call.message.chat.id, text, "HTML", reply_markup=BotKeyboard.show_links(username)) -@bot.callback_query_handler(cb_query_startswith('template_charge:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("template_charge:"), is_admin=True) def template_charge_command(call: types.CallbackQuery): _, template_id, username = call.data.split(":") now = datetime.now() @@ -791,21 +798,28 @@ def template_charge_command(call: types.CallbackQuery): return bot.answer_callback_query(call.id, "User not found!", show_alert=True) user = UserResponse.model_validate(db_user) if (user.data_limit and not user.expire) or (not user.data_limit and user.expire): - expire = (db_user.expire if db_user.expire else today) + expire = db_user.expire if db_user.expire else today expire += relativedelta(seconds=template.expire_duration) db_user.expire = expire.timestamp() - db_user.data_limit = (user.data_limit - user.used_traffic + template.data_limit - ) if user.data_limit else template.data_limit + db_user.data_limit = ( + (user.data_limit - user.used_traffic + template.data_limit) if user.data_limit else template.data_limit + ) db_user.status = UserStatus.active bot.edit_message_text( f"""\ ‼️ If add template Data limit and Time to the user, the user will be this:\n\n\ {get_user_info_text(db_user)}\n\n\ Add template Data limit and Time to user or Reset to Template default⁉️""", - call.message.chat.id, call.message.message_id, parse_mode='html', - reply_markup=BotKeyboard.charge_add_or_reset( - username=username, template_id=template_id)) - elif (not user.data_limit and not user.expire) or (user.used_traffic > user.data_limit) or (now > datetime.fromtimestamp(user.expire)): + call.message.chat.id, + call.message.message_id, + parse_mode="html", + reply_markup=BotKeyboard.charge_add_or_reset(username=username, template_id=template_id), + ) + elif ( + (not user.data_limit and not user.expire) + or (user.used_traffic > user.data_limit) + or (now > datetime.fromtimestamp(user.expire)) + ): crud.reset_user_data_usage(db, db_user) expire_date = None if template.expire_duration: @@ -822,8 +836,9 @@ def template_charge_command(call: types.CallbackQuery): get_user_info_text(db_user), call.message.chat.id, call.message.message_id, - parse_mode='html', - reply_markup=BotKeyboard.user_menu(user_info={'status': 'active', 'username': user.username})) + parse_mode="html", + reply_markup=BotKeyboard.user_menu(user_info={"status": "active", "username": user.username}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 🔋 #Charged #Reset #From_Bot @@ -843,27 +858,30 @@ def template_charge_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {call.from_user.full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass else: - expire = (db_user.expire if db_user.expire else today) + expire = db_user.expire if db_user.expire else today expire += relativedelta(seconds=template.expire_duration) db_user.expire = expire.timestamp() - db_user.data_limit = (user.data_limit - user.used_traffic + template.data_limit - ) if user.data_limit else template.data_limit + db_user.data_limit = ( + (user.data_limit - user.used_traffic + template.data_limit) if user.data_limit else template.data_limit + ) db_user.status = UserStatus.active bot.edit_message_text( f"""\ ‼️ If add template Data limit and Time to the user, the user will be this:\n\n\ {get_user_info_text(db_user)}\n\n\ Add template Data limit and Time to user or Reset to Template default⁉️""", - call.message.chat.id, call.message.message_id, parse_mode='html', - reply_markup=BotKeyboard.charge_add_or_reset( - username=username, template_id=template_id)) + call.message.chat.id, + call.message.message_id, + parse_mode="html", + reply_markup=BotKeyboard.charge_add_or_reset(username=username, template_id=template_id), + ) -@bot.callback_query_handler(cb_query_startswith('charge:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("charge:"), is_admin=True) def charge_command(call: types.CallbackQuery): username = call.data.split(":")[1] with GetDB() as db: @@ -879,16 +897,16 @@ def charge_command(call: types.CallbackQuery): f"{call.message.html_text}\n\n🔢 Select User Template to charge:", call.message.chat.id, call.message.message_id, - parse_mode='html', + parse_mode="html", reply_markup=BotKeyboard.templates_menu( {template.name: template.id for template in templates}, username=username, - ) + ), ) -@bot.callback_query_handler(cb_query_equals('template_add_user'), is_admin=True) -@bot.callback_query_handler(cb_query_equals('template_add_bulk_user'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("template_add_user"), is_admin=True) +@bot.callback_query_handler(cb_query_equals("template_add_bulk_user"), is_admin=True) def add_user_from_template_command(call: types.CallbackQuery): with GetDB() as db: templates = crud.get_user_templates(db) @@ -906,12 +924,12 @@ def add_user_from_template_command(call: types.CallbackQuery): "Select a Template to create user from:", call.message.chat.id, call.message.message_id, - parse_mode='html', - reply_markup=BotKeyboard.templates_menu({template.name: template.id for template in templates}) + parse_mode="html", + reply_markup=BotKeyboard.templates_menu({template.name: template.id for template in templates}), ) -@bot.callback_query_handler(cb_query_startswith('template_add_user:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("template_add_user:"), is_admin=True) def add_user_from_template(call: types.CallbackQuery): template_id = int(call.data.split(":")[1]) with GetDB() as db: @@ -927,45 +945,42 @@ def add_user_from_template(call: types.CallbackQuery): text += f"\n⚠️ Username will be suffixed with {template.username_suffix}" mem_store.set(f"{call.message.chat.id}:template_id", template.id) - template_msg = bot.edit_message_text( - text, - call.message.chat.id, - call.message.message_id, - parse_mode="HTML" - ) - text = '👤 Enter username:\n⚠️ Username only can be 3 to 32 characters and contain a-z, A-Z, 0-9, and underscores in between.' + template_msg = bot.edit_message_text(text, call.message.chat.id, call.message.message_id, parse_mode="HTML") + text = "👤 Enter username:\n⚠️ Username only can be 3 to 32 characters and contain a-z, A-Z, 0-9, and underscores in between." msg = bot.send_message( - call.message.chat.id, - text, - parse_mode="HTML", - reply_markup=BotKeyboard.random_username(template_id=template.id) + call.message.chat.id, text, parse_mode="HTML", reply_markup=BotKeyboard.random_username(template_id=template.id) ) schedule_delete_message(call.message.chat.id, template_msg.message_id, msg.id) bot.register_next_step_handler(template_msg, add_user_from_template_username_step) -@bot.callback_query_handler(cb_query_startswith('random'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("random"), is_admin=True) def random_username(call: types.CallbackQuery): bot.clear_step_handler_by_chat_id(call.message.chat.id) template_id = int(call.data.split(":")[1] or 0) - mem_store.delete(f'{call.message.chat.id}:template_id') + mem_store.delete(f"{call.message.chat.id}:template_id") - username = ''.join([random.choice(string.ascii_letters)] + - random.choices(string.ascii_letters + string.digits, k=7)) + username = "".join( + [random.choice(string.ascii_letters)] + random.choices(string.ascii_letters + string.digits, k=7) + ) schedule_delete_message(call.message.chat.id, call.message.id) cleanup_messages(call.message.chat.id) - if mem_store.get(f"{call.message.chat.id}:is_bulk", False) and not mem_store.get(f"{call.message.chat.id}:is_bulk_from_template", False): - msg = bot.send_message(call.message.chat.id, - 'how many do you want?', - reply_markup=BotKeyboard.inline_cancel_action()) + if mem_store.get(f"{call.message.chat.id}:is_bulk", False) and not mem_store.get( + f"{call.message.chat.id}:is_bulk_from_template", False + ): + msg = bot.send_message( + call.message.chat.id, "how many do you want?", reply_markup=BotKeyboard.inline_cancel_action() + ) schedule_delete_message(call.message.chat.id, msg.id) return bot.register_next_step_handler(msg, add_user_bulk_number_step, username=username) if not template_id: - msg = bot.send_message(call.message.chat.id, - '⬆️ Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.', - reply_markup=BotKeyboard.inline_cancel_action()) + msg = bot.send_message( + call.message.chat.id, + "⬆️ Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.", + reply_markup=BotKeyboard.inline_cancel_action(), + ) schedule_delete_message(call.message.chat.id, msg.id) return bot.register_next_step_handler(call.message, add_user_data_limit_step, username=username) @@ -992,17 +1007,18 @@ def random_username(call: types.CallbackQuery): mem_store.set(f"{call.message.chat.id}:template_info_text", text) if mem_store.get(f"{call.message.chat.id}:is_bulk", False): - msg = bot.send_message(call.message.chat.id, - 'how many do you want?', - reply_markup=BotKeyboard.inline_cancel_action()) + msg = bot.send_message( + call.message.chat.id, "how many do you want?", reply_markup=BotKeyboard.inline_cancel_action() + ) schedule_delete_message(call.message.chat.id, msg.id) return bot.register_next_step_handler(msg, add_user_bulk_number_step, username=username) else: if expire_date: msg = bot.send_message( call.message.chat.id, - '⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now', - reply_markup=BotKeyboard.user_status_select()) + "⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now", + reply_markup=BotKeyboard.user_status_select(), + ) schedule_delete_message(call.message.chat.id, msg.id) else: mem_store.set(f"{call.message.chat.id}:template_info_text", None) @@ -1016,7 +1032,9 @@ def random_username(call: types.CallbackQuery): "create_from_template", username=username, data_limit=template.data_limit, - expire_date=expire_date,)) + expire_date=expire_date, + ), + ) def add_user_from_template_username_step(message: types.Message): @@ -1025,7 +1043,7 @@ def add_user_from_template_username_step(message: types.Message): return bot.send_message(message.chat.id, "An error occurred in the process! try again.") if not message.text: - wait_msg = bot.send_message(message.chat.id, '❌ Username can not be empty.') + wait_msg = bot.send_message(message.chat.id, "❌ Username can not be empty.") schedule_delete_message(message.chat.id, wait_msg.message_id, message.message_id) return bot.register_next_step_handler(wait_msg, add_user_from_template_username_step) @@ -1042,7 +1060,8 @@ def add_user_from_template_username_step(message: types.Message): if not match: wait_msg = bot.send_message( message.chat.id, - '❌ Username only can be 3 to 32 characters and contain a-z, A-Z, 0-9, and underscores in between.') + "❌ Username only can be 3 to 32 characters and contain a-z, A-Z, 0-9, and underscores in between.", + ) schedule_delete_message(message.chat.id, wait_msg.message_id, message.message_id) return bot.register_next_step_handler(wait_msg, add_user_from_template_username_step) @@ -1051,7 +1070,8 @@ def add_user_from_template_username_step(message: types.Message): message.chat.id, f"❌ Username can't be generated because is shorter than 32 characters! username: { username}", - parse_mode="HTML") + parse_mode="HTML", + ) schedule_delete_message(message.chat.id, wait_msg.message_id, message.message_id) return bot.register_next_step_handler(wait_msg, add_user_from_template_username_step) elif len(username) > 32: @@ -1059,12 +1079,13 @@ def add_user_from_template_username_step(message: types.Message): message.chat.id, f"❌ Username can't be generated because is longer than 32 characters! username: { username}", - parse_mode="HTML") + parse_mode="HTML", + ) schedule_delete_message(message.chat.id, wait_msg.message_id, message.message_id) return bot.register_next_step_handler(wait_msg, add_user_from_template_username_step) if crud.get_user(db, username): - wait_msg = bot.send_message(message.chat.id, '❌ Username already exists.') + wait_msg = bot.send_message(message.chat.id, "❌ Username already exists.") schedule_delete_message(message.chat.id, wait_msg.message_id, message.message_id) return bot.register_next_step_handler(wait_msg, add_user_from_template_username_step) template = UserTemplateResponse.model_validate(template) @@ -1083,17 +1104,18 @@ def add_user_from_template_username_step(message: types.Message): mem_store.set(f"{message.chat.id}:template_info_text", text) if mem_store.get(f"{message.chat.id}:is_bulk", False): - msg = bot.send_message(message.chat.id, - 'how many do you want?', - reply_markup=BotKeyboard.inline_cancel_action()) + msg = bot.send_message( + message.chat.id, "how many do you want?", reply_markup=BotKeyboard.inline_cancel_action() + ) schedule_delete_message(message.chat.id, msg.id) return bot.register_next_step_handler(msg, add_user_bulk_number_step, username=username) else: if expire_date: msg = bot.send_message( message.chat.id, - '⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now', - reply_markup=BotKeyboard.user_status_select()) + "⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now", + reply_markup=BotKeyboard.user_status_select(), + ) schedule_delete_message(message.chat.id, msg.id) else: mem_store.set(f"{message.chat.id}:template_info_text", None) @@ -1107,11 +1129,13 @@ def add_user_from_template_username_step(message: types.Message): "create_from_template", username=username, data_limit=template.data_limit, - expire_date=expire_date,)) + expire_date=expire_date, + ), + ) -@bot.callback_query_handler(cb_query_equals('add_bulk_user'), is_admin=True) -@bot.callback_query_handler(cb_query_equals('add_user'), is_admin=True) +@bot.callback_query_handler(cb_query_equals("add_bulk_user"), is_admin=True) +@bot.callback_query_handler(cb_query_equals("add_user"), is_admin=True) def add_user_command(call: types.CallbackQuery): try: bot.delete_message(call.message.chat.id, call.message.message_id) @@ -1127,9 +1151,10 @@ def add_user_command(call: types.CallbackQuery): username_msg = bot.send_message( call.message.chat.id, - '👤 Enter username:\n⚠️Username only can be 3 to 32 characters and contain a-z, A-Z 0-9, and underscores in ' - 'between.', - reply_markup=BotKeyboard.random_username()) + "👤 Enter username:\n⚠️Username only can be 3 to 32 characters and contain a-z, A-Z 0-9, and underscores in " + "between.", + reply_markup=BotKeyboard.random_username(), + ) schedule_delete_message(call.message.chat.id, username_msg.id) bot.register_next_step_handler(username_msg, add_user_username_step) @@ -1137,34 +1162,37 @@ def add_user_command(call: types.CallbackQuery): def add_user_username_step(message: types.Message): username = message.text if not username: - wait_msg = bot.send_message(message.chat.id, '❌ Username can not be empty.') + wait_msg = bot.send_message(message.chat.id, "❌ Username can not be empty.") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_user_username_step) if not re.match(r"^(?=\w{3,32}\b)[a-zA-Z0-9-_@.]+(?:_[a-zA-Z0-9-_@.]+)*$", username): wait_msg = bot.send_message( message.chat.id, - '❌ Username only can be 3 to 32 characters and contain a-z, A-Z, 0-9, and underscores in between.') + "❌ Username only can be 3 to 32 characters and contain a-z, A-Z, 0-9, and underscores in between.", + ) schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_user_username_step) with GetDB() as db: if crud.get_user(db, username): - wait_msg = bot.send_message(message.chat.id, '❌ Username already exists.') + wait_msg = bot.send_message(message.chat.id, "❌ Username already exists.") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_user_username_step) schedule_delete_message(message.chat.id, message.id) cleanup_messages(message.chat.id) if mem_store.get(f"{message.chat.id}:is_bulk", False): - msg = bot.send_message(message.chat.id, - 'how many do you want?', - reply_markup=BotKeyboard.inline_cancel_action()) + msg = bot.send_message( + message.chat.id, "how many do you want?", reply_markup=BotKeyboard.inline_cancel_action() + ) schedule_delete_message(message.chat.id, msg.id) return bot.register_next_step_handler(msg, add_user_bulk_number_step, username=username) - msg = bot.send_message(message.chat.id, - '⬆️ Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.', - reply_markup=BotKeyboard.inline_cancel_action()) + msg = bot.send_message( + message.chat.id, + "⬆️ Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.", + reply_markup=BotKeyboard.inline_cancel_action(), + ) schedule_delete_message(message.chat.id, msg.id) bot.register_next_step_handler(msg, add_user_data_limit_step, username=username) @@ -1172,13 +1200,13 @@ def add_user_username_step(message: types.Message): def add_user_bulk_number_step(message: types.Message, username: str): try: if int(message.text) < 1: - wait_msg = bot.send_message(message.chat.id, '❌ Bulk number must be greater or equal to 1.') + wait_msg = bot.send_message(message.chat.id, "❌ Bulk number must be greater or equal to 1.") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_user_bulk_number_step, username=username) - mem_store.set(f'{message.chat.id}:number', int(message.text)) + mem_store.set(f"{message.chat.id}:number", int(message.text)) except ValueError: - wait_msg = bot.send_message(message.chat.id, '❌ bulk must be a number.') + wait_msg = bot.send_message(message.chat.id, "❌ bulk must be a number.") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_user_bulk_number_step, username=username) @@ -1186,20 +1214,21 @@ def add_user_bulk_number_step(message: types.Message, username: str): schedule_delete_message(message.chat.id, message.id) cleanup_messages(message.chat.id) if mem_store.get(f"{message.chat.id}:is_bulk_from_template", False): - expire_date = mem_store.get(f'{message.chat.id}:expire_date') + expire_date = mem_store.get(f"{message.chat.id}:expire_date") if expire_date: msg = bot.send_message( message.chat.id, - '⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now', - reply_markup=BotKeyboard.user_status_select()) + "⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now", + reply_markup=BotKeyboard.user_status_select(), + ) schedule_delete_message(message.chat.id, msg.id) return else: text = mem_store.get(f"{message.chat.id}:template_info_text") mem_store.set(f"{message.chat.id}:template_info_text", None) inbounds = mem_store.get(f"{message.chat.id}:protocols") - mem_store.set(f'{message.chat.id}:user_status', UserStatus.active) - data_limit = mem_store.get(f'{message.chat.id}:data_limit') + mem_store.set(f"{message.chat.id}:user_status", UserStatus.active) + data_limit = mem_store.get(f"{message.chat.id}:data_limit") return bot.send_message( message.chat.id, text, @@ -1209,11 +1238,15 @@ def add_user_bulk_number_step(message: types.Message, username: str): "create_from_template", username=username, data_limit=data_limit, - expire_date=expire_date,)) + expire_date=expire_date, + ), + ) - msg = bot.send_message(message.chat.id, - '⬆️ Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.', - reply_markup=BotKeyboard.inline_cancel_action()) + msg = bot.send_message( + message.chat.id, + "⬆️ Enter Data Limit (GB):\n⚠️ Send 0 for unlimited.", + reply_markup=BotKeyboard.inline_cancel_action(), + ) schedule_delete_message(message.chat.id, msg.id) bot.register_next_step_handler(msg, add_user_data_limit_step, username=username) @@ -1221,13 +1254,13 @@ def add_user_bulk_number_step(message: types.Message, username: str): def add_user_data_limit_step(message: types.Message, username: str): try: if float(message.text) < 0: - wait_msg = bot.send_message(message.chat.id, '❌ Data limit must be greater or equal to 0.') + wait_msg = bot.send_message(message.chat.id, "❌ Data limit must be greater or equal to 0.") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_user_data_limit_step, username=username) data_limit = float(message.text) * 1024 * 1024 * 1024 except ValueError: - wait_msg = bot.send_message(message.chat.id, '❌ Data limit must be a number.') + wait_msg = bot.send_message(message.chat.id, "❌ Data limit must be a number.") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_user_data_limit_step, username=username) @@ -1236,22 +1269,23 @@ def add_user_data_limit_step(message: types.Message, username: str): cleanup_messages(message.chat.id) msg = bot.send_message( message.chat.id, - '⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now', - reply_markup=BotKeyboard.user_status_select()) + "⚡ Select User Status:\nOn Hold: Expiration starts after the first connection\nActive: Expiration starts from now", + reply_markup=BotKeyboard.user_status_select(), + ) schedule_delete_message(message.chat.id, msg.id) - mem_store.set(f'{message.chat.id}:data_limit', data_limit) - mem_store.set(f'{message.chat.id}:username', username) + mem_store.set(f"{message.chat.id}:data_limit", data_limit) + mem_store.set(f"{message.chat.id}:username", username) -@bot.callback_query_handler(cb_query_startswith('status:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("status:"), is_admin=True) def add_user_status_step(call: types.CallbackQuery): - user_status = call.data.split(':')[1] - username = mem_store.get(f'{call.message.chat.id}:username') - data_limit = mem_store.get(f'{call.message.chat.id}:data_limit') + user_status = call.data.split(":")[1] + username = mem_store.get(f"{call.message.chat.id}:username") + data_limit = mem_store.get(f"{call.message.chat.id}:data_limit") - if user_status not in ['active', 'onhold']: - return bot.answer_callback_query(call.id, '❌ Invalid status. Please choose Active or OnHold.') + if user_status not in ["active", "onhold"]: + return bot.answer_callback_query(call.id, "❌ Invalid status. Please choose Active or OnHold.") bot.edit_message_reply_markup(call.message.chat.id, call.message.message_id, reply_markup=None) bot.delete_message(call.message.chat.id, call.message.message_id) @@ -1259,10 +1293,10 @@ def add_user_status_step(call: types.CallbackQuery): if text := mem_store.get(f"{call.message.chat.id}:template_info_text"): mem_store.set(f"{call.message.chat.id}:template_info_text", None) inbounds = mem_store.get(f"{call.message.chat.id}:protocols") - expire_date = mem_store.get(f'{call.message.chat.id}:expire_date') - mem_store.set(f'{call.message.chat.id}:user_status', user_status) + expire_date = mem_store.get(f"{call.message.chat.id}:expire_date") + mem_store.set(f"{call.message.chat.id}:user_status", user_status) if user_status == "onhold": - mem_store.set(f'{call.message.chat.id}:onhold_timeout', None) + mem_store.set(f"{call.message.chat.id}:onhold_timeout", None) return bot.send_message( call.message.chat.id, text, @@ -1272,20 +1306,20 @@ def add_user_status_step(call: types.CallbackQuery): "create_from_template", username=username, data_limit=data_limit, - expire_date=expire_date,)) + expire_date=expire_date, + ), + ) - if user_status == 'onhold': - expiry_message = '⬆️ Enter Expire Days\nYou Can Use Regex Symbol: ^[0-9]{1,3}(M|D) :' + if user_status == "onhold": + expiry_message = "⬆️ Enter Expire Days\nYou Can Use Regex Symbol: ^[0-9]{1,3}(M|D) :" else: - expiry_message = '⬆️ Enter Expire Date (YYYY-MM-DD)\nOr You Can Use Regex Symbol: ^[0-9]{1,3}(M|D) :\n⚠️ Send 0 for never expire.' + expiry_message = "⬆️ Enter Expire Date (YYYY-MM-DD)\nOr You Can Use Regex Symbol: ^[0-9]{1,3}(M|D) :\n⚠️ Send 0 for never expire." - msg = bot.send_message( - call.message.chat.id, - expiry_message, - reply_markup=BotKeyboard.inline_cancel_action()) + msg = bot.send_message(call.message.chat.id, expiry_message, reply_markup=BotKeyboard.inline_cancel_action()) schedule_delete_message(call.message.chat.id, msg.id) - bot.register_next_step_handler(msg, add_user_expire_step, username=username, - data_limit=data_limit, user_status=user_status) + bot.register_next_step_handler( + msg, add_user_expire_step, username=username, data_limit=data_limit, user_status=user_status + ) def add_user_expire_step(message: types.Message, username: str, data_limit: int, user_status: str): @@ -1293,27 +1327,27 @@ def add_user_expire_step(message: types.Message, username: str, data_limit: int, now = datetime.now() today = datetime(year=now.year, month=now.month, day=now.day, hour=23, minute=59, second=59) - if re.match(r'^[0-9]{1,3}([MmDd])$', message.text): - number_pattern = r'^[0-9]{1,3}' + if re.match(r"^[0-9]{1,3}([MmDd])$", message.text): + number_pattern = r"^[0-9]{1,3}" number = int(re.findall(number_pattern, message.text)[0]) - symbol_pattern = r'([MmDd])$' + symbol_pattern = r"([MmDd])$" symbol = re.findall(symbol_pattern, message.text)[0].upper() - if user_status == 'onhold': - if symbol == 'M': + if user_status == "onhold": + if symbol == "M": expire_date = number * 30 else: expire_date = number else: # active - if symbol == 'M': + if symbol == "M": expire_date = today + relativedelta(months=number) else: expire_date = today + relativedelta(days=number) - elif message.text == '0': - if user_status == 'onhold': + elif message.text == "0": + if user_status == "onhold": raise ValueError("Expire days is required for an on hold user.") expire_date = None - elif user_status == 'active': + elif user_status == "active": expire_date = datetime.strptime(message.text, "%Y-%m-%d") if expire_date < today: raise ValueError("Expire date must be greater than today.") @@ -1321,41 +1355,44 @@ def add_user_expire_step(message: types.Message, username: str, data_limit: int, raise ValueError("Invalid input for onhold status.") except ValueError as e: error_message = str(e) if str(e) != "Invalid input for onhold status." else "Invalid input. Please try again." - wait_msg = bot.send_message(message.chat.id, f'❌ {error_message}') + wait_msg = bot.send_message(message.chat.id, f"❌ {error_message}") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler( - wait_msg, add_user_expire_step, username=username, data_limit=data_limit, user_status=user_status) + wait_msg, add_user_expire_step, username=username, data_limit=data_limit, user_status=user_status + ) - mem_store.set(f'{message.chat.id}:username', username) - mem_store.set(f'{message.chat.id}:data_limit', data_limit) - mem_store.set(f'{message.chat.id}:user_status', user_status) - mem_store.set(f'{message.chat.id}:expire_date', expire_date) + mem_store.set(f"{message.chat.id}:username", username) + mem_store.set(f"{message.chat.id}:data_limit", data_limit) + mem_store.set(f"{message.chat.id}:user_status", user_status) + mem_store.set(f"{message.chat.id}:expire_date", expire_date) schedule_delete_message(message.chat.id, message.id) cleanup_messages(message.chat.id) if user_status == "onhold": - timeout_message = '⬆️ Enter timeout (YYYY-MM-DD)\nOr You Can Use Regex Symbol: ^[0-9]{1,3}(M|D) :\n⚠️ Send 0 for never timeout.' - msg = bot.send_message( - message.chat.id, - timeout_message, - reply_markup=BotKeyboard.inline_cancel_action() + timeout_message = ( + "⬆️ Enter timeout (YYYY-MM-DD)\nOr You Can Use Regex Symbol: ^[0-9]{1,3}(M|D) :\n⚠️ Send 0 for never timeout." ) + msg = bot.send_message(message.chat.id, timeout_message, reply_markup=BotKeyboard.inline_cancel_action()) schedule_delete_message(message.chat.id, msg.id) return bot.register_next_step_handler(msg, add_on_hold_timeout) bot.send_message( - message.chat.id, 'Select Protocols:\nUsername: {}\nData Limit: {}\nStatus: {}\nExpiry Date: {}'.format( - mem_store.get(f'{message.chat.id}:username'), - readable_size(mem_store.get(f'{message.chat.id}:data_limit')) - if mem_store.get(f'{message.chat.id}:data_limit') else "Unlimited", mem_store.get( - f'{message.chat.id}:user_status'), - mem_store.get(f'{message.chat.id}:expire_date').strftime("%Y-%m-%d") - if isinstance(mem_store.get(f'{message.chat.id}:expire_date'), - datetime) else mem_store.get(f'{message.chat.id}:expire_date') - if mem_store.get(f'{message.chat.id}:expire_date') else 'Never'), - reply_markup=BotKeyboard.select_protocols( - mem_store.get(f'{message.chat.id}:protocols', {}), action="create")) + message.chat.id, + "Select Protocols:\nUsername: {}\nData Limit: {}\nStatus: {}\nExpiry Date: {}".format( + mem_store.get(f"{message.chat.id}:username"), + readable_size(mem_store.get(f"{message.chat.id}:data_limit")) + if mem_store.get(f"{message.chat.id}:data_limit") + else "Unlimited", + mem_store.get(f"{message.chat.id}:user_status"), + mem_store.get(f"{message.chat.id}:expire_date").strftime("%Y-%m-%d") + if isinstance(mem_store.get(f"{message.chat.id}:expire_date"), datetime) + else mem_store.get(f"{message.chat.id}:expire_date") + if mem_store.get(f"{message.chat.id}:expire_date") + else "Never", + ), + reply_markup=BotKeyboard.select_protocols(mem_store.get(f"{message.chat.id}:protocols", {}), action="create"), + ) def add_on_hold_timeout(message: types.Message): @@ -1363,16 +1400,16 @@ def add_on_hold_timeout(message: types.Message): now = datetime.now() today = datetime(year=now.year, month=now.month, day=now.day, hour=23, minute=59, second=59) - if re.match(r'^[0-9]{1,3}([MmDd])$', message.text): - number_pattern = r'^[0-9]{1,3}' + if re.match(r"^[0-9]{1,3}([MmDd])$", message.text): + number_pattern = r"^[0-9]{1,3}" number = int(re.findall(number_pattern, message.text)[0]) - symbol_pattern = r'([MmDd])$' + symbol_pattern = r"([MmDd])$" symbol = re.findall(symbol_pattern, message.text)[0].upper() - if symbol == 'M': + if symbol == "M": onhold_timeout = today + relativedelta(months=number) else: onhold_timeout = today + relativedelta(days=number) - elif message.text == '0': + elif message.text == "0": onhold_timeout = None else: onhold_timeout = datetime.strptime(message.text, "%Y-%m-%d") @@ -1380,48 +1417,52 @@ def add_on_hold_timeout(message: types.Message): raise ValueError("Expire date must be greater than today.") except ValueError as e: error_message = str(e) - wait_msg = bot.send_message(message.chat.id, f'❌ {error_message}') + wait_msg = bot.send_message(message.chat.id, f"❌ {error_message}") schedule_delete_message(message.chat.id, wait_msg.id) schedule_delete_message(message.chat.id, message.id) return bot.register_next_step_handler(wait_msg, add_on_hold_timeout) - mem_store.set(f'{message.chat.id}:onhold_timeout', onhold_timeout) + mem_store.set(f"{message.chat.id}:onhold_timeout", onhold_timeout) schedule_delete_message(message.chat.id, message.id) cleanup_messages(message.chat.id) bot.send_message( - message.chat.id, 'Select Protocols:\nUsername: {}\nData Limit: {}\nStatus: {}\nExpiry Date: {}'.format( - mem_store.get(f'{message.chat.id}:username'), - readable_size(mem_store.get(f'{message.chat.id}:data_limit')) - if mem_store.get(f'{message.chat.id}:data_limit') else "Unlimited", mem_store.get( - f'{message.chat.id}:user_status'), - mem_store.get(f'{message.chat.id}:expire_date').strftime("%Y-%m-%d") - if isinstance(mem_store.get(f'{message.chat.id}:expire_date'), - datetime) else mem_store.get(f'{message.chat.id}:expire_date') - if mem_store.get(f'{message.chat.id}:expire_date') else 'Never'), - reply_markup=BotKeyboard.select_protocols( - mem_store.get(f'{message.chat.id}:protocols', {}), action="create")) + message.chat.id, + "Select Protocols:\nUsername: {}\nData Limit: {}\nStatus: {}\nExpiry Date: {}".format( + mem_store.get(f"{message.chat.id}:username"), + readable_size(mem_store.get(f"{message.chat.id}:data_limit")) + if mem_store.get(f"{message.chat.id}:data_limit") + else "Unlimited", + mem_store.get(f"{message.chat.id}:user_status"), + mem_store.get(f"{message.chat.id}:expire_date").strftime("%Y-%m-%d") + if isinstance(mem_store.get(f"{message.chat.id}:expire_date"), datetime) + else mem_store.get(f"{message.chat.id}:expire_date") + if mem_store.get(f"{message.chat.id}:expire_date") + else "Never", + ), + reply_markup=BotKeyboard.select_protocols(mem_store.get(f"{message.chat.id}:protocols", {}), action="create"), + ) -@bot.callback_query_handler(cb_query_startswith('select_inbound:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("select_inbound:"), is_admin=True) def select_inbounds(call: types.CallbackQuery): - if not (username := mem_store.get(f'{call.message.chat.id}:username')): - return bot.answer_callback_query(call.id, '❌ No user selected.', show_alert=True) - protocols: dict[str, list[str]] = mem_store.get(f'{call.message.chat.id}:protocols', {}) - _, inbound, action = call.data.split(':') + if not (username := mem_store.get(f"{call.message.chat.id}:username")): + return bot.answer_callback_query(call.id, "❌ No user selected.", show_alert=True) + protocols: dict[str, list[str]] = mem_store.get(f"{call.message.chat.id}:protocols", {}) + _, inbound, action = call.data.split(":") for protocol, inbounds in xray.config.inbounds_by_protocol.items(): for i in inbounds: - if i['tag'] != inbound: + if i["tag"] != inbound: continue - if not inbound in protocols[protocol]: + if inbound not in protocols[protocol]: protocols[protocol].append(inbound) else: protocols[protocol].remove(inbound) if len(protocols[protocol]) < 1: del protocols[protocol] - mem_store.set(f'{call.message.chat.id}:protocols', protocols) + mem_store.set(f"{call.message.chat.id}:protocols", protocols) if action in ["edit", "create_from_template"]: return bot.edit_message_text( @@ -1433,28 +1474,28 @@ def select_inbounds(call: types.CallbackQuery): "edit", username=username, data_limit=mem_store.get(f"{call.message.chat.id}:data_limit"), - expire_date=mem_store.get(f"{call.message.chat.id}:expire_date")) + expire_date=mem_store.get(f"{call.message.chat.id}:expire_date"), + ), ) bot.edit_message_text( call.message.text, call.message.chat.id, call.message.message_id, - reply_markup=BotKeyboard.select_protocols(protocols, "create") + reply_markup=BotKeyboard.select_protocols(protocols, "create"), ) -@bot.callback_query_handler(cb_query_startswith('select_protocol:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("select_protocol:"), is_admin=True) def select_protocols(call: types.CallbackQuery): - if not (username := mem_store.get(f'{call.message.chat.id}:username')): - return bot.answer_callback_query(call.id, '❌ No user selected.', show_alert=True) - protocols: dict[str, list[str]] = mem_store.get(f'{call.message.chat.id}:protocols', {}) - _, protocol, action = call.data.split(':') + if not (username := mem_store.get(f"{call.message.chat.id}:username")): + return bot.answer_callback_query(call.id, "❌ No user selected.", show_alert=True) + protocols: dict[str, list[str]] = mem_store.get(f"{call.message.chat.id}:protocols", {}) + _, protocol, action = call.data.split(":") if protocol in protocols: del protocols[protocol] else: - protocols.update( - {protocol: [inbound['tag'] for inbound in xray.config.inbounds_by_protocol[protocol]]}) - mem_store.set(f'{call.message.chat.id}:protocols', protocols) + protocols.update({protocol: [inbound["tag"] for inbound in xray.config.inbounds_by_protocol[protocol]]}) + mem_store.set(f"{call.message.chat.id}:protocols", protocols) if action in ["edit", "create_from_template"]: return bot.edit_message_text( @@ -1466,41 +1507,33 @@ def select_protocols(call: types.CallbackQuery): "edit", username=username, data_limit=mem_store.get(f"{call.message.chat.id}:data_limit"), - expire_date=mem_store.get(f"{call.message.chat.id}:expire_date")) + expire_date=mem_store.get(f"{call.message.chat.id}:expire_date"), + ), ) bot.edit_message_text( call.message.text, call.message.chat.id, call.message.message_id, - reply_markup=BotKeyboard.select_protocols(protocols, action="create") + reply_markup=BotKeyboard.select_protocols(protocols, action="create"), ) -@bot.callback_query_handler(cb_query_startswith('confirm:'), is_admin=True) +@bot.callback_query_handler(cb_query_startswith("confirm:"), is_admin=True) def confirm_user_command(call: types.CallbackQuery): - data = call.data.split(':')[1] + data = call.data.split(":")[1] chat_id = call.from_user.id full_name = call.from_user.full_name now = datetime.now() - today = datetime( - year=now.year, - month=now.month, - day=now.day, - hour=23, - minute=59, - second=59) - if data == 'delete': - username = call.data.split(':')[2] + today = datetime(year=now.year, month=now.month, day=now.day, hour=23, minute=59, second=59) + if data == "delete": + username = call.data.split(":")[2] with GetDB() as db: db_user = crud.get_user(db, username) crud.remove_user(db, db_user) xray.operations.remove_user(db_user) bot.edit_message_text( - '✅ User deleted.', - call.message.chat.id, - call.message.message_id, - reply_markup=BotKeyboard.main_menu() + "✅ User deleted.", call.message.chat.id, call.message.message_id, reply_markup=BotKeyboard.main_menu() ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ @@ -1513,22 +1546,22 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass elif data == "suspend": username = call.data.split(":")[2] with GetDB() as db: db_user = crud.get_user(db, username) - crud.update_user(db, db_user, UserModify( - status=UserStatusModify.disabled)) + crud.update_user(db, db_user, UserModify(status=UserStatusModify.disabled)) xray.operations.remove_user(db_user) bot.edit_message_text( get_user_info_text(db_user), call.message.chat.id, call.message.message_id, - parse_mode='HTML', - reply_markup=BotKeyboard.user_menu(user_info={'status': 'disabled', 'username': db_user.username})) + parse_mode="HTML", + reply_markup=BotKeyboard.user_menu(user_info={"status": "disabled", "username": db_user.username}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ ❌ #Disabled #From_Bot @@ -1537,22 +1570,22 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass elif data == "activate": username = call.data.split(":")[2] with GetDB() as db: db_user = crud.get_user(db, username) - crud.update_user(db, db_user, UserModify( - status=UserStatusModify.active)) + crud.update_user(db, db_user, UserModify(status=UserStatusModify.active)) xray.operations.add_user(db_user) bot.edit_message_text( get_user_info_text(db_user), call.message.chat.id, call.message.message_id, - parse_mode='HTML', - reply_markup=BotKeyboard.user_menu(user_info={'status': 'active', 'username': db_user.username})) + parse_mode="HTML", + reply_markup=BotKeyboard.user_menu(user_info={"status": "active", "username": db_user.username}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ ✅ #Activated #From_Bot @@ -1561,10 +1594,10 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass - elif data == 'reset_usage': + elif data == "reset_usage": username = call.data.split(":")[2] with GetDB() as db: db_user = crud.get_user(db, username) @@ -1576,8 +1609,9 @@ def confirm_user_command(call: types.CallbackQuery): get_user_info_text(db_user), call.message.chat.id, call.message.message_id, - parse_mode='HTML', - reply_markup=BotKeyboard.user_menu(user_info={'status': user.status, 'username': user.username})) + parse_mode="HTML", + reply_markup=BotKeyboard.user_menu(user_info={"status": user.status, "username": user.username}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 🔁 #Reset_usage #From_Bot @@ -1586,24 +1620,21 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass - elif data == 'restart': - m = bot.edit_message_text( - '🔄 Restarting XRay core...', call.message.chat.id, call.message.message_id) + elif data == "restart": + m = bot.edit_message_text("🔄 Restarting XRay core...", call.message.chat.id, call.message.message_id) config = xray.config.include_db_users() xray.core.restart(config) for node_id, node in list(xray.nodes.items()): if node.connected: xray.operations.restart_node(node_id, config) bot.edit_message_text( - '✅ XRay core restarted successfully.', - m.chat.id, m.message_id, - reply_markup=BotKeyboard.main_menu() + "✅ XRay core restarted successfully.", m.chat.id, m.message_id, reply_markup=BotKeyboard.main_menu() ) - elif data in ['charge_add', 'charge_reset']: + elif data in ["charge_add", "charge_reset"]: _, _, username, template_id = call.data.split(":") with GetDB() as db: template = crud.get_user_template(db, template_id) @@ -1626,7 +1657,7 @@ def confirm_user_command(call: types.CallbackQuery): del proxies[protocol] crud.reset_user_data_usage(db, db_user) - if data == 'charge_reset': + if data == "charge_reset": expire_date = None if template.expire_duration: expire_date = today + relativedelta(seconds=template.expire_duration) @@ -1638,8 +1669,9 @@ def confirm_user_command(call: types.CallbackQuery): else: expire_date = None if template.expire_duration: - expire_date = (datetime.fromtimestamp(user.expire) - if user.expire else today) + relativedelta(seconds=template.expire_duration) + expire_date = (datetime.fromtimestamp(user.expire) if user.expire else today) + relativedelta( + seconds=template.expire_duration + ) modify = UserModify( status=UserStatus.active, expire=int(expire_date.timestamp()) if expire_date else 0, @@ -1652,8 +1684,9 @@ def confirm_user_command(call: types.CallbackQuery): get_user_info_text(db_user), call.message.chat.id, call.message.message_id, - parse_mode='html', - reply_markup=BotKeyboard.user_menu(user_info={'status': user.status, 'username': user.username})) + parse_mode="html", + reply_markup=BotKeyboard.user_menu(user_info={"status": user.status, "username": user.username}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 🔋 #Charged #{data.split('_')[1].title()} #From_Bot @@ -1674,63 +1707,62 @@ def confirm_user_command(call: types.CallbackQuery): By : {full_name}\ """ try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass - elif data == 'edit_user': - if (username := mem_store.get(f'{call.message.chat.id}:username')) is None: + elif data == "edit_user": + if (username := mem_store.get(f"{call.message.chat.id}:username")) is None: try: - bot.delete_message(call.message.chat.id, - call.message.message_id) + bot.delete_message(call.message.chat.id, call.message.message_id) except Exception: pass return bot.send_message( - call.message.chat.id, - '❌ Bot reload detected. Please start over.', - reply_markup=BotKeyboard.main_menu() + call.message.chat.id, "❌ Bot reload detected. Please start over.", reply_markup=BotKeyboard.main_menu() ) - if not mem_store.get(f'{call.message.chat.id}:protocols'): - return bot.answer_callback_query( - call.id, - '❌ No inbounds selected.', - show_alert=True - ) + if not mem_store.get(f"{call.message.chat.id}:protocols"): + return bot.answer_callback_query(call.id, "❌ No inbounds selected.", show_alert=True) inbounds: dict[str, list[str]] = { - k: v for k, v in mem_store.get(f'{call.message.chat.id}:protocols').items() if v} + k: v for k, v in mem_store.get(f"{call.message.chat.id}:protocols").items() if v + } with GetDB() as db: db_user = crud.get_user(db, username) if not db_user: - return bot.answer_callback_query(call.id, text=f"User not found!", show_alert=True) + return bot.answer_callback_query(call.id, text="User not found!", show_alert=True) proxies = {p.type.value: p.settings for p in db_user.proxies} for protocol in xray.config.inbounds_by_protocol: if protocol in inbounds and protocol not in db_user.inbounds: - proxies.update({protocol: {'flow': TELEGRAM_DEFAULT_VLESS_FLOW} if - TELEGRAM_DEFAULT_VLESS_FLOW and protocol == ProxyTypes.VLESS else {}}) + proxies.update( + { + protocol: {"flow": TELEGRAM_DEFAULT_VLESS_FLOW} + if TELEGRAM_DEFAULT_VLESS_FLOW and protocol == ProxyTypes.VLESS + else {} + } + ) elif protocol in db_user.inbounds and protocol not in inbounds: del proxies[protocol] data_limit = mem_store.get(f"{call.message.chat.id}:data_limit") - expire_date = mem_store.get(f'{call.message.chat.id}:expire_date') + expire_date = mem_store.get(f"{call.message.chat.id}:expire_date") if isinstance(expire_date, int): modify = UserModify( on_hold_expire_duration=expire_date, - on_hold_timeout=mem_store.get(f'{call.message.chat.id}:expire_on_hold_timeout'), + on_hold_timeout=mem_store.get(f"{call.message.chat.id}:expire_on_hold_timeout"), data_limit=data_limit, proxies=proxies, - inbounds=inbounds + inbounds=inbounds, ) else: modify = UserModify( expire=int(expire_date.timestamp()) if expire_date else 0, data_limit=data_limit, proxies=proxies, - inbounds=inbounds + inbounds=inbounds, ) last_user = UserResponse.model_validate(db_user) db_user = crud.update_user(db, db_user, modify) @@ -1746,7 +1778,8 @@ def confirm_user_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.user_menu({'username': db_user.username, 'status': db_user.status})) + reply_markup=BotKeyboard.user_menu({"username": db_user.username, "status": db_user.status}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: tag = f'\n➖➖➖➖➖➖➖➖➖ \nBy : {full_name}' if last_user.data_limit != user.data_limit: @@ -1757,7 +1790,7 @@ def confirm_user_command(call: types.CallbackQuery): Last Traffic Limit : {readable_size(last_user.data_limit) if last_user.data_limit else "Unlimited"} New Traffic Limit : {readable_size(user.data_limit) if user.data_limit else "Unlimited"}{tag}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass if last_user.expire != user.expire: @@ -1770,7 +1803,7 @@ def confirm_user_command(call: types.CallbackQuery): New Expire Date : \ {datetime.fromtimestamp(user.expire).strftime('%H:%M:%S %Y-%m-%d') if user.expire else "Never"}{tag}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass if list(last_user.inbounds.values())[0] != list(user.inbounds.values())[0]: @@ -1781,77 +1814,77 @@ def confirm_user_command(call: types.CallbackQuery): Last Proxies : {", ".join(list(last_user.inbounds.values())[0])} New Proxies : {", ".join(list(user.inbounds.values())[0])}{tag}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass - elif data == 'add_user': - if mem_store.get(f'{call.message.chat.id}:username') is None: + elif data == "add_user": + if mem_store.get(f"{call.message.chat.id}:username") is None: try: bot.delete_message(call.message.chat.id, call.message.message_id) except Exception: pass return bot.send_message( - call.message.chat.id, - '❌ Bot reload detected. Please start over.', - reply_markup=BotKeyboard.main_menu() + call.message.chat.id, "❌ Bot reload detected. Please start over.", reply_markup=BotKeyboard.main_menu() ) - if not mem_store.get(f'{call.message.chat.id}:protocols'): - return bot.answer_callback_query( - call.id, - '❌ No inbounds selected.', - show_alert=True - ) + if not mem_store.get(f"{call.message.chat.id}:protocols"): + return bot.answer_callback_query(call.id, "❌ No inbounds selected.", show_alert=True) inbounds: dict[str, list[str]] = { - k: v for k, v in mem_store.get(f'{call.message.chat.id}:protocols').items() if v} - original_proxies = {p: ({'flow': TELEGRAM_DEFAULT_VLESS_FLOW} if - TELEGRAM_DEFAULT_VLESS_FLOW and p == ProxyTypes.VLESS else {}) for p in inbounds} - - user_status = mem_store.get(f'{call.message.chat.id}:user_status') - number = mem_store.get(f'{call.message.chat.id}:number', 1) + k: v for k, v in mem_store.get(f"{call.message.chat.id}:protocols").items() if v + } + original_proxies = { + p: ({"flow": TELEGRAM_DEFAULT_VLESS_FLOW} if TELEGRAM_DEFAULT_VLESS_FLOW and p == ProxyTypes.VLESS else {}) + for p in inbounds + } + + user_status = mem_store.get(f"{call.message.chat.id}:user_status") + number = mem_store.get(f"{call.message.chat.id}:number", 1) if not mem_store.get(f"{call.message.chat.id}:is_bulk", False): number = 1 for i in range(number): proxies = copy.deepcopy(original_proxies) - username: str = mem_store.get(f'{call.message.chat.id}:username') + username: str = mem_store.get(f"{call.message.chat.id}:username") if mem_store.get(f"{call.message.chat.id}:is_bulk", False): if n := get_number_at_end(username): - username = username.replace(n, str(int(n)+i)) + username = username.replace(n, str(int(n) + i)) else: - username += str(i+1) if i > 0 else "" - if user_status == 'onhold': - expire_days = mem_store.get(f'{call.message.chat.id}:expire_date') - onhold_timeout = mem_store.get(f'{call.message.chat.id}:onhold_timeout') + username += str(i + 1) if i > 0 else "" + if user_status == "onhold": + expire_days = mem_store.get(f"{call.message.chat.id}:expire_date") + onhold_timeout = mem_store.get(f"{call.message.chat.id}:onhold_timeout") if isinstance(expire_days, datetime): expire_days = (expire_days - datetime.now()).days new_user = UserCreate( username=username, - status='on_hold', + status="on_hold", on_hold_expire_duration=int(expire_days) * 24 * 60 * 60, on_hold_timeout=onhold_timeout, - data_limit=mem_store.get(f'{call.message.chat.id}:data_limit') - if mem_store.get(f'{call.message.chat.id}:data_limit') else None, + data_limit=mem_store.get(f"{call.message.chat.id}:data_limit") + if mem_store.get(f"{call.message.chat.id}:data_limit") + else None, proxies=proxies, - inbounds=inbounds) + inbounds=inbounds, + ) else: new_user = UserCreate( username=username, - status='active', - expire=int(mem_store.get(f'{call.message.chat.id}:expire_date').timestamp()) - if mem_store.get(f'{call.message.chat.id}:expire_date') else None, - data_limit=mem_store.get(f'{call.message.chat.id}:data_limit') - if mem_store.get(f'{call.message.chat.id}:data_limit') else None, + status="active", + expire=int(mem_store.get(f"{call.message.chat.id}:expire_date").timestamp()) + if mem_store.get(f"{call.message.chat.id}:expire_date") + else None, + data_limit=mem_store.get(f"{call.message.chat.id}:data_limit") + if mem_store.get(f"{call.message.chat.id}:data_limit") + else None, proxies=proxies, - inbounds=inbounds) + inbounds=inbounds, + ) for proxy_type in new_user.proxies: if not xray.config.inbounds_by_protocol.get(proxy_type): return bot.answer_callback_query( - call.id, - f'❌ Protocol {proxy_type} is disabled on your server', - show_alert=True + call.id, f"❌ Protocol {proxy_type} is disabled on your server", show_alert=True ) try: with GetDB() as db: @@ -1867,7 +1900,8 @@ def confirm_user_command(call: types.CallbackQuery): get_user_info_text(db_user), parse_mode="HTML", reply_markup=BotKeyboard.user_menu( - user_info={'status': user.status, 'username': user.username}) + user_info={"status": user.status, "username": user.username} + ), ) else: bot.edit_message_text( @@ -1875,14 +1909,13 @@ def confirm_user_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.user_menu(user_info={'status': user.status, 'username': user.username})) + reply_markup=BotKeyboard.user_menu( + user_info={"status": user.status, "username": user.username} + ), + ) except sqlalchemy.exc.IntegrityError: db.rollback() - return bot.answer_callback_query( - call.id, - '❌ Username already exists.', - show_alert=True - ) + return bot.answer_callback_query(call.id, "❌ Username already exists.", show_alert=True) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 🆕 #Created #From_Bot @@ -1891,7 +1924,7 @@ def confirm_user_command(call: types.CallbackQuery): Status : {'Active' if user_status == 'active' else 'On Hold'} Traffic Limit : {readable_size(user.data_limit) if user.data_limit else "Unlimited"} """ - if user_status == 'onhold': + if user_status == "onhold": text += f"""\ On Hold Expire Duration : {new_user.on_hold_expire_duration // (24*60*60)} days On Hold Timeout : {new_user.on_hold_timeout.strftime("%H:%M:%S %Y-%m-%d") if new_user.on_hold_timeout else "-"}""" @@ -1903,22 +1936,21 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass - elif data in ['delete_expired', 'delete_limited']: + elif data in ["delete_expired", "delete_limited"]: bot.edit_message_text( - '⏳ In Progress...', - call.message.chat.id, - call.message.message_id, - parse_mode="HTML") + "⏳ In Progress...", call.message.chat.id, call.message.message_id, parse_mode="HTML" + ) with GetDB() as db: depleted_users = crud.get_users( - db, status=[UserStatus.limited if data == 'delete_limited' else UserStatus.expired]) - file_name = f'{data[8:]}_users_{int(now.timestamp()*1000)}.txt' - with open(file_name, 'w') as f: - f.write('USERNAME\tEXIPRY\tUSAGE/LIMIT\tSTATUS\n') + db, status=[UserStatus.limited if data == "delete_limited" else UserStatus.expired] + ) + file_name = f"{data[8:]}_users_{int(now.timestamp()*1000)}.txt" + with open(file_name, "w") as f: + f.write("USERNAME\tEXIPRY\tUSAGE/LIMIT\tSTATUS\n") deleted = 0 for user in depleted_users: try: @@ -1930,15 +1962,17 @@ def confirm_user_command(call: types.CallbackQuery): \t{user.expire if user.expire else "never"}\ \t{readable_size(user.used_traffic) if user.used_traffic else 0}\ /{readable_size(user.data_limit) if user.data_limit else "Unlimited"}\ -\t{user.status}\n') +\t{user.status}\n' + ) except sqlalchemy.exc.IntegrityError: db.rollback() bot.edit_message_text( - f'✅ {deleted}/{len(depleted_users)} {data[7:].title()} Users Deleted', + f"✅ {deleted}/{len(depleted_users)} {data[7:].title()} Users Deleted", call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.main_menu()) + reply_markup=BotKeyboard.main_menu(), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 🗑 #Delete #{data[7:].title()} #From_Bot @@ -1947,22 +1981,21 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_document(TELEGRAM_LOGGER_CHANNEL_ID, open( - file_name, 'rb'), caption=text, parse_mode='HTML') + bot.send_document( + TELEGRAM_LOGGER_CHANNEL_ID, open(file_name, "rb"), caption=text, parse_mode="HTML" + ) os.remove(file_name) except ApiTelegramException: pass - elif data == 'add_data': - schedule_delete_message( - call.message.chat.id, - bot.send_message(chat_id, '⏳ In Progress...', 'HTML').id) + elif data == "add_data": + schedule_delete_message(call.message.chat.id, bot.send_message(chat_id, "⏳ In Progress...", "HTML").id) data_limit = float(call.data.split(":")[2]) * 1024 * 1024 * 1024 with GetDB() as db: users = crud.get_users(db) counter = 0 - file_name = f'new_data_limit_users_{int(now.timestamp()*1000)}.txt' - with open(file_name, 'w') as f: - f.write('USERNAME\tEXIPRY\tUSAGE/LIMIT\tSTATUS\n') + file_name = f"new_data_limit_users_{int(now.timestamp()*1000)}.txt" + with open(file_name, "w") as f: + f.write("USERNAME\tEXIPRY\tUSAGE/LIMIT\tSTATUS\n") for user in users: try: if user.data_limit and user.status not in [UserStatus.limited, UserStatus.expired]: @@ -1973,7 +2006,8 @@ def confirm_user_command(call: types.CallbackQuery): \t{user.expire if user.expire else "never"}\ \t{readable_size(user.used_traffic) if user.used_traffic else 0}\ /{readable_size(user.data_limit) if user.data_limit else "Unlimited"}\ -\t{user.status}\n') +\t{user.status}\n' + ) except sqlalchemy.exc.IntegrityError: db.rollback() cleanup_messages(chat_id) @@ -1981,8 +2015,9 @@ def confirm_user_command(call: types.CallbackQuery): chat_id, f'✅ {counter}/{len(users)} Users Data Limit according to {"+" if data_limit > 0 else "-"}{readable_size(abs(data_limit))}', - 'HTML', - reply_markup=BotKeyboard.main_menu()) + "HTML", + reply_markup=BotKeyboard.main_menu(), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 📶 #Traffic_Change #From_Bot @@ -1992,46 +2027,45 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_document(TELEGRAM_LOGGER_CHANNEL_ID, open( - file_name, 'rb'), caption=text, parse_mode='HTML') + bot.send_document( + TELEGRAM_LOGGER_CHANNEL_ID, open(file_name, "rb"), caption=text, parse_mode="HTML" + ) os.remove(file_name) except ApiTelegramException: pass - elif data == 'add_time': - schedule_delete_message( - call.message.chat.id, - bot.send_message(chat_id, '⏳ In Progress...', 'HTML').id) + elif data == "add_time": + schedule_delete_message(call.message.chat.id, bot.send_message(chat_id, "⏳ In Progress...", "HTML").id) days = int(call.data.split(":")[2]) with GetDB() as db: users = crud.get_users(db) counter = 0 - file_name = f'new_expiry_users_{int(now.timestamp()*1000)}.txt' - with open(file_name, 'w') as f: - f.write('USERNAME\tEXIPRY\tUSAGE/LIMIT\tSTATUS\n') + file_name = f"new_expiry_users_{int(now.timestamp()*1000)}.txt" + with open(file_name, "w") as f: + f.write("USERNAME\tEXIPRY\tUSAGE/LIMIT\tSTATUS\n") for user in users: try: if user.expire and user.status not in [UserStatus.limited, UserStatus.expired]: user = crud.update_user( - db, user, - UserModify( - expire=int( - (user.expire + relativedelta(days=days)).timestamp()))) + db, user, UserModify(expire=int((user.expire + relativedelta(days=days)).timestamp())) + ) counter += 1 f.write( f'{user.username}\ \t{user.expire if user.expire else "never"}\ \t{readable_size(user.used_traffic) if user.used_traffic else 0}\ /{readable_size(user.data_limit) if user.data_limit else "Unlimited"}\ -\t{user.status}\n') +\t{user.status}\n' + ) except sqlalchemy.exc.IntegrityError: db.rollback() cleanup_messages(chat_id) bot.send_message( chat_id, - f'✅ {counter}/{len(users)} Users Expiry Changes according to {days} Days', - 'HTML', - reply_markup=BotKeyboard.main_menu()) + f"✅ {counter}/{len(users)} Users Expiry Changes according to {days} Days", + "HTML", + reply_markup=BotKeyboard.main_menu(), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ 📅 #Expiry_Change #From_Bot @@ -2041,61 +2075,67 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_document(TELEGRAM_LOGGER_CHANNEL_ID, open( - file_name, 'rb'), caption=text, parse_mode='HTML') + bot.send_document( + TELEGRAM_LOGGER_CHANNEL_ID, open(file_name, "rb"), caption=text, parse_mode="HTML" + ) os.remove(file_name) except ApiTelegramException: pass - elif data in ['inbound_add', 'inbound_remove']: + elif data in ["inbound_add", "inbound_remove"]: bot.edit_message_text( - '⏳ In Progress...', - call.message.chat.id, - call.message.message_id, - parse_mode="HTML") + "⏳ In Progress...", call.message.chat.id, call.message.message_id, parse_mode="HTML" + ) inbound = call.data.split(":")[2] with GetDB() as db: users = crud.get_users(db) unsuccessful = 0 for user in users: inbound_tags = [j for i in user.inbounds for j in user.inbounds[i]] - protocol = xray.config.inbounds_by_tag[inbound]['protocol'] + protocol = xray.config.inbounds_by_tag[inbound]["protocol"] new_inbounds = user.inbounds - if data == 'inbound_add': + if data == "inbound_add": if inbound not in inbound_tags: if protocol in list(new_inbounds.keys()): new_inbounds[protocol].append(inbound) else: new_inbounds[protocol] = [inbound] - elif data == 'inbound_remove': + elif data == "inbound_remove": if inbound in inbound_tags: if len(new_inbounds[protocol]) == 1: del new_inbounds[protocol] else: new_inbounds[protocol].remove(inbound) - if (data == 'inbound_remove' and inbound in inbound_tags)\ - or (data == 'inbound_add' and inbound not in inbound_tags): + if (data == "inbound_remove" and inbound in inbound_tags) or ( + data == "inbound_add" and inbound not in inbound_tags + ): proxies = {p.type.value: p.settings for p in user.proxies} for protocol in xray.config.inbounds_by_protocol: if protocol in new_inbounds and protocol not in user.inbounds: - proxies.update({protocol: {'flow': TELEGRAM_DEFAULT_VLESS_FLOW} if - TELEGRAM_DEFAULT_VLESS_FLOW and protocol == ProxyTypes.VLESS else {}}) + proxies.update( + { + protocol: {"flow": TELEGRAM_DEFAULT_VLESS_FLOW} + if TELEGRAM_DEFAULT_VLESS_FLOW and protocol == ProxyTypes.VLESS + else {} + } + ) elif protocol in user.inbounds and protocol not in new_inbounds: del proxies[protocol] try: user = crud.update_user(db, user, UserModify(inbounds=new_inbounds, proxies=proxies)) if user.status == UserStatus.active: xray.operations.update_user(user) - except: + except Exception: db.rollback() unsuccessful += 1 bot.edit_message_text( - f'✅ {data[8:].title()} {inbound} Users Successfully' + - (f'\n Unsuccessful: {unsuccessful}' if unsuccessful else ''), + f"✅ {data[8:].title()} {inbound} Users Successfully" + + (f"\n Unsuccessful: {unsuccessful}" if unsuccessful else ""), call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.main_menu()) + reply_markup=BotKeyboard.main_menu(), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ @@ -2105,16 +2145,16 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass - elif data == 'revoke_sub': + elif data == "revoke_sub": username = call.data.split(":")[2] with GetDB() as db: db_user = crud.get_user(db, username) if not db_user: - return bot.answer_callback_query(call.id, text=f"User not found!", show_alert=True) + return bot.answer_callback_query(call.id, text="User not found!", show_alert=True) db_user = crud.revoke_user_sub(db, db_user) user = UserResponse.model_validate(db_user) bot.answer_callback_query(call.id, "✅ Subscription Successfully Revoked!") @@ -2123,7 +2163,8 @@ def confirm_user_command(call: types.CallbackQuery): call.message.chat.id, call.message.message_id, parse_mode="HTML", - reply_markup=BotKeyboard.user_menu(user_info={'status': user.status, 'username': user.username})) + reply_markup=BotKeyboard.user_menu(user_info={"status": user.status, "username": user.username}), + ) if TELEGRAM_LOGGER_CHANNEL_ID: text = f"""\ @@ -2133,20 +2174,19 @@ def confirm_user_command(call: types.CallbackQuery): ➖➖➖➖➖➖➖➖➖ By : {full_name}""" try: - bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, 'HTML') + bot.send_message(TELEGRAM_LOGGER_CHANNEL_ID, text, "HTML") except ApiTelegramException: pass -@bot.message_handler(commands=['user'], is_admin=True) +@bot.message_handler(commands=["user"], is_admin=True) def search_user(message: types.Message): args = extract_arguments(message.text) if not args: return bot.reply_to( message, - "❌ You must pass some usernames\n\n" - "Usage: /user username1 username2", - parse_mode="HTML" + "❌ You must pass some usernames\n\n" "Usage: /user username1 username2", + parse_mode="HTML", ) usernames = args.split() @@ -2155,11 +2195,12 @@ def search_user(message: types.Message): for username in usernames: db_user = crud.get_user(db, username) if not db_user: - bot.reply_to(message, f'❌ User «{username}» not found.') + bot.reply_to(message, f"❌ User «{username}» not found.") continue user = UserResponse.model_validate(db_user) bot.reply_to( message, get_user_info_text(db_user), parse_mode="html", - reply_markup=BotKeyboard.user_menu(user_info={'status': user.status, 'username': user.username})) + reply_markup=BotKeyboard.user_menu(user_info={"status": user.status, "username": user.username}), + ) diff --git a/app/telegram/handlers/report.py b/app/telegram/handlers/report.py index 6c7575576..a37213e18 100644 --- a/app/telegram/handlers/report.py +++ b/app/telegram/handlers/report.py @@ -1,5 +1,3 @@ -import datetime - from app import logger from app.db.models import User from app.telegram import bot @@ -28,17 +26,17 @@ def report(text: str, chat_id: int = None, parse_mode="html", keyboard=None): def report_new_user( - user_id: int, - username: str, - by: str, - expire_date: int, - data_limit: int, - proxies: list, - has_next_plan: bool, - data_limit_reset_strategy: UserDataLimitResetStrategy, - admin: Admin = None + user_id: int, + username: str, + by: str, + expire_date: int, + data_limit: int, + proxies: list, + has_next_plan: bool, + data_limit_reset_strategy: UserDataLimitResetStrategy, + admin: Admin = None, ): - text = '''\ + text = """\ 🆕 #Created ➖➖➖➖➖➖➖➖➖ Username : {username} @@ -49,7 +47,7 @@ def report_new_user( Has Next Plan : {next_plan} ➖➖➖➖➖➖➖➖➖ Belongs To : {belong_to} -By : #{by}'''.format( +By : #{by}""".format( belong_to=escape_html(admin.username) if admin else None, by=escape_html(by), username=escape_html(username), @@ -63,25 +61,21 @@ def report_new_user( return report( chat_id=admin.telegram_id if admin and admin.telegram_id else None, text=text, - keyboard=BotKeyboard.user_menu({ - 'username': username, - 'id': user_id, - 'status': 'active' - }, with_back=False) + keyboard=BotKeyboard.user_menu({"username": username, "id": user_id, "status": "active"}, with_back=False), ) def report_user_modification( - username: str, - expire_date: int, - data_limit: int, - proxies: list, - has_next_plan: bool, - by: str, - data_limit_reset_strategy: UserDataLimitResetStrategy, - admin: Admin = None + username: str, + expire_date: int, + data_limit: int, + proxies: list, + has_next_plan: bool, + by: str, + data_limit_reset_strategy: UserDataLimitResetStrategy, + admin: Admin = None, ): - text = '''\ + text = """\ ✏️ #Modified ➖➖➖➖➖➖➖➖➖ Username : {username} @@ -93,13 +87,13 @@ def report_user_modification( ➖➖➖➖➖➖➖➖➖ Belongs To : {belong_to} By : #{by}\ - '''.format( + """.format( belong_to=escape_html(admin.username) if admin else None, by=escape_html(by), username=escape_html(username), data_limit=readable_size(data_limit) if data_limit else "Unlimited", expire_date=datetime.fromtimestamp(expire_date).strftime("%H:%M:%S %Y-%m-%d") if expire_date else "Never", - protocols=', '.join([p for p in proxies]), + protocols=", ".join([p for p in proxies]), data_limit_reset_strategy=escape_html(data_limit_reset_strategy), next_plan="True" if has_next_plan else "False", ) @@ -107,41 +101,38 @@ def report_user_modification( return report( chat_id=admin.telegram_id if admin and admin.telegram_id else None, text=text, - keyboard=BotKeyboard.user_menu({'username': username, 'status': 'active'}, with_back=False)) + keyboard=BotKeyboard.user_menu({"username": username, "status": "active"}, with_back=False), + ) def report_user_deletion(username: str, by: str, admin: Admin = None): - text = '''\ + text = """\ 🗑 #Deleted ➖➖➖➖➖➖➖➖➖ Username : {username} ➖➖➖➖➖➖➖➖➖ Belongs To : {belong_to} By : #{by}\ - '''.format( - belong_to=escape_html(admin.username) if admin else None, - by=escape_html(by), - username=escape_html(username) + """.format( + belong_to=escape_html(admin.username) if admin else None, by=escape_html(by), username=escape_html(username) ) return report(chat_id=admin.telegram_id if admin and admin.telegram_id else None, text=text) def report_status_change(username: str, status: str, admin: Admin = None): _status = { - 'active': '✅ #Activated', - 'disabled': '❌ #Disabled', - 'limited': '🪫 #Limited', - 'expired': '🕔 #Expired' + "active": "✅ #Activated", + "disabled": "❌ #Disabled", + "limited": "🪫 #Limited", + "expired": "🕔 #Expired", } - text = '''\ + text = """\ {status} ➖➖➖➖➖➖➖➖➖ Username : {username} Belongs To : {belong_to}\ - '''.format( - belong_to=escape_html(admin.username) if admin else None, - username=escape_html(username), - status=_status[status] + """.format( + belong_to=escape_html(admin.username) if admin else None, username=escape_html(username), status=_status[status] ) return report(chat_id=admin.telegram_id if admin and admin.telegram_id else None, text=text) @@ -155,9 +146,7 @@ def report_user_usage_reset(username: str, by: str, admin: Admin = None): Belongs To : {belong_to} By : #{by}\ """.format( - belong_to=escape_html(admin.username) if admin else None, - by=escape_html(by), - username=escape_html(username) + belong_to=escape_html(admin.username) if admin else None, by=escape_html(by), username=escape_html(username) ) return report(chat_id=admin.telegram_id if admin and admin.telegram_id else None, text=text) @@ -187,9 +176,7 @@ def report_user_subscription_revoked(username: str, by: str, admin: Admin = None Belongs To : {belong_to} By : #{by}\ """.format( - belong_to=escape_html(admin.username) if admin else None, - by=escape_html(by), - username=escape_html(username) + belong_to=escape_html(admin.username) if admin else None, by=escape_html(by), username=escape_html(username) ) return report(chat_id=admin.telegram_id if admin and admin.telegram_id else None, text=text) @@ -207,6 +194,6 @@ def report_login(username: str, password: str, client_ip: str, status: str): username=escape_html(username), password=escape_html(password), status=escape_html(status), - client_ip=escape_html(client_ip) + client_ip=escape_html(client_ip), ) return report(text=text) diff --git a/app/telegram/handlers/user.py b/app/telegram/handlers/user.py index 660e2c063..105cd84ca 100644 --- a/app/telegram/handlers/user.py +++ b/app/telegram/handlers/user.py @@ -10,11 +10,11 @@ bot.add_custom_filter(ChatFilter()) -@bot.message_handler(commands=['usage']) +@bot.message_handler(commands=["usage"]) def usage_command(message): username = extract_arguments(message.text) if not username: - return bot.reply_to(message, 'Usage: `/usage `', parse_mode='MarkdownV2') + return bot.reply_to(message, "Usage: `/usage `", parse_mode="MarkdownV2") with GetDB() as db: dbuser = crud.get_user(db, username) @@ -23,13 +23,9 @@ def usage_command(message): return bot.reply_to(message, "No user found with this username") user = UserResponse.model_validate(dbuser) - statuses = { - 'active': '✅', - 'expired': '🕰', - 'limited': '📵', - 'disabled': '❌'} + statuses = {"active": "✅", "expired": "🕰", "limited": "📵", "disabled": "❌"} - text = f'''\ + text = f"""\ ┌─{statuses[user.status]} Status: {user.status.title()} │ └─Username: {user.username} │ @@ -37,6 +33,6 @@ def usage_command(message): │ └─Data Used: {readable_size(user.used_traffic) if user.used_traffic else "-"} │ └─📅 Expiry Date: {datetime.fromtimestamp(user.expire).date() if user.expire else 'Never'} - └─Days left: {(datetime.fromtimestamp(user.expire or 0) - datetime.now()).days if user.expire else '-'}''' + └─Days left: {(datetime.fromtimestamp(user.expire or 0) - datetime.now()).days if user.expire else '-'}""" - return bot.reply_to(message, text, parse_mode='HTML') + return bot.reply_to(message, text, parse_mode="HTML") diff --git a/app/telegram/utils/__init__.py b/app/telegram/utils/__init__.py index be2712244..6699e5edc 100644 --- a/app/telegram/utils/__init__.py +++ b/app/telegram/utils/__init__.py @@ -2,4 +2,4 @@ def setup() -> None: - custom_filters.setup() \ No newline at end of file + custom_filters.setup() diff --git a/app/telegram/utils/custom_filters.py b/app/telegram/utils/custom_filters.py index 8361fe2d7..83c1b6de0 100644 --- a/app/telegram/utils/custom_filters.py +++ b/app/telegram/utils/custom_filters.py @@ -7,7 +7,7 @@ class IsAdminFilter(AdvancedCustomFilter): - key = 'is_admin' + key = "is_admin" def check(self, message, text): """ @@ -26,6 +26,5 @@ def cb_query_startswith(text: str): return lambda query: query.data.startswith(text) - def setup() -> None: - bot.add_custom_filter(IsAdminFilter()) \ No newline at end of file + bot.add_custom_filter(IsAdminFilter()) diff --git a/app/telegram/utils/keyboard.py b/app/telegram/utils/keyboard.py index a11cc9d2a..e297a7468 100644 --- a/app/telegram/utils/keyboard.py +++ b/app/telegram/utils/keyboard.py @@ -15,47 +15,49 @@ def chunk_dict(data: dict, size: int = 2): class BotKeyboard: - @staticmethod def main_menu(): keyboard = types.InlineKeyboardMarkup() keyboard.add( - types.InlineKeyboardButton(text='🔁 System Info', callback_data='system'), - types.InlineKeyboardButton(text='♻️ Restart Xray', callback_data='restart')) - keyboard.add( - types.InlineKeyboardButton(text='👥 Users', callback_data='users:1'), - types.InlineKeyboardButton(text='✏️ Edit All Users', callback_data='edit_all')) - keyboard.add( - types.InlineKeyboardButton(text='➕ Create User From Template', callback_data='template_add_user')) - keyboard.add( - types.InlineKeyboardButton(text='➕ Bulk User From Template', callback_data='template_add_bulk_user')) + types.InlineKeyboardButton(text="🔁 System Info", callback_data="system"), + types.InlineKeyboardButton(text="♻️ Restart Xray", callback_data="restart"), + ) keyboard.add( - types.InlineKeyboardButton(text='➕ Create User', callback_data='add_user')) + types.InlineKeyboardButton(text="👥 Users", callback_data="users:1"), + types.InlineKeyboardButton(text="✏️ Edit All Users", callback_data="edit_all"), + ) + keyboard.add(types.InlineKeyboardButton(text="➕ Create User From Template", callback_data="template_add_user")) keyboard.add( - types.InlineKeyboardButton(text='➕ Create Bulk User', callback_data='add_bulk_user')) + types.InlineKeyboardButton(text="➕ Bulk User From Template", callback_data="template_add_bulk_user") + ) + keyboard.add(types.InlineKeyboardButton(text="➕ Create User", callback_data="add_user")) + keyboard.add(types.InlineKeyboardButton(text="➕ Create Bulk User", callback_data="add_bulk_user")) return keyboard @staticmethod def edit_all_menu(): keyboard = types.InlineKeyboardMarkup() keyboard.add( - types.InlineKeyboardButton(text='🗑 Delete Expired', callback_data='delete_expired'), - types.InlineKeyboardButton(text='🗑 Delete Limited', callback_data='delete_limited')) + types.InlineKeyboardButton(text="🗑 Delete Expired", callback_data="delete_expired"), + types.InlineKeyboardButton(text="🗑 Delete Limited", callback_data="delete_limited"), + ) keyboard.add( - types.InlineKeyboardButton(text='🔋 Data (➕|➖)', callback_data='add_data'), - types.InlineKeyboardButton(text='📅 Time (➕|➖)', callback_data='add_time')) + types.InlineKeyboardButton(text="🔋 Data (➕|➖)", callback_data="add_data"), + types.InlineKeyboardButton(text="📅 Time (➕|➖)", callback_data="add_time"), + ) keyboard.add( - types.InlineKeyboardButton(text='➕ Add Inbound', callback_data='inbound_add'), - types.InlineKeyboardButton(text='➖ Remove Inbound', callback_data='inbound_remove')) - keyboard.add(types.InlineKeyboardButton(text='🔙 Back', callback_data='cancel')) + types.InlineKeyboardButton(text="➕ Add Inbound", callback_data="inbound_add"), + types.InlineKeyboardButton(text="➖ Remove Inbound", callback_data="inbound_remove"), + ) + keyboard.add(types.InlineKeyboardButton(text="🔙 Back", callback_data="cancel")) return keyboard @staticmethod def inbounds_menu(action, inbounds): keyboard = types.InlineKeyboardMarkup() for inbound in inbounds: - keyboard.add(types.InlineKeyboardButton(text=inbound, callback_data=f'confirm_{action}:{inbound}')) - keyboard.add(types.InlineKeyboardButton(text='🔙 Back', callback_data='cancel')) + keyboard.add(types.InlineKeyboardButton(text=inbound, callback_data=f"confirm_{action}:{inbound}")) + keyboard.add(types.InlineKeyboardButton(text="🔙 Back", callback_data="cancel")) return keyboard @staticmethod @@ -68,25 +70,22 @@ def templates_menu(templates: Dict[str, int], username: str = None): row.append( types.InlineKeyboardButton( text=name, - callback_data=f'template_charge:{_id}:{username}' if username else f"template_add_user:{_id}")) + callback_data=f"template_charge:{_id}:{username}" if username else f"template_add_user:{_id}", + ) + ) keyboard.add(*row) keyboard.add( - types.InlineKeyboardButton( - text='🔙 Back', - callback_data=f'user:{username}' if username else 'cancel')) + types.InlineKeyboardButton(text="🔙 Back", callback_data=f"user:{username}" if username else "cancel") + ) return keyboard @staticmethod - def random_username(template_id: str = ''): + def random_username(template_id: str = ""): keyboard = types.InlineKeyboardMarkup() - keyboard.add(types.InlineKeyboardButton( - text='🔡 Random Username', - callback_data=f'random:{template_id}')) - keyboard.add(types.InlineKeyboardButton( - text='🔙 Cancel', - callback_data='cancel')) + keyboard.add(types.InlineKeyboardButton(text="🔡 Random Username", callback_data=f"random:{template_id}")) + keyboard.add(types.InlineKeyboardButton(text="🔙 Cancel", callback_data="cancel")) return keyboard @staticmethod @@ -94,45 +93,25 @@ def user_menu(user_info, with_back: bool = True, page: int = 1): keyboard = types.InlineKeyboardMarkup() keyboard.add( types.InlineKeyboardButton( - text='❌ Disable' if user_info['status'] == 'active' else '✅ Activate', - callback_data=f"{'suspend' if user_info['status'] == 'active' else 'activate'}:{user_info['username']}" - ), - types.InlineKeyboardButton( - text='🗑 Delete', - callback_data=f"delete:{user_info['username']}" + text="❌ Disable" if user_info["status"] == "active" else "✅ Activate", + callback_data=f"{'suspend' if user_info['status'] == 'active' else 'activate'}:{user_info['username']}", ), + types.InlineKeyboardButton(text="🗑 Delete", callback_data=f"delete:{user_info['username']}"), ) keyboard.add( - types.InlineKeyboardButton( - text='🚫 Revoke Sub', - callback_data=f"revoke_sub:{user_info['username']}"), - types.InlineKeyboardButton( - text='✏️ Edit', - callback_data=f"edit:{user_info['username']}")) + types.InlineKeyboardButton(text="🚫 Revoke Sub", callback_data=f"revoke_sub:{user_info['username']}"), + types.InlineKeyboardButton(text="✏️ Edit", callback_data=f"edit:{user_info['username']}"), + ) keyboard.add( - types.InlineKeyboardButton( - text='📝 Edit Note', - callback_data=f"edit_note:{user_info['username']}"), - types.InlineKeyboardButton( - text='📡 Links', - callback_data=f"links:{user_info['username']}")) + types.InlineKeyboardButton(text="📝 Edit Note", callback_data=f"edit_note:{user_info['username']}"), + types.InlineKeyboardButton(text="📡 Links", callback_data=f"links:{user_info['username']}"), + ) keyboard.add( - types.InlineKeyboardButton( - text='🔁 Reset usage', - callback_data=f"reset_usage:{user_info['username']}" - ), - types.InlineKeyboardButton( - text='🔋 Charge', - callback_data=f"charge:{user_info['username']}" - ) + types.InlineKeyboardButton(text="🔁 Reset usage", callback_data=f"reset_usage:{user_info['username']}"), + types.InlineKeyboardButton(text="🔋 Charge", callback_data=f"charge:{user_info['username']}"), ) if with_back: - keyboard.add( - types.InlineKeyboardButton( - text='🔙 Back', - callback_data=f'users:{page}' - ) - ) + keyboard.add(types.InlineKeyboardButton(text="🔙 Back", callback_data=f"users:{page}")) return keyboard @staticmethod @@ -140,21 +119,10 @@ def user_status_select(): keyboard = types.InlineKeyboardMarkup() keyboard.add( - types.InlineKeyboardButton( - text="🟢 active", - callback_data='status:active' - ), - types.InlineKeyboardButton( - text="🟣 onhold", - callback_data='status:onhold' - ) - ) - keyboard.add( - types.InlineKeyboardButton( - text='🔙 Back', - callback_data='cancel' - ) + types.InlineKeyboardButton(text="🟢 active", callback_data="status:active"), + types.InlineKeyboardButton(text="🟣 onhold", callback_data="status:onhold"), ) + keyboard.add(types.InlineKeyboardButton(text="🔙 Back", callback_data="cancel")) return keyboard @staticmethod @@ -162,44 +130,25 @@ def show_links(username: str): keyboard = types.InlineKeyboardMarkup() keyboard.add( - types.InlineKeyboardButton( - text="🖼 Configs QRcode", - callback_data=f'genqr:configs:{username}' - ), - types.InlineKeyboardButton( - text="🚀 Sub QRcode", - callback_data=f'genqr:sub:{username}' - ) - ) - keyboard.add( - types.InlineKeyboardButton( - text='🔙 Back', - callback_data=f'user:{username}' - ) + types.InlineKeyboardButton(text="🖼 Configs QRcode", callback_data=f"genqr:configs:{username}"), + types.InlineKeyboardButton(text="🚀 Sub QRcode", callback_data=f"genqr:sub:{username}"), ) + keyboard.add(types.InlineKeyboardButton(text="🔙 Back", callback_data=f"user:{username}")) return keyboard @staticmethod def subscription_page(sub_url: str): keyboard = types.InlineKeyboardMarkup() - if sub_url[:4] == 'http': - keyboard.add(types.InlineKeyboardButton( - text='🚀 Subscription Page', - url=sub_url)) + if sub_url[:4] == "http": + keyboard.add(types.InlineKeyboardButton(text="🚀 Subscription Page", url=sub_url)) return keyboard @staticmethod def confirm_action(action: str, username: str = None): keyboard = types.InlineKeyboardMarkup() keyboard.add( - types.InlineKeyboardButton( - text='Yes', - callback_data=f"confirm:{action}:{username}" - ), - types.InlineKeyboardButton( - text='No', - callback_data=f"cancel" - ) + types.InlineKeyboardButton(text="Yes", callback_data=f"confirm:{action}:{username}"), + types.InlineKeyboardButton(text="No", callback_data="cancel"), ) return keyboard @@ -208,30 +157,17 @@ def charge_add_or_reset(username: str, template_id: int): keyboard = types.InlineKeyboardMarkup() keyboard.add( types.InlineKeyboardButton( - text='🔰 Add to current', - callback_data=f"confirm:charge_add:{username}:{template_id}" + text="🔰 Add to current", callback_data=f"confirm:charge_add:{username}:{template_id}" ), - types.InlineKeyboardButton( - text='♻️ Reset', - callback_data=f"confirm:charge_reset:{username}:{template_id}" - )) - keyboard.add( - types.InlineKeyboardButton( - text="Cancel", - callback_data=f'user:{username}' - ) + types.InlineKeyboardButton(text="♻️ Reset", callback_data=f"confirm:charge_reset:{username}:{template_id}"), ) + keyboard.add(types.InlineKeyboardButton(text="Cancel", callback_data=f"user:{username}")) return keyboard @staticmethod def inline_cancel_action(callback_data: str = "cancel"): keyboard = types.InlineKeyboardMarkup() - keyboard.add( - types.InlineKeyboardButton( - text="🔙 Cancel", - callback_data=callback_data - ) - ) + keyboard.add(types.InlineKeyboardButton(text="🔙 Cancel", callback_data=callback_data)) return keyboard @staticmethod @@ -239,107 +175,90 @@ def user_list(users: list, page: int, total_pages: int): keyboard = types.InlineKeyboardMarkup() if len(users) >= 2: users = [p for p in users] - users = [users[i:i + 2] for i in range(0, len(users), 2)] + users = [users[i : i + 2] for i in range(0, len(users), 2)] else: users = [users] for user in users: row = [] for p in user: - status = { - 'active': '✅', - 'expired': '🕰', - 'limited': '📵', - 'disabled': '❌', - 'on_hold': '🔌' - } - row.append(types.InlineKeyboardButton( - text=f"{p.username} ({status[p.status]})", - callback_data=f'user:{p.username}:{page}' - )) + status = {"active": "✅", "expired": "🕰", "limited": "📵", "disabled": "❌", "on_hold": "🔌"} + row.append( + types.InlineKeyboardButton( + text=f"{p.username} ({status[p.status]})", callback_data=f"user:{p.username}:{page}" + ) + ) keyboard.row(*row) # if there is more than one page if total_pages > 1: if page > 1: - keyboard.add( - types.InlineKeyboardButton( - text="⬅️ Previous", - callback_data=f'users:{page - 1}' - ) - ) + keyboard.add(types.InlineKeyboardButton(text="⬅️ Previous", callback_data=f"users:{page - 1}")) if page < total_pages: - keyboard.add( - types.InlineKeyboardButton( - text="➡️ Next", - callback_data=f'users:{page + 1}' - ) - ) - keyboard.add( - types.InlineKeyboardButton( - text='🔙 Back', - callback_data='cancel' - ) - ) + keyboard.add(types.InlineKeyboardButton(text="➡️ Next", callback_data=f"users:{page + 1}")) + keyboard.add(types.InlineKeyboardButton(text="🔙 Back", callback_data="cancel")) return keyboard @staticmethod def select_protocols( - selected_protocols: Dict[str, List[str]], - action: Literal["edit", "create", "create_from_template"], - username: str = None, - data_limit: float = None, - expire_date: dt = None, - expire_on_hold_duration: int = None, - expire_on_hold_timeout: dt = None + selected_protocols: Dict[str, List[str]], + action: Literal["edit", "create", "create_from_template"], + username: str = None, + data_limit: float = None, + expire_date: dt = None, + expire_on_hold_duration: int = None, + expire_on_hold_timeout: dt = None, ): keyboard = types.InlineKeyboardMarkup() if action == "edit": - keyboard.add(types.InlineKeyboardButton(text="⚠️ Data Limit:", callback_data=f"help_edit")) + keyboard.add(types.InlineKeyboardButton(text="⚠️ Data Limit:", callback_data="help_edit")) keyboard.add( types.InlineKeyboardButton( - text=f"{readable_size(data_limit) if data_limit else 'Unlimited'}", - callback_data=f"help_edit" + text=f"{readable_size(data_limit) if data_limit else 'Unlimited'}", callback_data="help_edit" ), - types.InlineKeyboardButton(text="✏️ Edit", callback_data=f"edit_user:{username}:data")) + types.InlineKeyboardButton(text="✏️ Edit", callback_data=f"edit_user:{username}:data"), + ) if expire_on_hold_duration: - keyboard.add(types.InlineKeyboardButton(text="⏳ Duration:", callback_data=f"edit_user:{username}:expire")) + keyboard.add( + types.InlineKeyboardButton(text="⏳ Duration:", callback_data=f"edit_user:{username}:expire") + ) keyboard.add( types.InlineKeyboardButton( text=f"{int(expire_on_hold_duration / 24 / 60 / 60)} روز", - callback_data=f"edit_user:{username}:expire" + callback_data=f"edit_user:{username}:expire", ), - types.InlineKeyboardButton(text="✏️ Edit", callback_data=f"edit_user:{username}:expire")) + types.InlineKeyboardButton(text="✏️ Edit", callback_data=f"edit_user:{username}:expire"), + ) keyboard.add( types.InlineKeyboardButton( - text="🌀 Auto enable at:", - callback_data=f"edit_user:{username}:expire_on_hold_timeout" + text="🌀 Auto enable at:", callback_data=f"edit_user:{username}:expire_on_hold_timeout" ) ) keyboard.add( types.InlineKeyboardButton( text=f"{expire_on_hold_timeout.strftime('%Y-%m-%d') if expire_on_hold_timeout else 'Never'}", - callback_data=f"edit_user:{username}:expire_on_hold_timeout"), + callback_data=f"edit_user:{username}:expire_on_hold_timeout", + ), types.InlineKeyboardButton( - text="✏️ Edit", - callback_data=f"edit_user:{username}:expire_on_hold_timeout" - ) + text="✏️ Edit", callback_data=f"edit_user:{username}:expire_on_hold_timeout" + ), ) else: - keyboard.add(types.InlineKeyboardButton(text="📅 Expire Date:", callback_data=f"help_edit")) + keyboard.add(types.InlineKeyboardButton(text="📅 Expire Date:", callback_data="help_edit")) keyboard.add( types.InlineKeyboardButton( text=f"{expire_date.strftime('%Y-%m-%d') if expire_date else 'Never'}", - callback_data=f"help_edit" + callback_data="help_edit", ), - types.InlineKeyboardButton(text="✏️ Edit", callback_data=f"edit_user:{username}:expire")) + types.InlineKeyboardButton(text="✏️ Edit", callback_data=f"edit_user:{username}:expire"), + ) - if action != 'create_from_template': + if action != "create_from_template": for protocol, inbounds in xray.config.inbounds_by_protocol.items(): keyboard.add( types.InlineKeyboardButton( text=f"🌐 {protocol.upper()} {'✅' if protocol in selected_protocols else '❌'}", - callback_data=f'select_protocol:{protocol}:{action}' + callback_data=f"select_protocol:{protocol}:{action}", ) ) if protocol in selected_protocols: @@ -347,20 +266,18 @@ def select_protocols( keyboard.add( types.InlineKeyboardButton( text=f"«{inbound['tag']}» {'✅' if inbound['tag'] in selected_protocols[protocol] else '❌'}", - callback_data=f'select_inbound:{inbound["tag"]}:{action}' + callback_data=f'select_inbound:{inbound["tag"]}:{action}', ) ) keyboard.add( types.InlineKeyboardButton( - text='Done', - callback_data='confirm:edit_user' if action == "edit" else 'confirm:add_user' + text="Done", callback_data="confirm:edit_user" if action == "edit" else "confirm:add_user" ) ) keyboard.add( types.InlineKeyboardButton( - text='Cancel', - callback_data=f'user:{username}' if action == "edit" else 'cancel' + text="Cancel", callback_data=f"user:{username}" if action == "edit" else "cancel" ) ) diff --git a/app/telegram/utils/shared.py b/app/telegram/utils/shared.py index ac6b089c8..4b7c528e1 100644 --- a/app/telegram/utils/shared.py +++ b/app/telegram/utils/shared.py @@ -52,7 +52,7 @@ def get_user_info_text(db_user: User) -> str: used_traffic = readable_size(user.used_traffic) if user.used_traffic else "-" data_left = readable_size(user.data_limit - user.used_traffic) if user.data_limit else "-" on_hold_timeout = user.on_hold_timeout.strftime("%Y-%m-%d") if user.on_hold_timeout else "-" - on_hold_duration = user.on_hold_expire_duration // (24*60*60) if user.on_hold_expire_duration else None + on_hold_duration = user.on_hold_expire_duration // (24 * 60 * 60) if user.on_hold_expire_duration else None expiry_date = dt.fromtimestamp(user.expire).date() if user.expire else "Never" time_left = time_to_string(dt.fromtimestamp(user.expire)) if user.expire else "-" online_at = time_to_string(user.online_at) if user.online_at else "-" @@ -86,8 +86,11 @@ def get_template_info_text(template: UserTemplate): protocols += f"\n├─ {p.upper()}\n" protocols += "├───" + ", ".join([f"{i}" for i in inbounds]) data_limit = readable_size(template.data_limit) if template.data_limit else "Unlimited" - expire = ((dt.now() + relativedelta(seconds=template.expire_duration)) - .strftime("%Y-%m-%d")) if template.expire_duration else "Never" + expire = ( + ((dt.now() + relativedelta(seconds=template.expire_duration)).strftime("%Y-%m-%d")) + if template.expire_duration + else "Never" + ) text = f""" 📊 Template Info: ID: {template.id} @@ -100,6 +103,6 @@ def get_template_info_text(template: UserTemplate): def get_number_at_end(username: str): - n = re.search(r'(\d+)$', username) + n = re.search(r"(\d+)$", username) if n: return n.group(1) diff --git a/app/templates/__init__.py b/app/templates/__init__.py index 8d0c2ba14..9a8e77ef4 100644 --- a/app/templates/__init__.py +++ b/app/templates/__init__.py @@ -14,7 +14,7 @@ env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_directories)) env.filters.update(CUSTOM_FILTERS) -env.globals['now'] = datetime.utcnow +env.globals["now"] = datetime.utcnow def render_template(template: str, context: Union[dict, None] = None) -> str: diff --git a/app/templates/filters.py b/app/templates/filters.py index 70af740b9..b9b01da2c 100644 --- a/app/templates/filters.py +++ b/app/templates/filters.py @@ -24,7 +24,7 @@ def only_keys(obj, *target_keys): def datetimeformat(dt): if isinstance(dt, int): dt = datetime.fromtimestamp(dt) - formatted_datetime = dt.strftime('%Y-%m-%d %H:%M:%S') + formatted_datetime = dt.strftime("%Y-%m-%d %H:%M:%S") return formatted_datetime @@ -37,5 +37,5 @@ def env_override(value, key): "except": exclude_keys, "only": only_keys, "datetime": datetimeformat, - "bytesformat": readable_size + "bytesformat": readable_size, } diff --git a/app/utils/concurrency.py b/app/utils/concurrency.py index c55c27773..9c8c7a623 100644 --- a/app/utils/concurrency.py +++ b/app/utils/concurrency.py @@ -8,6 +8,7 @@ def threaded_function(func): def wrapper(*args, **kwargs): thread = Thread(target=func, args=args, daemon=True, kwargs=kwargs) thread.start() + return wrapper diff --git a/app/utils/crypto.py b/app/utils/crypto.py index 82d239c53..ed8236bca 100644 --- a/app/utils/crypto.py +++ b/app/utils/crypto.py @@ -20,14 +20,11 @@ def generate_certificate(): cert = crypto.X509() cert.get_subject().CN = "Gozargah" cert.gmtime_adj_notBefore(0) - cert.gmtime_adj_notAfter(100*365*24*60*60) + cert.gmtime_adj_notAfter(100 * 365 * 24 * 60 * 60) cert.set_issuer(cert.get_subject()) cert.set_pubkey(k) - cert.sign(k, 'sha512') + cert.sign(k, "sha512") cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8") key_pem = crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8") - return { - "cert": cert_pem, - "key": key_pem - } + return {"cert": cert_pem, "key": key_pem} diff --git a/app/utils/helpers.py b/app/utils/helpers.py index 4318480f9..b6d388991 100644 --- a/app/utils/helpers.py +++ b/app/utils/helpers.py @@ -12,7 +12,7 @@ def calculate_expiration_days(expire: dt) -> int: def yml_uuid_representer(dumper, data): - return dumper.represent_scalar('tag:yaml.org,2002:str', str(data)) + return dumper.represent_scalar("tag:yaml.org,2002:str", str(data)) class UUIDEncoder(json.JSONEncoder): diff --git a/app/utils/jwt.py b/app/utils/jwt.py index b2202cf94..6d1f93a18 100644 --- a/app/utils/jwt.py +++ b/app/utils/jwt.py @@ -14,6 +14,7 @@ @lru_cache(maxsize=None) def get_secret_key(): from app.db import GetDB, get_jwt_secret_key + with GetDB() as db: return get_jwt_secret_key(db) @@ -32,10 +33,10 @@ def get_admin_payload(token: str) -> Union[dict, None]: payload = jwt.decode(token, get_secret_key(), algorithms=["HS256"]) username: str = payload.get("sub") access: str = payload.get("access") - if not username or access not in ('admin', 'sudo'): + if not username or access not in ("admin", "sudo"): return try: - created_at = datetime.utcfromtimestamp(payload['iat']) + created_at = datetime.utcfromtimestamp(payload["iat"]) except KeyError: created_at = None @@ -45,14 +46,11 @@ def get_admin_payload(token: str) -> Union[dict, None]: def create_subscription_token(username: str) -> str: - data = username + ',' + str(ceil(time.time())) - data_b64_str = b64encode(data.encode('utf-8'), altchars=b'-_').decode('utf-8').rstrip('=') + data = username + "," + str(ceil(time.time())) + data_b64_str = b64encode(data.encode("utf-8"), altchars=b"-_").decode("utf-8").rstrip("=") data_b64_sign = b64encode( - sha256( - (data_b64_str+get_secret_key()).encode('utf-8') - ).digest(), - altchars=b'-_' - ).decode('utf-8')[:10] + sha256((data_b64_str + get_secret_key()).encode("utf-8")).digest(), altchars=b"-_" + ).decode("utf-8")[:10] data_final = data_b64_str + data_b64_sign return data_final @@ -65,7 +63,7 @@ def get_subscription_payload(token: str) -> Union[dict, None]: if token.startswith("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."): payload = jwt.decode(token, get_secret_key(), algorithms=["HS256"]) if payload.get("access") == "subscription": - return {"username": payload['sub'], "created_at": datetime.utcfromtimestamp(payload['iat'])} + return {"username": payload["sub"], "created_at": datetime.utcfromtimestamp(payload["iat"])} else: return else: @@ -73,16 +71,19 @@ def get_subscription_payload(token: str) -> Union[dict, None]: u_signature = token[-10:] try: u_token_dec = b64decode( - (u_token.encode('utf-8') + b'=' * (-len(u_token.encode('utf-8')) % 4)), - altchars=b'-_', validate=True) - u_token_dec_str = u_token_dec.decode('utf-8') - except: + (u_token.encode("utf-8") + b"=" * (-len(u_token.encode("utf-8")) % 4)), + altchars=b"-_", + validate=True, + ) + u_token_dec_str = u_token_dec.decode("utf-8") + except Exception: return - u_token_resign = b64encode(sha256((u_token+get_secret_key()).encode('utf-8') - ).digest(), altchars=b'-_').decode('utf-8')[:10] + u_token_resign = b64encode( + sha256((u_token + get_secret_key()).encode("utf-8")).digest(), altchars=b"-_" + ).decode("utf-8")[:10] if u_signature == u_token_resign: - u_username = u_token_dec_str.split(',')[0] - u_created_at = int(u_token_dec_str.split(',')[1]) + u_username = u_token_dec_str.split(",")[0] + u_created_at = int(u_token_dec_str.split(",")[1]) return {"username": u_username, "created_at": datetime.utcfromtimestamp(u_created_at)} else: return diff --git a/app/utils/report.py b/app/utils/report.py index 8f947f20d..16dbbcf1a 100644 --- a/app/utils/report.py +++ b/app/utils/report.py @@ -2,16 +2,26 @@ from typing import Optional from app import telegram -from app.db import Session, create_notification_reminder, get_admin_by_id, GetDB -from app.db.models import UserStatus, User +from app.db import Session, create_notification_reminder +from app.db.models import UserStatus from app.models.admin import Admin from app.models.user import ReminderType, UserResponse -from app.utils.notification import (Notification, ReachedDaysLeft, - ReachedUsagePercent, UserCreated, UserDataResetByNext, - UserDataUsageReset, UserDeleted, - UserDisabled, UserEnabled, UserExpired, - UserLimited, UserSubscriptionRevoked, - UserUpdated, notify) +from app.utils.notification import ( + Notification, + ReachedDaysLeft, + ReachedUsagePercent, + UserCreated, + UserDataResetByNext, + UserDataUsageReset, + UserDeleted, + UserDisabled, + UserEnabled, + UserExpired, + UserLimited, + UserSubscriptionRevoked, + UserUpdated, + notify, +) from app import discord from config import ( @@ -23,12 +33,13 @@ NOTIFY_USER_SUB_REVOKED, NOTIFY_IF_DATA_USAGE_PERCENT_REACHED, NOTIFY_IF_DAYS_LEFT_REACHED, - NOTIFY_LOGIN + NOTIFY_LOGIN, ) def status_change( - username: str, status: UserStatus, user: UserResponse, user_admin: Admin = None, by: Admin = None) -> None: + username: str, status: UserStatus, user: UserResponse, user_admin: Admin = None, by: Admin = None +) -> None: if NOTIFY_STATUS_CHANGE: try: telegram.report_status_change(username, status, user_admin) @@ -60,7 +71,7 @@ def user_created(user: UserResponse, user_id: int, by: Admin, user_admin: Admin proxies=user.proxies, has_next_plan=user.next_plan is not None, data_limit_reset_strategy=user.data_limit_reset_strategy, - admin=user_admin + admin=user_admin, ) except Exception: pass @@ -74,7 +85,7 @@ def user_created(user: UserResponse, user_id: int, by: Admin, user_admin: Admin proxies=user.proxies, has_next_plan=user.next_plan is not None, data_limit_reset_strategy=user.data_limit_reset_strategy, - admin=user_admin + admin=user_admin, ) except Exception: pass @@ -91,7 +102,7 @@ def user_updated(user: UserResponse, by: Admin, user_admin: Admin = None) -> Non by=by.username, has_next_plan=user.next_plan is not None, data_limit_reset_strategy=user.data_limit_reset_strategy, - admin=user_admin + admin=user_admin, ) except Exception: pass @@ -105,7 +116,7 @@ def user_updated(user: UserResponse, by: Admin, user_admin: Admin = None) -> Non by=by.username, has_next_plan=user.next_plan is not None, data_limit_reset_strategy=user.data_limit_reset_strategy, - admin=user_admin + admin=user_admin, ) except Exception: pass @@ -127,20 +138,12 @@ def user_deleted(username: str, by: Admin, user_admin: Admin = None) -> None: def user_data_usage_reset(user: UserResponse, by: Admin, user_admin: Admin = None) -> None: if NOTIFY_USER_DATA_USED_RESET: try: - telegram.report_user_usage_reset( - username=user.username, - by=by.username, - admin=user_admin - ) + telegram.report_user_usage_reset(username=user.username, by=by.username, admin=user_admin) except Exception: pass notify(UserDataUsageReset(username=user.username, action=Notification.Type.data_usage_reset, by=by, user=user)) try: - discord.report_user_usage_reset( - username=user.username, - by=by.username, - admin=user_admin - ) + discord.report_user_usage_reset(username=user.username, by=by.username, admin=user_admin) except Exception: pass @@ -148,18 +151,12 @@ def user_data_usage_reset(user: UserResponse, by: Admin, user_admin: Admin = Non def user_data_reset_by_next(user: UserResponse, user_admin: Admin = None) -> None: if NOTIFY_USER_DATA_USED_RESET: try: - telegram.report_user_data_reset_by_next( - user=user, - admin=user_admin - ) + telegram.report_user_data_reset_by_next(user=user, admin=user_admin) except Exception: pass notify(UserDataResetByNext(username=user.username, action=Notification.Type.data_reset_by_next, user=user)) try: - discord.report_user_data_reset_by_next( - user=user, - admin=user_admin - ) + discord.report_user_data_reset_by_next(user=user, admin=user_admin) except Exception: pass @@ -167,39 +164,41 @@ def user_data_reset_by_next(user: UserResponse, user_admin: Admin = None) -> Non def user_subscription_revoked(user: UserResponse, by: Admin, user_admin: Admin = None) -> None: if NOTIFY_USER_SUB_REVOKED: try: - telegram.report_user_subscription_revoked( - username=user.username, - by=by.username, - admin=user_admin - ) + telegram.report_user_subscription_revoked(username=user.username, by=by.username, admin=user_admin) except Exception: pass - notify(UserSubscriptionRevoked(username=user.username, - action=Notification.Type.subscription_revoked, by=by, user=user)) - try: - discord.report_user_subscription_revoked( - username=user.username, - by=by.username, - admin=user_admin + notify( + UserSubscriptionRevoked( + username=user.username, action=Notification.Type.subscription_revoked, by=by, user=user ) + ) + try: + discord.report_user_subscription_revoked(username=user.username, by=by.username, admin=user_admin) except Exception: pass def data_usage_percent_reached( - db: Session, percent: float, user: UserResponse, user_id: int, expire: Optional[dt] = None, threshold: Optional[int] = None) -> None: + db: Session, + percent: float, + user: UserResponse, + user_id: int, + expire: Optional[dt] = None, + threshold: Optional[int] = None, +) -> None: if NOTIFY_IF_DATA_USAGE_PERCENT_REACHED: notify(ReachedUsagePercent(username=user.username, user=user, used_percent=percent)) - create_notification_reminder(db, ReminderType.data_usage, - expires_at=expire if expire else None, user_id=user_id, threshold=threshold) + create_notification_reminder( + db, ReminderType.data_usage, expires_at=expire if expire else None, user_id=user_id, threshold=threshold + ) def expire_days_reached(db: Session, days: int, user: UserResponse, user_id: int, expire: dt, threshold=None) -> None: notify(ReachedDaysLeft(username=user.username, user=user, days_left=days)) if NOTIFY_IF_DAYS_LEFT_REACHED: create_notification_reminder( - db, ReminderType.expiration_date, expires_at=expire, - user_id=user_id, threshold=threshold) + db, ReminderType.expiration_date, expires_at=expire, user_id=user_id, threshold=threshold + ) def login(username: str, password: str, client_ip: str, success: bool) -> None: @@ -209,7 +208,7 @@ def login(username: str, password: str, client_ip: str, success: bool) -> None: username=username, password=password, client_ip=client_ip, - status="✅ Success" if success else "❌ Failed" + status="✅ Success" if success else "❌ Failed", ) except Exception: pass @@ -218,7 +217,7 @@ def login(username: str, password: str, client_ip: str, success: bool) -> None: username=username, password=password, client_ip=client_ip, - status="✅ Success" if success else "❌ Failed" + status="✅ Success" if success else "❌ Failed", ) except Exception: pass diff --git a/app/utils/responses.py b/app/utils/responses.py index e624a787c..7611ae9da 100644 --- a/app/utils/responses.py +++ b/app/utils/responses.py @@ -31,9 +31,7 @@ class Conflict(HTTPException): "headers": { "WWW-Authenticate": { "description": "Authentication type", - "schema": { - "type": "string" - }, + "schema": {"type": "string"}, }, }, } diff --git a/app/utils/system.py b/app/utils/system.py index b7a40eacb..1f2aa7362 100644 --- a/app/utils/system.py +++ b/app/utils/system.py @@ -12,14 +12,14 @@ @dataclass -class MemoryStat(): +class MemoryStat: total: int used: int free: int @dataclass -class CPUStat(): +class CPUStat: cores: int percent: float @@ -66,8 +66,7 @@ class RealtimeBandwidthStat: outgoing_packets: int -rt_bw = RealtimeBandwidth( - incoming_bytes=0, outgoing_bytes=0, incoming_packets=0, outgoing_packets=0) +rt_bw = RealtimeBandwidth(incoming_bytes=0, outgoing_bytes=0, incoming_packets=0, outgoing_packets=0) # sample time is 2 seconds, values lower than this may not produce good results @@ -80,8 +79,14 @@ def record_realtime_bandwidth() -> None: sample_time = rt_bw.last_perf_counter - last_perf_counter rt_bw.incoming_bytes, rt_bw.bytes_recv = round((io.bytes_recv - rt_bw.bytes_recv) / sample_time), io.bytes_recv rt_bw.outgoing_bytes, rt_bw.bytes_sent = round((io.bytes_sent - rt_bw.bytes_sent) / sample_time), io.bytes_sent - rt_bw.incoming_packets, rt_bw.packets_recv = round((io.packets_recv - rt_bw.packets_recv) / sample_time), io.packets_recv - rt_bw.outgoing_packets, rt_bw.packets_sent = round((io.packets_sent - rt_bw.packets_sent) / sample_time), io.packets_sent + rt_bw.incoming_packets, rt_bw.packets_recv = ( + round((io.packets_recv - rt_bw.packets_recv) / sample_time), + io.packets_recv, + ) + rt_bw.outgoing_packets, rt_bw.packets_sent = ( + round((io.packets_sent - rt_bw.packets_sent) / sample_time), + io.packets_sent, + ) def realtime_bandwidth() -> RealtimeBandwidthStat: @@ -100,7 +105,7 @@ def random_password() -> str: def check_port(port: int) -> bool: s = socket.socket() try: - s.connect(('127.0.0.1', port)) + s.connect(("127.0.0.1", port)) return True except socket.error: return False @@ -110,22 +115,22 @@ def check_port(port: int) -> bool: def get_public_ip(): try: - resp = requests.get('http://api4.ipify.org/', timeout=5).text.strip() + resp = requests.get("http://api4.ipify.org/", timeout=5).text.strip() if ipaddress.IPv4Address(resp).is_global: return resp - except: + except Exception: pass try: - resp = requests.get('http://ipv4.icanhazip.com/', timeout=5).text.strip() + resp = requests.get("http://ipv4.icanhazip.com/", timeout=5).text.strip() if ipaddress.IPv4Address(resp).is_global: return resp - except: + except Exception: pass try: requests.packages.urllib3.util.connection.HAS_IPV6 = False - resp = requests.get('https://ifconfig.io/ip', timeout=5).text.strip() + resp = requests.get("https://ifconfig.io/ip", timeout=5).text.strip() if ipaddress.IPv4Address(resp).is_global: return resp except requests.exceptions.RequestException: @@ -135,7 +140,7 @@ def get_public_ip(): try: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.connect(('8.8.8.8', 80)) + sock.connect(("8.8.8.8", 80)) resp = sock.getsockname()[0] if ipaddress.IPv4Address(resp).is_global: return resp @@ -144,25 +149,25 @@ def get_public_ip(): finally: sock.close() - return '127.0.0.1' + return "127.0.0.1" def get_public_ipv6(): try: - resp = requests.get('http://api6.ipify.org/', timeout=5).text.strip() + resp = requests.get("http://api6.ipify.org/", timeout=5).text.strip() if ipaddress.IPv6Address(resp).is_global: - return '[%s]' % resp - except: + return "[%s]" % resp + except Exception: pass try: - resp = requests.get('http://ipv6.icanhazip.com/', timeout=5).text.strip() + resp = requests.get("http://ipv6.icanhazip.com/", timeout=5).text.strip() if ipaddress.IPv6Address(resp).is_global: - return '[%s]' % resp - except: + return "[%s]" % resp + except Exception: pass - return '[::1]' + return "[::1]" def readable_size(size_bytes): @@ -172,4 +177,4 @@ def readable_size(size_bytes): i = int(math.floor(math.log(size_bytes, 1024))) p = math.pow(1024, i) s = round(size_bytes / p, 2) - return f'{s} {size_name[i]}' + return f"{s} {size_name[i]}" diff --git a/app/xray/__init__.py b/app/xray/__init__.py index c33a073a7..f3f79cdf7 100644 --- a/app/xray/__init__.py +++ b/app/xray/__init__.py @@ -45,18 +45,16 @@ def hosts(storage: dict): storage[host.id] = { "remark": host.remark, "inbound_tag": host.inbound_tag, - "address": [i.strip() for i in host.address.split(',')] if host.address else [], + "address": [i.strip() for i in host.address.split(",")] if host.address else [], "port": host.port, "path": host.path if host.path else None, - "sni": [i.strip() for i in host.sni.split(',')] if host.sni else [], - "host": [i.strip() for i in host.host.split(',')] if host.host else [], + "sni": [i.strip() for i in host.sni.split(",")] if host.sni else [], + "host": [i.strip() for i in host.host.split(",")] if host.host else [], "alpn": host.alpn.value, "fingerprint": host.fingerprint.value, # None means the tls is not specified by host itself and # complies with its inbound's settings. - "tls": None - if host.security == ProxyHostSecurity.inbound_default - else host.security.value, + "tls": None if host.security == ProxyHostSecurity.inbound_default else host.security.value, "allowinsecure": host.allowinsecure, "mux_enable": host.mux_enable, "fragment_setting": host.fragment_setting, diff --git a/app/xray/config.py b/app/xray/config.py index b428c2d67..8d0296248 100644 --- a/app/xray/config.py +++ b/app/xray/config.py @@ -27,21 +27,18 @@ def merge_dicts(a, b): # B will override A dictionary key and values class XRayConfig(dict): - def __init__(self, - config: Union[dict, str, PosixPath] = {}, - api_host: str = "127.0.0.1", - api_port: int = 8080): + def __init__(self, config: Union[dict, str, PosixPath] = {}, api_host: str = "127.0.0.1", api_port: int = 8080): if isinstance(config, str): try: # considering string as json config = commentjson.loads(config) except (json.JSONDecodeError, ValueError): # considering string as file path - with open(config, 'r') as file: + with open(config, "r") as file: config = commentjson.loads(file.read()) if isinstance(config, PosixPath): - with open(config, 'r') as file: + with open(config, "r") as file: config = commentjson.loads(file.read()) if isinstance(config, dict): @@ -69,28 +66,16 @@ def _apply_api(self): api_inbound["port"] = self.api_port return - self["api"] = { - "services": [ - "HandlerService", - "StatsService", - "LoggerService" - ], - "tag": "API" - } + self["api"] = {"services": ["HandlerService", "StatsService", "LoggerService"], "tag": "API"} self["stats"] = {} forced_policies = { - "levels": { - "0": { - "statsUserUplink": True, - "statsUserDownlink": True - } - }, + "levels": {"0": {"statsUserUplink": True, "statsUserDownlink": True}}, "system": { "statsInboundDownlink": False, "statsInboundUplink": False, "statsOutboundDownlink": True, - "statsOutboundUplink": True - } + "statsOutboundUplink": True, + }, } if self.get("policy"): self["policy"] = merge_dicts(self.get("policy"), forced_policies) @@ -100,10 +85,8 @@ def _apply_api(self): "listen": self.api_host, "port": self.api_port, "protocol": "dokodemo-door", - "settings": { - "address": self.api_host - }, - "tag": "API_INBOUND" + "settings": {"address": self.api_host}, + "tag": "API_INBOUND", } try: self["inbounds"].insert(0, inbound) @@ -111,13 +94,7 @@ def _apply_api(self): self["inbounds"] = [] self["inbounds"].insert(0, inbound) - rule = { - "inboundTag": [ - "API_INBOUND" - ], - "outboundTag": "API", - "type": "field" - } + rule = {"inboundTag": ["API_INBOUND"], "outboundTag": "API", "type": "field"} try: self["routing"]["rules"].insert(0, rule) except KeyError: @@ -131,227 +108,227 @@ def _validate(self): if not self.get("outbounds"): raise ValueError("config doesn't have outbounds") - for inbound in self['inbounds']: + for inbound in self["inbounds"]: if not inbound.get("tag"): raise ValueError("all inbounds must have a unique tag") - if ',' in inbound.get("tag"): + if "," in inbound.get("tag"): raise ValueError("character «,» is not allowed in inbound tag") - for outbound in self['outbounds']: + for outbound in self["outbounds"]: if not outbound.get("tag"): raise ValueError("all outbounds must have a unique tag") def _resolve_inbounds(self): - for inbound in self['inbounds']: - if not inbound['protocol'] in ProxyTypes._value2member_map_: + for inbound in self["inbounds"]: + if inbound["protocol"] not in ProxyTypes._value2member_map_: continue - if inbound['tag'] in XRAY_EXCLUDE_INBOUND_TAGS: + if inbound["tag"] in XRAY_EXCLUDE_INBOUND_TAGS: continue - if not inbound.get('settings'): - inbound['settings'] = {} - if not inbound['settings'].get('clients'): - inbound['settings']['clients'] = [] + if not inbound.get("settings"): + inbound["settings"] = {} + if not inbound["settings"].get("clients"): + inbound["settings"]["clients"] = [] settings = { "tag": inbound["tag"], "protocol": inbound["protocol"], "port": None, "network": "tcp", - "tls": 'none', + "tls": "none", "sni": [], "host": [], "path": "", "header_type": "", - "is_fallback": False + "is_fallback": False, } # port settings try: - settings['port'] = inbound['port'] + settings["port"] = inbound["port"] except KeyError: if self._fallbacks_inbound: try: - settings['port'] = self._fallbacks_inbound['port'] - settings['is_fallback'] = True + settings["port"] = self._fallbacks_inbound["port"] + settings["is_fallback"] = True except KeyError: raise ValueError("fallbacks inbound doesn't have port") # stream settings - if stream := inbound.get('streamSettings'): - net = stream.get('network', 'tcp') + if stream := inbound.get("streamSettings"): + net = stream.get("network", "tcp") net_settings = stream.get(f"{net}Settings", {}) security = stream.get("security") tls_settings = stream.get(f"{security}Settings") - if settings['is_fallback'] is True: + if settings["is_fallback"] is True: # probably this is a fallback - security = self._fallbacks_inbound.get( - 'streamSettings', {}).get('security') - tls_settings = self._fallbacks_inbound.get( - 'streamSettings', {}).get(f"{security}Settings", {}) + security = self._fallbacks_inbound.get("streamSettings", {}).get("security") + tls_settings = self._fallbacks_inbound.get("streamSettings", {}).get(f"{security}Settings", {}) - settings['network'] = net + settings["network"] = net - if security == 'tls': + if security == "tls": # settings['fp'] # settings['alpn'] - settings['tls'] = 'tls' - for certificate in tls_settings.get('certificates', []): - + settings["tls"] = "tls" + for certificate in tls_settings.get("certificates", []): if certificate.get("certificateFile", None): - with open(certificate['certificateFile'], 'rb') as file: + with open(certificate["certificateFile"], "rb") as file: cert = file.read() - settings['sni'].extend(get_cert_SANs(cert)) + settings["sni"].extend(get_cert_SANs(cert)) if certificate.get("certificate", None): - cert = certificate['certificate'] + cert = certificate["certificate"] if isinstance(cert, list): - cert = '\n'.join(cert) + cert = "\n".join(cert) if isinstance(cert, str): cert = cert.encode() - settings['sni'].extend(get_cert_SANs(cert)) + settings["sni"].extend(get_cert_SANs(cert)) - elif security == 'reality': - settings['fp'] = 'chrome' - settings['tls'] = 'reality' - settings['sni'] = tls_settings.get('serverNames', []) + elif security == "reality": + settings["fp"] = "chrome" + settings["tls"] = "reality" + settings["sni"] = tls_settings.get("serverNames", []) try: - settings['pbk'] = tls_settings['publicKey'] + settings["pbk"] = tls_settings["publicKey"] except KeyError: - pvk = tls_settings.get('privateKey') + pvk = tls_settings.get("privateKey") if not pvk: - raise ValueError( - f"You need to provide privateKey in realitySettings of {inbound['tag']}") + raise ValueError(f"You need to provide privateKey in realitySettings of {inbound['tag']}") try: from app.xray import core + x25519 = core.get_x25519(pvk) - settings['pbk'] = x25519['public_key'] + settings["pbk"] = x25519["public_key"] except ImportError: pass - if not settings.get('pbk'): - raise ValueError( - f"You need to provide publicKey in realitySettings of {inbound['tag']}") + if not settings.get("pbk"): + raise ValueError(f"You need to provide publicKey in realitySettings of {inbound['tag']}") try: - settings['sids'] = tls_settings.get('shortIds') - settings['sids'][0] # check if there is any shortIds + settings["sids"] = tls_settings.get("shortIds") + settings["sids"][0] # check if there is any shortIds except (IndexError, TypeError): raise ValueError( - f"You need to define at least one shortID in realitySettings of {inbound['tag']}") + f"You need to define at least one shortID in realitySettings of {inbound['tag']}" + ) try: - settings['spx'] = tls_settings.get('SpiderX') - except: - settings['spx'] = "" + settings["spx"] = tls_settings.get("SpiderX") + except Exception: + settings["spx"] = "" - if net in ('tcp', 'raw'): - header = net_settings.get('header', {}) - request = header.get('request', {}) - path = request.get('path') - host = request.get('headers', {}).get('Host') + if net in ("tcp", "raw"): + header = net_settings.get("header", {}) + request = header.get("request", {}) + path = request.get("path") + host = request.get("headers", {}).get("Host") - settings['header_type'] = header.get('type', '') + settings["header_type"] = header.get("type", "") if isinstance(path, str) or isinstance(host, str): - raise ValueError(f"Settings of {inbound['tag']} for path and host must be list, not str\n" - "https://xtls.github.io/config/transports/tcp.html#httpheaderobject") + raise ValueError( + f"Settings of {inbound['tag']} for path and host must be list, not str\n" + "https://xtls.github.io/config/transports/tcp.html#httpheaderobject" + ) if path and isinstance(path, list): - settings['path'] = path[0] + settings["path"] = path[0] if host and isinstance(host, list): - settings['host'] = host + settings["host"] = host - elif net == 'ws': - path = net_settings.get('path', '') - host = net_settings.get('host', '') or net_settings.get('headers', {}).get('Host') + elif net == "ws": + path = net_settings.get("path", "") + host = net_settings.get("host", "") or net_settings.get("headers", {}).get("Host") - settings['header_type'] = '' + settings["header_type"] = "" if isinstance(path, list) or isinstance(host, list): - raise ValueError(f"Settings of {inbound['tag']} for path and host must be str, not list\n" - "https://xtls.github.io/config/transports/websocket.html#websocketobject") + raise ValueError( + f"Settings of {inbound['tag']} for path and host must be str, not list\n" + "https://xtls.github.io/config/transports/websocket.html#websocketobject" + ) if isinstance(path, str): - settings['path'] = path + settings["path"] = path if isinstance(host, str): - settings['host'] = [host] - - settings["heartbeatPeriod"] = net_settings.get('heartbeatPeriod', 0) - elif net == 'grpc' or net == 'gun': - settings['header_type'] = '' - settings['path'] = net_settings.get('serviceName', '') - host = net_settings.get('authority', '') - settings['host'] = [host] - settings['multiMode'] = net_settings.get('multiMode', False) - - elif net == 'quic': - settings['header_type'] = net_settings.get('header', {}).get('type', '') - settings['path'] = net_settings.get('key', '') - settings['host'] = [net_settings.get('security', '')] - - elif net == 'httpupgrade': - settings['path'] = net_settings.get('path', '') - host = net_settings.get('host', '') - settings['host'] = [host] - - elif net in ('splithttp', 'xhttp'): - settings['path'] = net_settings.get('path', '') - host = net_settings.get('host', '') - settings['host'] = [host] - settings['scMaxEachPostBytes'] = net_settings.get('scMaxEachPostBytes') - settings['scMaxConcurrentPosts'] = net_settings.get('scMaxConcurrentPosts') - settings['scMinPostsIntervalMs'] = net_settings.get('scMinPostsIntervalMs') - settings['xPaddingBytes'] = net_settings.get('xPaddingBytes') + settings["host"] = [host] + + settings["heartbeatPeriod"] = net_settings.get("heartbeatPeriod", 0) + elif net == "grpc" or net == "gun": + settings["header_type"] = "" + settings["path"] = net_settings.get("serviceName", "") + host = net_settings.get("authority", "") + settings["host"] = [host] + settings["multiMode"] = net_settings.get("multiMode", False) + + elif net == "quic": + settings["header_type"] = net_settings.get("header", {}).get("type", "") + settings["path"] = net_settings.get("key", "") + settings["host"] = [net_settings.get("security", "")] + + elif net == "httpupgrade": + settings["path"] = net_settings.get("path", "") + host = net_settings.get("host", "") + settings["host"] = [host] + + elif net in ("splithttp", "xhttp"): + settings["path"] = net_settings.get("path", "") + host = net_settings.get("host", "") + settings["host"] = [host] + settings["scMaxEachPostBytes"] = net_settings.get("scMaxEachPostBytes") + settings["scMaxConcurrentPosts"] = net_settings.get("scMaxConcurrentPosts") + settings["scMinPostsIntervalMs"] = net_settings.get("scMinPostsIntervalMs") + settings["xPaddingBytes"] = net_settings.get("xPaddingBytes") settings["noGRPCHeader"] = net_settings.get("noGRPCHeader") - settings['xmux'] = net_settings.get('xmux', {}) - settings['downloadSettings'] = net_settings.get('downloadSettings', {}) + settings["xmux"] = net_settings.get("xmux", {}) + settings["downloadSettings"] = net_settings.get("downloadSettings", {}) settings["mode"] = net_settings.get("mode", "auto") settings["keepAlivePeriod"] = net_settings.get("keepAlivePeriod", 0) settings["scStreamUpServerSecs"] = net_settings.get("scStreamUpServerSecs") - elif net == 'kcp': - header = net_settings.get('header', {}) + elif net == "kcp": + header = net_settings.get("header", {}) - settings['header_type'] = header.get('type', '') - settings['host'] = header.get('domain', '') - settings['path'] = net_settings.get('seed', '') + settings["header_type"] = header.get("type", "") + settings["host"] = header.get("domain", "") + settings["path"] = net_settings.get("seed", "") elif net in ("http", "h2", "h3"): net_settings = stream.get("httpSettings", {}) - settings['host'] = net_settings.get('host') or net_settings.get('Host', '') - settings['path'] = net_settings.get('path', '') + settings["host"] = net_settings.get("host") or net_settings.get("Host", "") + settings["path"] = net_settings.get("path", "") else: - settings['path'] = net_settings.get('path', '') - host = net_settings.get( - 'host', {}) or net_settings.get('Host', {}) + settings["path"] = net_settings.get("path", "") + host = net_settings.get("host", {}) or net_settings.get("Host", {}) if host and isinstance(host, str): - settings['host'] = host + settings["host"] = host elif host and isinstance(host, list): - settings['host'] = host[0] + settings["host"] = host[0] self.inbounds.append(settings) - self.inbounds_by_tag[inbound['tag']] = settings + self.inbounds_by_tag[inbound["tag"]] = settings try: - self.inbounds_by_protocol[inbound['protocol']].append(settings) + self.inbounds_by_protocol[inbound["protocol"]].append(settings) except KeyError: - self.inbounds_by_protocol[inbound['protocol']] = [settings] + self.inbounds_by_protocol[inbound["protocol"]] = [settings] def get_inbound(self, tag) -> dict: - for inbound in self['inbounds']: - if inbound['tag'] == tag: + for inbound in self["inbounds"]: + if inbound["tag"] == tag: return inbound def get_outbound(self, tag) -> dict: - for outbound in self['outbounds']: - if outbound['tag'] == tag: + for outbound in self["outbounds"]: + if outbound["tag"] == tag: return outbound def to_json(self, **json_kwargs): @@ -364,75 +341,74 @@ def include_db_users(self) -> XRayConfig: config = self.copy() with GetDB() as db: - query = db.query( - db_models.User.id, - db_models.User.username, - func.lower(db_models.Proxy.type).label('type'), - db_models.Proxy.settings, - func.group_concat(db_models.excluded_inbounds_association.c.inbound_tag).label('excluded_inbound_tags') - ).join( - db_models.Proxy, db_models.User.id == db_models.Proxy.user_id - ).outerjoin( - db_models.excluded_inbounds_association, - db_models.Proxy.id == db_models.excluded_inbounds_association.c.proxy_id - ).filter( - db_models.User.status.in_([UserStatus.active, UserStatus.on_hold]) - ).group_by( - func.lower(db_models.Proxy.type), - db_models.User.id, - db_models.User.username, - db_models.Proxy.settings, + query = ( + db.query( + db_models.User.id, + db_models.User.username, + func.lower(db_models.Proxy.type).label("type"), + db_models.Proxy.settings, + func.group_concat(db_models.excluded_inbounds_association.c.inbound_tag).label( + "excluded_inbound_tags" + ), + ) + .join(db_models.Proxy, db_models.User.id == db_models.Proxy.user_id) + .outerjoin( + db_models.excluded_inbounds_association, + db_models.Proxy.id == db_models.excluded_inbounds_association.c.proxy_id, + ) + .filter(db_models.User.status.in_([UserStatus.active, UserStatus.on_hold])) + .group_by( + func.lower(db_models.Proxy.type), + db_models.User.id, + db_models.User.username, + db_models.Proxy.settings, + ) ) result = query.all() grouped_data = defaultdict(list) for row in result: - grouped_data[row.type].append(( - row.id, - row.username, - row.settings, - [i for i in row.excluded_inbound_tags.split(',') if i] if row.excluded_inbound_tags else None - )) + grouped_data[row.type].append( + ( + row.id, + row.username, + row.settings, + [i for i in row.excluded_inbound_tags.split(",") if i] if row.excluded_inbound_tags else None, + ) + ) for proxy_type, rows in grouped_data.items(): - inbounds = self.inbounds_by_protocol.get(proxy_type) if not inbounds: continue for inbound in inbounds: - clients = config.get_inbound(inbound['tag'])['settings']['clients'] + clients = config.get_inbound(inbound["tag"])["settings"]["clients"] for row in rows: user_id, username, settings, excluded_inbound_tags = row - if excluded_inbound_tags and inbound['tag'] in excluded_inbound_tags: + if excluded_inbound_tags and inbound["tag"] in excluded_inbound_tags: continue - client = { - "email": f"{user_id}.{username}", - **settings - } + client = {"email": f"{user_id}.{username}", **settings} # XTLS currently only supports transmission methods of TCP and mKCP - if client.get('flow') and ( - inbound.get('network', 'tcp') not in ('tcp', 'raw', 'kcp') - or - ( - inbound.get('network', 'tcp') in ('tcp', 'raw', 'kcp') - and - inbound.get('tls') not in ('tls', 'reality') - ) - or - inbound.get('header_type') == 'http' + if client.get("flow") and ( + inbound.get("network", "tcp") not in ("tcp", "raw", "kcp") + or ( + inbound.get("network", "tcp") in ("tcp", "raw", "kcp") + and inbound.get("tls") not in ("tls", "reality") + ) + or inbound.get("header_type") == "http" ): - del client['flow'] + del client["flow"] clients.append(client) if DEBUG: - with open('generated_config-debug.json', 'w') as f: + with open("generated_config-debug.json", "w") as f: f.write(config.to_json(indent=4)) return config diff --git a/app/xray/core.py b/app/xray/core.py index 7b152b13c..259433581 100644 --- a/app/xray/core.py +++ b/app/xray/core.py @@ -11,9 +11,7 @@ class XRayCore: - def __init__(self, - executable_path: str = "/usr/bin/xray", - assets_path: str = "/usr/share/xray"): + def __init__(self, executable_path: str = "/usr/bin/xray", assets_path: str = "/usr/share/xray"): self.executable_path = executable_path self.assets_path = assets_path @@ -25,31 +23,26 @@ def __init__(self, self._temp_log_buffers = {} self._on_start_funcs = [] self._on_stop_funcs = [] - self._env = { - "XRAY_LOCATION_ASSET": assets_path - } + self._env = {"XRAY_LOCATION_ASSET": assets_path} atexit.register(lambda: self.stop() if self.started else None) def get_version(self): cmd = [self.executable_path, "version"] - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode('utf-8') - m = re.match(r'^Xray (\d+\.\d+\.\d+)', output) + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") + m = re.match(r"^Xray (\d+\.\d+\.\d+)", output) if m: return m.groups()[0] def get_x25519(self, private_key: str = None): cmd = [self.executable_path, "x25519"] if private_key: - cmd.extend(['-i', private_key]) - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode('utf-8') - m = re.match(r'Private key: (.+)\nPublic key: (.+)', output) + cmd.extend(["-i", private_key]) + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") + m = re.match(r"Private key: (.+)\nPublic key: (.+)", output) if m: private, public = m.groups() - return { - "private_key": private, - "public_key": public - } + return {"private_key": private, "public_key": public} def __capture_process_logs(self): def capture_and_debug_log(): @@ -107,22 +100,17 @@ def start(self, config: XRayConfig): if self.started is True: raise RuntimeError("Xray is started already") - if config.get('log', {}).get('logLevel') in ('none', 'error'): - config['log']['logLevel'] = 'warning' + if config.get("log", {}).get("logLevel") in ("none", "error"): + config["log"]["logLevel"] = "warning" - cmd = [ - self.executable_path, - "run", - '-config', - 'stdin:' - ] + cmd = [self.executable_path, "run", "-config", "stdin:"] self.process = subprocess.Popen( cmd, env=self._env, stdin=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE, - universal_newlines=True + universal_newlines=True, ) self.process.stdin.write(config.to_json()) self.process.stdin.flush() diff --git a/app/xray/node.py b/app/xray/node.py index 1f936a9cf..b714a9cb3 100644 --- a/app/xray/node.py +++ b/app/xray/node.py @@ -20,7 +20,7 @@ def string_to_temp_file(content: str): - file = tempfile.NamedTemporaryFile(mode='w+t') + file = tempfile.NamedTemporaryFile(mode="w+t") file.write(content) file.flush() return file @@ -28,10 +28,7 @@ def string_to_temp_file(content: str): class SANIgnoringAdaptor(HTTPAdapter): def init_poolmanager(self, connections, maxsize, block=False): - self.poolmanager = PoolManager(num_pools=connections, - maxsize=maxsize, - block=block, - assert_hostname=False) + self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize, block=block, assert_hostname=False) class NodeAPIError(Exception): @@ -41,14 +38,9 @@ def __init__(self, status_code, detail): class ReSTXRayNode: - def __init__(self, - address: str, - port: int, - api_port: int, - ssl_key: str, - ssl_cert: str, - usage_coefficient: float = 1): - + def __init__( + self, address: str, port: int, api_port: int, ssl_key: str, ssl_cert: str, usage_coefficient: float = 1 + ): self.address = address self.port = port self.api_port = api_port @@ -60,7 +52,7 @@ def __init__(self, self._certfile = string_to_temp_file(ssl_cert) self.session = requests.Session() - self.session.mount('https://', SANIgnoringAdaptor()) + self.session.mount("https://", SANIgnoringAdaptor()) self.session.cert = (self._certfile.name, self._keyfile.name) self._session_id = None @@ -84,25 +76,22 @@ def _prepare_config(self, config: XRayConfig): certificates = tlsSettings.get("certificates") or [] for certificate in certificates: if certificate.get("certificateFile"): - with open(certificate['certificateFile']) as file: - certificate['certificate'] = [ - line.strip() for line in file.readlines() - ] - del certificate['certificateFile'] + with open(certificate["certificateFile"]) as file: + certificate["certificate"] = [line.strip() for line in file.readlines()] + del certificate["certificateFile"] if certificate.get("keyFile"): - with open(certificate['keyFile']) as file: - certificate['key'] = [ - line.strip() for line in file.readlines() - ] - del certificate['keyFile'] + with open(certificate["keyFile"]) as file: + certificate["key"] = [line.strip() for line in file.readlines()] + del certificate["keyFile"] return config def make_request(self, path: str, timeout: int, **params): try: - res = self.session.post(self._rest_api_url + path, timeout=timeout, - json={"session_id": self._session_id, **params}) + res = self.session.post( + self._rest_api_url + path, timeout=timeout, json={"session_id": self._session_id, **params} + ) data = res.json() except Exception as e: exc = NodeAPIError(0, str(e)) @@ -111,7 +100,7 @@ def make_request(self, path: str, timeout: int, **params): if res.status_code == 200: return data else: - exc = NodeAPIError(res.status_code, data['detail']) + exc = NodeAPIError(res.status_code, data["detail"]) raise exc @property @@ -127,7 +116,7 @@ def connected(self): @property def started(self): res = self.make_request("/", timeout=3) - return res.get('started', False) + return res.get("started", False) @property def api(self): @@ -140,7 +129,7 @@ def api(self): address=self.address, port=self.api_port, ssl_cert=self._node_cert.encode(), - ssl_target_name="Gozargah" + ssl_target_name="Gozargah", ) else: raise ConnectionError("Node is not started") @@ -153,7 +142,7 @@ def connect(self): self.session.verify = self._node_certfile.name res = self.make_request("/connect", timeout=3) - self._session_id = res['session_id'] + self._session_id = res["session_id"] def disconnect(self): self.make_request("/disconnect", timeout=3) @@ -161,7 +150,7 @@ def disconnect(self): def get_version(self): res = self.make_request("/", timeout=3) - return res.get('core_version') + return res.get("core_version") def start(self, config: XRayConfig): if not self.connected: @@ -173,7 +162,7 @@ def start(self, config: XRayConfig): try: res = self.make_request("/start", timeout=10, config=json_config) except NodeAPIError as exc: - if exc.detail == 'Xray is started already': + if exc.detail == "Xray is started already": return self.restart(config) else: raise exc @@ -181,16 +170,13 @@ def start(self, config: XRayConfig): self._started = True self._api = XRayAPI( - address=self.address, - port=self.api_port, - ssl_cert=self._node_cert.encode(), - ssl_target_name="Gozargah" + address=self.address, port=self.api_port, ssl_cert=self._node_cert.encode(), ssl_target_name="Gozargah" ) try: grpc.channel_ready_future(self._api._channel).result(timeout=5) except grpc.FutureTimeoutError: - raise ConnectionError('Failed to connect to node\'s API') + raise ConnectionError("Failed to connect to node's API") return res @@ -198,7 +184,7 @@ def stop(self): if not self.connected: self.connect() - self.make_request('/stop', timeout=5) + self.make_request("/stop", timeout=5) self._api = None self._started = False @@ -214,16 +200,13 @@ def restart(self, config: XRayConfig): self._started = True self._api = XRayAPI( - address=self.address, - port=self.api_port, - ssl_cert=self._node_cert.encode(), - ssl_target_name="Gozargah" + address=self.address, port=self.api_port, ssl_cert=self._node_cert.encode(), ssl_target_name="Gozargah" ) try: grpc.channel_ready_future(self._api._channel).result(timeout=5) except grpc.FutureTimeoutError: - raise ConnectionError('Failed to connect to node\'s API') + raise ConnectionError("Failed to connect to node's API") return res @@ -272,18 +255,11 @@ def get_logs(self): class RPyCXRayNode: - def __init__(self, - address: str, - port: int, - api_port: int, - ssl_key: str, - ssl_cert: str, - usage_coefficient: float = 1): - + def __init__( + self, address: str, port: int, api_port: int, ssl_key: str, ssl_cert: str, usage_coefficient: float = 1 + ): class Service(rpyc.Service): - def __init__(self, - on_start_funcs: List[callable] = [], - on_stop_funcs: List[callable] = []): + def __init__(self, on_start_funcs: List[callable] = [], on_stop_funcs: List[callable] = []): self.on_start_funcs = on_start_funcs self.on_stop_funcs = on_stop_funcs @@ -337,13 +313,15 @@ def connect(self): tries += 1 self._node_cert = ssl.get_server_certificate((self.address, self.port)) self._node_certfile = string_to_temp_file(self._node_cert) - conn = rpyc.ssl_connect(self.address, - self.port, - service=self._service, - keyfile=self._keyfile.name, - certfile=self._certfile.name, - ca_certs=self._node_certfile.name, - keepalive=True) + conn = rpyc.ssl_connect( + self.address, + self.port, + service=self._service, + keyfile=self._keyfile.name, + certfile=self._certfile.name, + ca_certs=self._node_certfile.name, + keepalive=True, + ) try: conn.ping() self.connection = conn @@ -357,7 +335,7 @@ def connect(self): def connected(self): try: self.connection.ping() - return (not self.connection.closed) + return not self.connection.closed except (AttributeError, EOFError, TimeoutError): self.disconnect() return False @@ -388,18 +366,14 @@ def _prepare_config(self, config: XRayConfig): certificates = tlsSettings.get("certificates") or [] for certificate in certificates: if certificate.get("certificateFile"): - with open(certificate['certificateFile']) as file: - certificate['certificate'] = [ - line.strip() for line in file.readlines() - ] - del certificate['certificateFile'] + with open(certificate["certificateFile"]) as file: + certificate["certificate"] = [line.strip() for line in file.readlines()] + del certificate["certificateFile"] if certificate.get("keyFile"): - with open(certificate['keyFile']) as file: - certificate['key'] = [ - line.strip() for line in file.readlines() - ] - del certificate['keyFile'] + with open(certificate["keyFile"]) as file: + certificate["key"] = [line.strip() for line in file.readlines()] + del certificate["keyFile"] return config @@ -411,30 +385,26 @@ def start(self, config: XRayConfig): # connect to API self._api = XRayAPI( - address=self.address, - port=self.api_port, - ssl_cert=self._node_cert.encode(), - ssl_target_name="Gozargah" + address=self.address, port=self.api_port, ssl_cert=self._node_cert.encode(), ssl_target_name="Gozargah" ) try: grpc.channel_ready_future(self._api._channel).result(timeout=5) except grpc.FutureTimeoutError: - start_time = time.time() end_time = start_time + 3 # check logs for 3 seconds - last_log = '' + last_log = "" with self.get_logs() as logs: while time.time() < end_time: if logs: - last_log = logs[-1].strip().split('\n')[-1] + last_log = logs[-1].strip().split("\n")[-1] time.sleep(0.1) self.disconnect() - if re.search(r'[Ff]ailed', last_log): + if re.search(r"[Ff]ailed", last_log): raise RuntimeError(last_log) - raise ConnectionError('Failed to connect to node\'s API') + raise ConnectionError("Failed to connect to node's API") def stop(self): self.remote.stop() @@ -494,20 +464,15 @@ def on_stop(self, func: callable): class XRayNode: - def __new__(self, - address: str, - port: int, - api_port: int, - ssl_key: str, - ssl_cert: str, - usage_coefficient: float = 1): - + def __new__( + self, address: str, port: int, api_port: int, ssl_key: str, ssl_cert: str, usage_coefficient: float = 1 + ): # trying to detect what's the server of node try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(1) s.connect((address, port)) - s.send(b'HEAD / HTTP/1.0\r\n\r\n') + s.send(b"HEAD / HTTP/1.0\r\n\r\n") s.recv(1024) s.close() # it might be uvicorn @@ -517,7 +482,7 @@ def __new__(self, api_port=api_port, ssl_key=ssl_key, ssl_cert=ssl_cert, - usage_coefficient=usage_coefficient + usage_coefficient=usage_coefficient, ) except Exception: # if might be rpyc @@ -527,5 +492,5 @@ def __new__(self, api_port=api_port, ssl_key=ssl_key, ssl_cert=ssl_cert, - usage_coefficient=usage_coefficient + usage_coefficient=usage_coefficient, ) diff --git a/app/xray/operations.py b/app/xray/operations.py index 386d83c79..e90adfc3a 100644 --- a/app/xray/operations.py +++ b/app/xray/operations.py @@ -20,12 +20,10 @@ @lru_cache(maxsize=None) def get_tls(): from app.db import GetDB, get_tls_certificate + with GetDB() as db: tls = get_tls_certificate(db) - return { - "key": tls.key, - "certificate": tls.certificate - } + return {"key": tls.key, "certificate": tls.certificate} @threaded_function @@ -71,16 +69,13 @@ def add_user(dbuser: "DBUser"): account = proxy_type.account_model(email=email, **proxy_settings) # XTLS currently only supports transmission methods of TCP and mKCP - if getattr(account, 'flow', None) and ( - inbound.get('network', 'tcp') not in ('tcp', 'raw', 'kcp') - or - ( - inbound.get('network', 'tcp') in ('tcp', 'raw', 'kcp') - and - inbound.get('tls') not in ('tls', 'reality') + if getattr(account, "flow", None) and ( + inbound.get("network", "tcp") not in ("tcp", "raw", "kcp") + or ( + inbound.get("network", "tcp") in ("tcp", "raw", "kcp") + and inbound.get("tls") not in ("tls", "reality") ) - or - inbound.get('header_type') == 'http' + or inbound.get("header_type") == "http" ): account.flow = XTLSFlows.NONE @@ -117,16 +112,10 @@ def update_user(dbuser: "DBUser"): account = proxy_type.account_model(email=email, **proxy_settings) # XTLS currently only supports transmission methods of TCP and mKCP - if getattr(account, 'flow', None) and ( - inbound.get('network', 'tcp') not in ('tcp', 'kcp') - or - ( - inbound.get('network', 'tcp') in ('tcp', 'kcp') - and - inbound.get('tls') not in ('tls', 'reality') - ) - or - inbound.get('header_type') == 'http' + if getattr(account, "flow", None) and ( + inbound.get("network", "tcp") not in ("tcp", "kcp") + or (inbound.get("network", "tcp") in ("tcp", "kcp") and inbound.get("tls") not in ("tls", "reality")) + or inbound.get("header_type") == "http" ): account.flow = XTLSFlows.NONE @@ -162,12 +151,14 @@ def add_node(dbnode: "DBNode"): remove_node(dbnode.id) tls = get_tls() - xray.nodes[dbnode.id] = XRayNode(address=dbnode.address, - port=dbnode.port, - api_port=dbnode.api_port, - ssl_key=tls['key'], - ssl_cert=tls['certificate'], - usage_coefficient=dbnode.usage_coefficient) + xray.nodes[dbnode.id] = XRayNode( + address=dbnode.address, + port=dbnode.port, + api_port=dbnode.api_port, + ssl_key=tls["key"], + ssl_cert=tls["certificate"], + usage_coefficient=dbnode.usage_coefficient, + ) return xray.nodes[dbnode.id] @@ -215,7 +206,7 @@ def connect_node(node_id, config=None): _connecting_nodes[node_id] = True _change_node_status(node_id, NodeStatus.connecting) - logger.info(f"Connecting to \"{dbnode.name}\" node") + logger.info(f'Connecting to "{dbnode.name}" node') if config is None: config = xray.config.include_db_users() @@ -223,11 +214,11 @@ def connect_node(node_id, config=None): node.start(config) version = node.get_version() _change_node_status(node_id, NodeStatus.connected, version=version) - logger.info(f"Connected to \"{dbnode.name}\" node, xray run on v{version}") + logger.info(f'Connected to "{dbnode.name}" node, xray run on v{version}') except Exception as e: _change_node_status(node_id, NodeStatus.error, message=str(e)) - logger.info(f"Unable to connect to \"{dbnode.name}\" node") + logger.info(f'Unable to connect to "{dbnode.name}" node') finally: try: @@ -253,13 +244,13 @@ def restart_node(node_id, config=None): return connect_node(node_id, config) try: - logger.info(f"Restarting Xray core of \"{dbnode.name}\" node") + logger.info(f'Restarting Xray core of "{dbnode.name}" node') if config is None: config = xray.config.include_db_users() node.restart(config) - logger.info(f"Xray core of \"{dbnode.name}\" node restarted") + logger.info(f'Xray core of "{dbnode.name}" node restarted') except Exception as e: _change_node_status(node_id, NodeStatus.error, message=str(e)) logger.info(f"Unable to restart node {node_id}") diff --git a/cli/admin.py b/cli/admin.py index e4ac38523..5b4e18242 100644 --- a/cli/admin.py +++ b/cli/admin.py @@ -58,27 +58,38 @@ def list_admins( with GetDB() as db: admins: list[Admin] = crud.get_admins(db, offset=offset, limit=limit, username=username) utils.print_table( - table=Table("Username", 'Usage', 'Reseted usage', "Users Usage", "Is sudo", "Is disabled", - "Created at", "Telegram ID", "Discord Webhook"), + table=Table( + "Username", + "Usage", + "Reseted usage", + "Users Usage", + "Is sudo", + "Is disabled", + "Created at", + "Telegram ID", + "Discord Webhook", + ), rows=[ - (str(admin.username), - calculate_admin_usage(admin.id), - calculate_admin_reseted_usage(admin.id), - readable_size(admin.users_usage), - "✔️" if admin.is_sudo else "✖️", - "✔️" if admin.is_disabled else "✖️", - utils.readable_datetime(admin.created_at), - str(admin.telegram_id or "✖️"), - str(admin.discord_webhook or "✖️")) + ( + str(admin.username), + calculate_admin_usage(admin.id), + calculate_admin_reseted_usage(admin.id), + readable_size(admin.users_usage), + "✔️" if admin.is_sudo else "✖️", + "✔️" if admin.is_disabled else "✖️", + utils.readable_datetime(admin.created_at), + str(admin.telegram_id or "✖️"), + str(admin.discord_webhook or "✖️"), + ) for admin in admins - ] + ], ) @app.command(name="delete") def delete_admin( username: str = typer.Option(..., *utils.FLAGS["username"], prompt=True), - yes_to_all: bool = typer.Option(False, *utils.FLAGS["yes_to_all"], help="Skips confirmations") + yes_to_all: bool = typer.Option(False, *utils.FLAGS["yes_to_all"], help="Skips confirmations"), ): """ Deletes the specified admin @@ -88,7 +99,7 @@ def delete_admin( with GetDB() as db: admin: Union[Admin, None] = crud.get_admin(db, username=username) if not admin: - utils.error(f"There's no admin with username \"{username}\"!") + utils.error(f'There\'s no admin with username "{username}"!') if yes_to_all or typer.confirm(f'Are you sure about deleting "{username}"?', default=False): crud.remove_admin(db, admin) @@ -101,12 +112,15 @@ def delete_admin( def create_admin( username: str = typer.Option(..., *utils.FLAGS["username"], show_default=False, prompt=True), is_sudo: bool = typer.Option(False, *utils.FLAGS["is_sudo"], prompt=True), - password: str = typer.Option(..., prompt=True, confirmation_prompt=True, - hide_input=True, hidden=True, envvar=utils.PASSWORD_ENVIRON_NAME), - telegram_id: str = typer.Option('', *utils.FLAGS["telegram_id"], prompt="Telegram ID", - show_default=False, callback=validate_telegram_id), - discord_webhook: str = typer.Option('', *utils.FLAGS["discord_webhook"], prompt=True, - show_default=False, callback=validate_discord_webhook), + password: str = typer.Option( + ..., prompt=True, confirmation_prompt=True, hide_input=True, hidden=True, envvar=utils.PASSWORD_ENVIRON_NAME + ), + telegram_id: str = typer.Option( + "", *utils.FLAGS["telegram_id"], prompt="Telegram ID", show_default=False, callback=validate_telegram_id + ), + discord_webhook: str = typer.Option( + "", *utils.FLAGS["discord_webhook"], prompt=True, show_default=False, callback=validate_discord_webhook + ), ): """ Creates an admin @@ -115,11 +129,16 @@ def create_admin( """ with GetDB() as db: try: - crud.create_admin(db, AdminCreate(username=username, - password=password, - is_sudo=is_sudo, - telegram_id=telegram_id, - discord_webhook=discord_webhook)) + crud.create_admin( + db, + AdminCreate( + username=username, + password=password, + is_sudo=is_sudo, + telegram_id=telegram_id, + discord_webhook=discord_webhook, + ), + ) utils.success(f'Admin "{username}" created successfully.') except IntegrityError: utils.error(f'Admin "{username}" already exists!') @@ -134,26 +153,21 @@ def update_admin(username: str = typer.Option(..., *utils.FLAGS["username"], pro """ def _get_modify_model(admin: Admin): - Console().print( - Panel(f'Editing "{username}". Just press "Enter" to leave each field unchanged.') - ) + Console().print(Panel(f'Editing "{username}". Just press "Enter" to leave each field unchanged.')) is_sudo: bool = typer.confirm("Is sudo", default=admin.is_sudo) is_disabled: bool = typer.confirm("Is disabled", default=admin.is_disabled) - new_password: Union[str, None] = typer.prompt( - "New password", - default="", - show_default=False, - confirmation_prompt=True, - hide_input=True - ) or None - - telegram_id: str = typer.prompt("Telegram ID (Enter 0 to clear current value)", - default=admin.telegram_id or "") + new_password: Union[str, None] = ( + typer.prompt("New password", default="", show_default=False, confirmation_prompt=True, hide_input=True) + or None + ) + + telegram_id: str = typer.prompt("Telegram ID (Enter 0 to clear current value)", default=admin.telegram_id or "") telegram_id = validate_telegram_id(telegram_id) - discord_webhook: str = typer.prompt("Discord webhook (Enter 0 to clear current value)", - default=admin.discord_webhook or "") + discord_webhook: str = typer.prompt( + "Discord webhook (Enter 0 to clear current value)", default=admin.discord_webhook or "" + ) discord_webhook = validate_discord_webhook(discord_webhook) return AdminPartialModify( @@ -167,7 +181,7 @@ def _get_modify_model(admin: Admin): with GetDB() as db: admin: Union[Admin, None] = crud.get_admin(db, username=username) if not admin: - utils.error(f"There's no admin with username \"{username}\"!") + utils.error(f'There\'s no admin with username "{username}"!') crud.partial_update_admin(db, admin, _get_modify_model(admin)) utils.success(f'Admin "{username}" updated successfully.') @@ -193,8 +207,9 @@ def import_from_env(yes_to_all: bool = typer.Option(False, *utils.FLAGS["yes_to_ ) if not (username and password): - utils.error("Unable to retrieve username and password.\n" - "Make sure both SUDO_USERNAME and SUDO_PASSWORD are set.") + utils.error( + "Unable to retrieve username and password.\n" "Make sure both SUDO_USERNAME and SUDO_PASSWORD are set." + ) with GetDB() as db: admin: Union[None, Admin] = None @@ -206,18 +221,10 @@ def import_from_env(yes_to_all: bool = typer.Option(False, *utils.FLAGS["yes_to_ ): utils.error("Aborted.") - admin = crud.partial_update_admin( - db, - current_admin, - AdminPartialModify(password=password, is_sudo=True) - ) + admin = crud.partial_update_admin(db, current_admin, AdminPartialModify(password=password, is_sudo=True)) # If env admin does not exist yet else: - admin = crud.create_admin(db, AdminCreate( - username=username, - password=password, - is_sudo=True - )) + admin = crud.create_admin(db, AdminCreate(username=username, password=password, is_sudo=True)) updated_user_count = db.query(User).filter_by(admin_id=None).update({"admin_id": admin.id}) db.commit() @@ -225,5 +232,5 @@ def import_from_env(yes_to_all: bool = typer.Option(False, *utils.FLAGS["yes_to_ utils.success( f'Admin "{username}" imported successfully.\n' f"{updated_user_count} users' admin_id set to the {username}'s id.\n" - 'You must delete SUDO_USERNAME and SUDO_PASSWORD from your env file now.' + "You must delete SUDO_USERNAME and SUDO_PASSWORD from your env file now." ) diff --git a/cli/subscription.py b/cli/subscription.py index 5f26f7728..ae9768cc3 100644 --- a/cli/subscription.py +++ b/cli/subscription.py @@ -19,9 +19,7 @@ class ConfigFormat(str, Enum): @app.command(name="get-link") -def get_link( - username: str = typer.Option(..., *utils.FLAGS["username"], prompt=True) -): +def get_link(username: str = typer.Option(..., *utils.FLAGS["username"], prompt=True)): """ Prints the given user's subscription link. @@ -40,9 +38,7 @@ def get_config( output_file: Optional[str] = typer.Option( None, *utils.FLAGS["output_file"], help="Writes the generated config in the file if provided" ), - as_base64: bool = typer.Option( - False, "--base64", is_flag=True, help="Encodes output in base64 format if present" - ) + as_base64: bool = typer.Option(False, "--base64", is_flag=True, help="Encodes output in base64 format if present"), ): """ Generates a subscription config. @@ -54,9 +50,7 @@ def get_config( """ with GetDB() as db: user: UserResponse = UserResponse.model_validate(utils.get_user(db, username)) - conf: str = generate_subscription( - user=user, config_format=config_format.name, as_base64=as_base64 - ) + conf: str = generate_subscription(user=user, config_format=config_format.name, as_base64=as_base64) if output_file: with open(output_file, "w") as out_file: @@ -68,8 +62,7 @@ def get_config( ) else: utils.success( - 'No output file specified.' - f' using pager for {username}\'s config in "{config_format}" format.', - auto_exit=False + "No output file specified." f' using pager for {username}\'s config in "{config_format}" format.', + auto_exit=False, ) utils.paginate(conf) diff --git a/cli/user.py b/cli/user.py index 062e0f30f..76015a97a 100644 --- a/cli/user.py +++ b/cli/user.py @@ -19,7 +19,7 @@ def list_users( username: Optional[List[str]] = typer.Option(None, *utils.FLAGS["username"], help="Search by username(s)"), search: Optional[str] = typer.Option(None, *utils.FLAGS["search"], help="Search by username/note"), status: Optional[crud.UserStatus] = typer.Option(None, *utils.FLAGS["status"]), - admins: Optional[List[str]] = typer.Option(None, *utils.FLAGS["admin"], help="Search by owner admin's username(s)") + admins: Optional[List[str]] = typer.Option(None, *utils.FLAGS["admin"], help="Search by owner admin's username(s)"), ): """ Displays a table of users @@ -28,15 +28,19 @@ def list_users( """ with GetDB() as db: users: list[User] = crud.get_users( - db=db, offset=offset, limit=limit, - usernames=username, search=search, status=status, - admins=admins + db=db, offset=offset, limit=limit, usernames=username, search=search, status=status, admins=admins ) utils.print_table( table=Table( - "ID", "Username", "Status", "Used traffic", - "Data limit", "Reset strategy", "Expires at", "Owner", + "ID", + "Username", + "Status", + "Used traffic", + "Data limit", + "Reset strategy", + "Expires at", + "Owner", ), rows=[ ( @@ -47,10 +51,10 @@ def list_users( readable_size(user.data_limit) if user.data_limit else "Unlimited", user.data_limit_reset_strategy.value, utils.readable_datetime(user.expire, include_time=False), - user.admin.username if user.admin else '' + user.admin.username if user.admin else "", ) for user in users - ] + ], ) @@ -58,7 +62,7 @@ def list_users( def set_owner( username: str = typer.Option(None, *utils.FLAGS["username"], prompt=True), admin: str = typer.Option(None, "--admin", "--owner", prompt=True, help="Admin's username"), - yes_to_all: bool = typer.Option(False, *utils.FLAGS["yes_to_all"], help="Skips confirmations") + yes_to_all: bool = typer.Option(False, *utils.FLAGS["yes_to_all"], help="Skips confirmations"), ): """ Transfers user's ownership @@ -66,16 +70,18 @@ def set_owner( NOTE: This command needs additional confirmation for users who already have an owner. """ with GetDB() as db: - user: User = utils.raise_if_falsy( - crud.get_user(db, username=username), f'User "{username}" not found.') + user: User = utils.raise_if_falsy(crud.get_user(db, username=username), f'User "{username}" not found.') - dbadmin = utils.raise_if_falsy( - crud.get_admin(db, username=admin), f'Admin "{admin}" not found.') + dbadmin = utils.raise_if_falsy(crud.get_admin(db, username=admin), f'Admin "{admin}" not found.') # Ask for confirmation if user already has an owner - if user.admin and not yes_to_all and not typer.confirm( - f'{username}\'s current owner is "{user.admin.username}".' - f' Are you sure about transferring its ownership to "{admin}"?' + if ( + user.admin + and not yes_to_all + and not typer.confirm( + f'{username}\'s current owner is "{user.admin.username}".' + f' Are you sure about transferring its ownership to "{admin}"?' + ) ): utils.error("Aborted.") diff --git a/cli/utils.py b/cli/utils.py index 9bbc5eb91..74d0809bc 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -54,22 +54,14 @@ def get_user(db, username: str) -> User: return user -def print_table( - table: Table, - rows: Iterable[Iterable[Any]], - console: Optional[Console] = None -): +def print_table(table: Table, rows: Iterable[Iterable[Any]], console: Optional[Console] = None): for row in rows: table.add_row(*row) (console or rich_console).print(table) -def readable_datetime( - date_time: Union[datetime, int, None], - include_date: bool = True, - include_time: bool = True -): +def readable_datetime(date_time: Union[datetime, int, None], include_date: bool = True, include_time: bool = True): def get_datetime_format(): dt_format = "" if include_date: diff --git a/config.py b/config.py index b1babb7ba..c75ea404a 100644 --- a/config.py +++ b/config.py @@ -33,15 +33,15 @@ ) XRAY_EXECUTABLE_PATH = config("XRAY_EXECUTABLE_PATH", default="/usr/local/bin/xray") XRAY_ASSETS_PATH = config("XRAY_ASSETS_PATH", default="/usr/local/share/xray") -XRAY_EXCLUDE_INBOUND_TAGS = config("XRAY_EXCLUDE_INBOUND_TAGS", default='').split() +XRAY_EXCLUDE_INBOUND_TAGS = config("XRAY_EXCLUDE_INBOUND_TAGS", default="").split() XRAY_SUBSCRIPTION_URL_PREFIX = config("XRAY_SUBSCRIPTION_URL_PREFIX", default="").strip("/") XRAY_SUBSCRIPTION_PATH = config("XRAY_SUBSCRIPTION_PATH", default="sub").strip("/") TELEGRAM_API_TOKEN = config("TELEGRAM_API_TOKEN", default="") TELEGRAM_ADMIN_ID = config( - 'TELEGRAM_ADMIN_ID', + "TELEGRAM_ADMIN_ID", default="", - cast=lambda v: [int(i) for i in filter(str.isdigit, (s.strip() for s in v.split(',')))] + cast=lambda v: [int(i) for i in filter(str.isdigit, (s.strip() for s in v.split(",")))], ) TELEGRAM_PROXY_URL = config("TELEGRAM_PROXY_URL", default="") TELEGRAM_LOGGER_CHANNEL_ID = config("TELEGRAM_LOGGER_CHANNEL_ID", cast=int, default=0) @@ -68,8 +68,9 @@ GRPC_USER_AGENT_TEMPLATE = config("GRPC_USER_AGENT_TEMPLATE", default="user_agent/grpc.json") EXTERNAL_CONFIG = config("EXTERNAL_CONFIG", default="", cast=str) -LOGIN_NOTIFY_WHITE_LIST = [ip.strip() for ip in config("LOGIN_NOTIFY_WHITE_LIST", - default="", cast=str).split(",") if ip.strip()] +LOGIN_NOTIFY_WHITE_LIST = [ + ip.strip() for ip in config("LOGIN_NOTIFY_WHITE_LIST", default="", cast=str).split(",") if ip.strip() +] USE_CUSTOM_JSON_DEFAULT = config("USE_CUSTOM_JSON_DEFAULT", default=False, cast=bool) USE_CUSTOM_JSON_FOR_V2RAYN = config("USE_CUSTOM_JSON_FOR_V2RAYN", default=False, cast=bool) @@ -99,15 +100,15 @@ # USERNAME: PASSWORD -SUDOERS = {config("SUDO_USERNAME"): config("SUDO_PASSWORD")} \ - if config("SUDO_USERNAME", default='') and config("SUDO_PASSWORD", default='') \ +SUDOERS = ( + {config("SUDO_USERNAME"): config("SUDO_PASSWORD")} + if config("SUDO_USERNAME", default="") and config("SUDO_PASSWORD", default="") else {} +) WEBHOOK_ADDRESS = config( - 'WEBHOOK_ADDRESS', - default="", - cast=lambda v: [address.strip() for address in v.split(',')] if v else [] + "WEBHOOK_ADDRESS", default="", cast=lambda v: [address.strip() for address in v.split(",")] if v else [] ) WEBHOOK_SECRET = config("WEBHOOK_SECRET", default=None) @@ -120,16 +121,12 @@ # sends a notification when the user uses this much of thier data NOTIFY_REACHED_USAGE_PERCENT = config( - "NOTIFY_REACHED_USAGE_PERCENT", - default="80", - cast=lambda v: [int(p.strip()) for p in v.split(',')] if v else [] + "NOTIFY_REACHED_USAGE_PERCENT", default="80", cast=lambda v: [int(p.strip()) for p in v.split(",")] if v else [] ) # sends a notification when there is n days left of their service NOTIFY_DAYS_LEFT = config( - "NOTIFY_DAYS_LEFT", - default="3", - cast=lambda v: [int(d.strip()) for d in v.split(',')] if v else [] + "NOTIFY_DAYS_LEFT", default="3", cast=lambda v: [int(d.strip()) for d in v.split(",")] if v else [] ) DISABLE_RECORDING_NODE_USAGE = config("DISABLE_RECORDING_NODE_USAGE", cast=bool, default=False) diff --git a/main.py b/main.py index 9ffa41457..d365b02b3 100644 --- a/main.py +++ b/main.py @@ -9,9 +9,8 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend -from app import app, logger -from config import (DEBUG, UVICORN_HOST, UVICORN_PORT, UVICORN_SSL_CERTFILE, - UVICORN_SSL_KEYFILE, UVICORN_UDS) +from app import app, logger #noqa +from config import DEBUG, UVICORN_HOST, UVICORN_PORT, UVICORN_SSL_CERTFILE, UVICORN_SSL_KEYFILE, UVICORN_UDS def check_and_modify_ip(ip_address: str) -> str: @@ -50,14 +49,13 @@ def check_and_modify_ip(ip_address: str) -> str: else: return "localhost" - except ValueError as e: + except ValueError: return "localhost" def validate_cert_and_key(cert_file_path, key_file_path): if not os.path.isfile(cert_file_path): - raise ValueError( - f"SSL certificate file '{cert_file_path}' does not exist.") + raise ValueError(f"SSL certificate file '{cert_file_path}' does not exist.") if not os.path.isfile(key_file_path): raise ValueError(f"SSL key file '{key_file_path}' does not exist.") @@ -68,13 +66,12 @@ def validate_cert_and_key(cert_file_path, key_file_path): raise ValueError(f"SSL Error: {e}") try: - with open(cert_file_path, 'rb') as cert_file: + with open(cert_file_path, "rb") as cert_file: cert_data = cert_file.read() cert = x509.load_pem_x509_certificate(cert_data, default_backend()) if cert.issuer == cert.subject: - raise ValueError( - "The certificate is self-signed and not issued by a trusted CA.") + raise ValueError("The certificate is self-signed and not issued by a trusted CA.") except Exception as e: raise ValueError(f"Certificate verification failed: {e}") @@ -89,18 +86,18 @@ def validate_cert_and_key(cert_file_path, key_file_path): if UVICORN_SSL_CERTFILE and UVICORN_SSL_KEYFILE: validate_cert_and_key(UVICORN_SSL_CERTFILE, UVICORN_SSL_KEYFILE) - bind_args['ssl_certfile'] = UVICORN_SSL_CERTFILE - bind_args['ssl_keyfile'] = UVICORN_SSL_KEYFILE + bind_args["ssl_certfile"] = UVICORN_SSL_CERTFILE + bind_args["ssl_keyfile"] = UVICORN_SSL_KEYFILE if UVICORN_UDS: - bind_args['uds'] = UVICORN_UDS + bind_args["uds"] = UVICORN_UDS else: - bind_args['host'] = UVICORN_HOST - bind_args['port'] = UVICORN_PORT + bind_args["host"] = UVICORN_HOST + bind_args["port"] = UVICORN_PORT else: if UVICORN_UDS: - bind_args['uds'] = UVICORN_UDS + bind_args["uds"] = UVICORN_UDS else: ip = check_and_modify_ip(UVICORN_HOST) @@ -120,20 +117,16 @@ def validate_cert_and_key(cert_file_path, key_file_path): Then, navigate to {click.style(f'http://{ip}:{UVICORN_PORT}', bold=True)} on your computer. """) - bind_args['host'] = ip - bind_args['port'] = UVICORN_PORT + bind_args["host"] = ip + bind_args["port"] = UVICORN_PORT if DEBUG: - bind_args['uds'] = None - bind_args['host'] = '0.0.0.0' + bind_args["uds"] = None + bind_args["host"] = "0.0.0.0" try: uvicorn.run( - "main:app", - **bind_args, - workers=1, - reload=DEBUG, - log_level=logging.DEBUG if DEBUG else logging.INFO + "main:app", **bind_args, workers=1, reload=DEBUG, log_level=logging.DEBUG if DEBUG else logging.INFO ) except FileNotFoundError: # to prevent error on removing unix sock pass diff --git a/marzban-cli.py b/marzban-cli.py index f177171fd..290bc61a3 100755 --- a/marzban-cli.py +++ b/marzban-cli.py @@ -24,27 +24,29 @@ def get_default_shell() -> Shells: - shell = os.environ.get('SHELL') + shell = os.environ.get("SHELL") if shell: - shell = shell.split('/')[-1] + shell = shell.split("/")[-1] if shell in Shells.__members__: return getattr(Shells, shell) return Shells.bash @app_completion.command(help="Show completion for the specified shell, to copy or customize it.") -def show(ctx: typer.Context, shell: Shells = typer.Option(None, - help="The shell to install completion for.", - case_sensitive=False)) -> None: +def show( + ctx: typer.Context, + shell: Shells = typer.Option(None, help="The shell to install completion for.", case_sensitive=False), +) -> None: if shell is None: shell = get_default_shell() typer.completion.show_callback(ctx, None, shell) @app_completion.command(help="Install completion for the specified shell.") -def install(ctx: typer.Context, shell: Shells = typer.Option(None, - help="The shell to install completion for.", - case_sensitive=False)) -> None: +def install( + ctx: typer.Context, + shell: Shells = typer.Option(None, help="The shell to install completion for.", case_sensitive=False), +) -> None: if shell is None: shell = get_default_shell() typer.completion.install_callback(ctx, None, shell) @@ -52,4 +54,4 @@ def install(ctx: typer.Context, shell: Shells = typer.Option(None, if __name__ == "__main__": typer.completion.completion_init() - app(prog_name=os.environ.get('CLI_PROG_NAME')) + app(prog_name=os.environ.get("CLI_PROG_NAME")) diff --git a/pyproject.toml b/pyproject.toml index c3b9631d5..7a7757d2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,3 +49,15 @@ dependencies = [ "websocket-client==1.7.0", "websockets==12.0" ] + +[tool.ruff] +line-length = 120 +exclude = [ + "xray_api", + "app/db/migrations", +] + +[dependency-groups] +dev = [ + "ruff>=0.9.7", +] diff --git a/uv.lock b/uv.lock index b48b0f6a8..4ff3cb2ef 100644 --- a/uv.lock +++ b/uv.lock @@ -470,6 +470,11 @@ dependencies = [ { name = "websockets" }, ] +[package.dev-dependencies] +dev = [ + { name = "ruff" }, +] + [package.metadata] requires-dist = [ { name = "alembic", specifier = "==1.14.0" }, @@ -517,6 +522,9 @@ requires-dist = [ { name = "websockets", specifier = "==12.0" }, ] +[package.metadata.requires-dev] +dev = [{ name = "ruff", specifier = ">=0.9.7" }] + [[package]] name = "mdurl" version = "0.1.2" @@ -865,6 +873,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, ] +[[package]] +name = "ruff" +version = "0.9.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/39/8b/a86c300359861b186f18359adf4437ac8e4c52e42daa9eedc731ef9d5b53/ruff-0.9.7.tar.gz", hash = "sha256:643757633417907510157b206e490c3aa11cab0c087c912f60e07fbafa87a4c6", size = 3669813 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/f3/3a1d22973291226df4b4e2ff70196b926b6f910c488479adb0eeb42a0d7f/ruff-0.9.7-py3-none-linux_armv6l.whl", hash = "sha256:99d50def47305fe6f233eb8dabfd60047578ca87c9dcb235c9723ab1175180f4", size = 11774588 }, + { url = "https://files.pythonhosted.org/packages/8e/c9/b881f4157b9b884f2994fd08ee92ae3663fb24e34b0372ac3af999aa7fc6/ruff-0.9.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d59105ae9c44152c3d40a9c40d6331a7acd1cdf5ef404fbe31178a77b174ea66", size = 11746848 }, + { url = "https://files.pythonhosted.org/packages/14/89/2f546c133f73886ed50a3d449e6bf4af27d92d2f960a43a93d89353f0945/ruff-0.9.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f313b5800483770bd540cddac7c90fc46f895f427b7820f18fe1822697f1fec9", size = 11177525 }, + { url = "https://files.pythonhosted.org/packages/d7/93/6b98f2c12bf28ab9def59c50c9c49508519c5b5cfecca6de871cf01237f6/ruff-0.9.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:042ae32b41343888f59c0a4148f103208bf6b21c90118d51dc93a68366f4e903", size = 11996580 }, + { url = "https://files.pythonhosted.org/packages/8e/3f/b3fcaf4f6d875e679ac2b71a72f6691a8128ea3cb7be07cbb249f477c061/ruff-0.9.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87862589373b33cc484b10831004e5e5ec47dc10d2b41ba770e837d4f429d721", size = 11525674 }, + { url = "https://files.pythonhosted.org/packages/f0/48/33fbf18defb74d624535d5d22adcb09a64c9bbabfa755bc666189a6b2210/ruff-0.9.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a17e1e01bee0926d351a1ee9bc15c445beae888f90069a6192a07a84af544b6b", size = 12739151 }, + { url = "https://files.pythonhosted.org/packages/63/b5/7e161080c5e19fa69495cbab7c00975ef8a90f3679caa6164921d7f52f4a/ruff-0.9.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7c1f880ac5b2cbebd58b8ebde57069a374865c73f3bf41f05fe7a179c1c8ef22", size = 13416128 }, + { url = "https://files.pythonhosted.org/packages/4e/c8/b5e7d61fb1c1b26f271ac301ff6d9de5e4d9a9a63f67d732fa8f200f0c88/ruff-0.9.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e63fc20143c291cab2841dbb8260e96bafbe1ba13fd3d60d28be2c71e312da49", size = 12870858 }, + { url = "https://files.pythonhosted.org/packages/da/cb/2a1a8e4e291a54d28259f8fc6a674cd5b8833e93852c7ef5de436d6ed729/ruff-0.9.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91ff963baed3e9a6a4eba2a02f4ca8eaa6eba1cc0521aec0987da8d62f53cbef", size = 14786046 }, + { url = "https://files.pythonhosted.org/packages/ca/6c/c8f8a313be1943f333f376d79724260da5701426c0905762e3ddb389e3f4/ruff-0.9.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88362e3227c82f63eaebf0b2eff5b88990280fb1ecf7105523883ba8c3aaf6fb", size = 12550834 }, + { url = "https://files.pythonhosted.org/packages/9d/ad/f70cf5e8e7c52a25e166bdc84c082163c9c6f82a073f654c321b4dff9660/ruff-0.9.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0372c5a90349f00212270421fe91874b866fd3626eb3b397ede06cd385f6f7e0", size = 11961307 }, + { url = "https://files.pythonhosted.org/packages/52/d5/4f303ea94a5f4f454daf4d02671b1fbfe2a318b5fcd009f957466f936c50/ruff-0.9.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d76b8ab60e99e6424cd9d3d923274a1324aefce04f8ea537136b8398bbae0a62", size = 11612039 }, + { url = "https://files.pythonhosted.org/packages/eb/c8/bd12a23a75603c704ce86723be0648ba3d4ecc2af07eecd2e9fa112f7e19/ruff-0.9.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0c439bdfc8983e1336577f00e09a4e7a78944fe01e4ea7fe616d00c3ec69a3d0", size = 12168177 }, + { url = "https://files.pythonhosted.org/packages/cc/57/d648d4f73400fef047d62d464d1a14591f2e6b3d4a15e93e23a53c20705d/ruff-0.9.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:115d1f15e8fdd445a7b4dc9a30abae22de3f6bcabeb503964904471691ef7606", size = 12610122 }, + { url = "https://files.pythonhosted.org/packages/49/79/acbc1edd03ac0e2a04ae2593555dbc9990b34090a9729a0c4c0cf20fb595/ruff-0.9.7-py3-none-win32.whl", hash = "sha256:e9ece95b7de5923cbf38893f066ed2872be2f2f477ba94f826c8defdd6ec6b7d", size = 9988751 }, + { url = "https://files.pythonhosted.org/packages/6d/95/67153a838c6b6ba7a2401241fd8a00cd8c627a8e4a0491b8d853dedeffe0/ruff-0.9.7-py3-none-win_amd64.whl", hash = "sha256:3770fe52b9d691a15f0b87ada29c45324b2ace8f01200fb0c14845e499eb0c2c", size = 11002987 }, + { url = "https://files.pythonhosted.org/packages/63/6a/aca01554949f3a401991dc32fe22837baeaccb8a0d868256cbb26a029778/ruff-0.9.7-py3-none-win_arm64.whl", hash = "sha256:b075a700b2533feb7a01130ff656a4ec0d5f340bb540ad98759b8401c32c2037", size = 10177763 }, +] + [[package]] name = "setuptools" version = "75.8.0"