diff --git a/src/diracx/cli/__init__.py b/src/diracx/cli/__init__.py index e545229e..c77a7a1b 100644 --- a/src/diracx/cli/__init__.py +++ b/src/diracx/cli/__init__.py @@ -4,7 +4,6 @@ import json import os from datetime import datetime, timedelta, timezone -from typing import Optional from typer import Option @@ -22,10 +21,8 @@ @app.async_command() async def login( vo: str, - group: Optional[str] = None, - property: Optional[list[str]] = Option( - None, help="Override the default(s) with one or more properties" - ), + group: str | None = None, + property: list[str] | None = Option(None, help="Override the default(s) with one or more properties"), ): scopes = [f"vo:{vo}"] if group: @@ -61,9 +58,7 @@ async def login( raise RuntimeError("Device authorization flow expired") CREDENTIALS_PATH.parent.mkdir(parents=True, exist_ok=True) - expires = datetime.now(tz=timezone.utc) + timedelta( - seconds=response.expires_in - EXPIRES_GRACE_SECONDS - ) + expires = datetime.now(tz=timezone.utc) + timedelta(seconds=response.expires_in - EXPIRES_GRACE_SECONDS) credential_data = { "access_token": response.access_token, # TODO: "refresh_token": @@ -82,7 +77,7 @@ async def logout(): @app.callback() -def callback(output_format: Optional[str] = None): +def callback(output_format: str | None = None): if "DIRACX_OUTPUT_FORMAT" not in os.environ: output_format = output_format or "rich" if output_format is not None: diff --git a/src/diracx/cli/internal.py b/src/diracx/cli/internal.py index 75b9ab0c..a05f65e5 100644 --- a/src/diracx/cli/internal.py +++ b/src/diracx/cli/internal.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import json from pathlib import Path @@ -47,11 +45,7 @@ def generate_cs( IdP=IdpConfig(URL=idp_url, ClientID=idp_client_id), DefaultGroup=user_group, Users={}, - Groups={ - user_group: GroupConfig( - JobShare=None, Properties=["NormalUser"], Quota=None, Users=[] - ) - }, + Groups={user_group: GroupConfig(JobShare=None, Properties=["NormalUser"], Quota=None, Users=[])}, ) config = Config( Registry={vo: registry}, @@ -105,7 +99,5 @@ def add_user( config_data = json.loads(config.json(exclude_unset=True)) yaml_path.write_text(yaml.safe_dump(config_data)) repo.index.add([yaml_path.relative_to(repo_path)]) - repo.index.commit( - f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}" - ) + repo.index.commit(f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}") typer.echo(f"Successfully added user to {config_repo}", err=True) diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py index 8f3470e0..3f2a628c 100644 --- a/src/diracx/core/config/schema.py +++ b/src/diracx/core/config/schema.py @@ -2,7 +2,7 @@ import os from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel as _BaseModel from pydantic import EmailStr, PrivateAttr, root_validator @@ -22,9 +22,7 @@ def legacy_adaptor(cls, v): # though ideally we should parse the type hints properly. for field, hint in cls.__annotations__.items(): # Convert comma separated lists to actual lists - if hint in {"list[str]", "list[SecurityProperty]"} and isinstance( - v.get(field), str - ): + if hint in {"list[str]", "list[SecurityProperty]"} and isinstance(v.get(field), str): v[field] = [x.strip() for x in v[field].split(",") if x.strip()] # If the field is optional and the value is "None" convert it to None if "| None" in hint and field in v: @@ -49,12 +47,12 @@ class GroupConfig(BaseModel): AutoAddVOMS: bool = False AutoUploadPilotProxy: bool = False AutoUploadProxy: bool = False - JobShare: Optional[int] + JobShare: int | None Properties: list[SecurityProperty] - Quota: Optional[int] + Quota: int | None Users: list[str] AllowBackgroundTQs: bool = False - VOMSRole: Optional[str] + VOMSRole: str | None AutoSyncVOMS: bool = False diff --git a/src/diracx/core/extensions.py b/src/diracx/core/extensions.py index 46786a96..27bb6396 100644 --- a/src/diracx/core/extensions.py +++ b/src/diracx/core/extensions.py @@ -1,12 +1,10 @@ -from __future__ import absolute_import - __all__ = ("select_from_extension",) import os from collections import defaultdict +from collections.abc import Iterator from importlib.metadata import EntryPoint, entry_points from importlib.util import find_spec -from typing import Iterator def extensions_by_priority() -> Iterator[str]: @@ -17,9 +15,7 @@ def extensions_by_priority() -> Iterator[str]: yield module_name -def select_from_extension( - *, group: str, name: str | None = None -) -> Iterator[EntryPoint]: +def select_from_extension(*, group: str, name: str | None = None) -> Iterator[EntryPoint]: """Select entry points by group and name, in order of priority. Similar to ``importlib.metadata.entry_points.select`` except only modules diff --git a/src/diracx/core/properties.py b/src/diracx/core/properties.py index 8bf80201..4403afc8 100644 --- a/src/diracx/core/properties.py +++ b/src/diracx/core/properties.py @@ -5,7 +5,7 @@ import inspect import operator -from typing import Callable +from collections.abc import Callable from diracx.core.extensions import select_from_extension @@ -14,9 +14,7 @@ class SecurityProperty(str): @classmethod def available_properties(cls) -> set[SecurityProperty]: properties = set() - for entry_point in select_from_extension( - group="diracx", name="properties_module" - ): + for entry_point in select_from_extension(group="diracx", name="properties_module"): properties_module = entry_point.load() for _, obj in inspect.getmembers(properties_module): if isinstance(obj, SecurityProperty): @@ -26,23 +24,17 @@ def available_properties(cls) -> set[SecurityProperty]: def __repr__(self) -> str: return f"{self.__class__.__name__}({self})" - def __and__( - self, value: SecurityProperty | UnevaluatedProperty - ) -> UnevaluatedExpression: + def __and__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression: if not isinstance(value, UnevaluatedProperty): value = UnevaluatedProperty(value) return UnevaluatedProperty(self) & value - def __or__( - self, value: SecurityProperty | UnevaluatedProperty - ) -> UnevaluatedExpression: + def __or__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression: if not isinstance(value, UnevaluatedProperty): value = UnevaluatedProperty(value) return UnevaluatedProperty(self) | value - def __xor__( - self, value: SecurityProperty | UnevaluatedProperty - ) -> UnevaluatedExpression: + def __xor__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression: if not isinstance(value, UnevaluatedProperty): value = UnevaluatedProperty(value) return UnevaluatedProperty(self) ^ value diff --git a/src/diracx/db/utils.py b/src/diracx/db/utils.py index 4efc58d5..750a54e5 100644 --- a/src/diracx/db/utils.py +++ b/src/diracx/db/utils.py @@ -5,9 +5,10 @@ import contextlib import os from abc import ABCMeta +from collections.abc import AsyncIterator from datetime import datetime, timedelta, timezone from functools import partial -from typing import TYPE_CHECKING, AsyncIterator, Self +from typing import TYPE_CHECKING, Self from pydantic import parse_obj_as from sqlalchemy import Column as RawColumn diff --git a/src/diracx/routers/__init__.py b/src/diracx/routers/__init__.py index fd8710b9..43710cf3 100644 --- a/src/diracx/routers/__init__.py +++ b/src/diracx/routers/__init__.py @@ -3,8 +3,9 @@ import inspect import logging import os +from collections.abc import AsyncGenerator, Iterable from functools import partial -from typing import AsyncContextManager, AsyncGenerator, Iterable, TypeVar +from typing import AsyncContextManager, TypeVar import dotenv from fastapi import APIRouter, Depends, Request @@ -59,8 +60,7 @@ def create_app_inner( available_db_classes: set[type[BaseDB]] = set() for db_name, db_url in database_urls.items(): db_classes: list[type[BaseDB]] = [ - entry_point.load() - for entry_point in select_from_extension(group="diracx.dbs", name=db_name) + entry_point.load() for entry_point in select_from_extension(group="diracx.dbs", name=db_name) ] assert db_classes, f"Could not find {db_name=}" # The first DB is the highest priority one @@ -79,9 +79,7 @@ def create_app_inner( # Without this AutoREST generates different client sources for each ordering for system_name in sorted(enabled_systems): assert system_name not in routers - for entry_point in select_from_extension( - group="diracx.services", name=system_name - ): + for entry_point in select_from_extension(group="diracx.services", name=system_name): routers[system_name] = entry_point.load() break else: @@ -92,16 +90,12 @@ def create_app_inner( # Ensure required settings are available for cls in find_dependents(router, ServiceSettingsBase): if cls not in available_settings_classes: - raise NotImplementedError( - f"Cannot enable {system_name=} as it requires {cls=}" - ) + raise NotImplementedError(f"Cannot enable {system_name=} as it requires {cls=}") # Ensure required DBs are available missing_dbs = set(find_dependents(router, BaseDB)) - available_db_classes if missing_dbs: - raise NotImplementedError( - f"Cannot enable {system_name=} as it requires {missing_dbs=}" - ) + raise NotImplementedError(f"Cannot enable {system_name=} as it requires {missing_dbs=}") # Add the router to the application dependencies = [] @@ -155,18 +149,14 @@ def create_app() -> DiracFastAPI: def dirac_error_handler(request: Request, exc: DiracError) -> Response: - return JSONResponse( - status_code=exc.http_status_code, content={"detail": exc.detail} - ) + return JSONResponse(status_code=exc.http_status_code, content={"detail": exc.detail}) def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response: return JSONResponse(status_code=exc.status_code, content=exc.data) -def find_dependents( - obj: APIRouter | Iterable[Dependant], cls: type[T] -) -> Iterable[type[T]]: +def find_dependents(obj: APIRouter | Iterable[Dependant], cls: type[T]) -> Iterable[type[T]]: if isinstance(obj, APIRouter): # TODO: Support dependencies of the router itself # yield from find_dependents(obj.dependencies, cls)