Skip to content

Commit

Permalink
style: ran 'find -name *.py -not -path ./src/diracx/client/* -exec py…
Browse files Browse the repository at this point in the history
…upgrade --py311-plus {} +'
  • Loading branch information
fstagni committed Sep 5, 2023
1 parent 258f633 commit 1092836
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 64 deletions.
13 changes: 4 additions & 9 deletions src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os
from datetime import datetime, timedelta, timezone
from typing import Optional

from typer import Option

Expand All @@ -22,10 +21,8 @@
@app.async_command()
async def login(
vo: str,
group: Optional[str] = None,
property: Optional[list[str]] = Option(
None, help="Override the default(s) with one or more properties"
),
group: str | None = None,
property: list[str] | None = Option(None, help="Override the default(s) with one or more properties"),
):
scopes = [f"vo:{vo}"]
if group:
Expand Down Expand Up @@ -61,9 +58,7 @@ async def login(
raise RuntimeError("Device authorization flow expired")

CREDENTIALS_PATH.parent.mkdir(parents=True, exist_ok=True)
expires = datetime.now(tz=timezone.utc) + timedelta(
seconds=response.expires_in - EXPIRES_GRACE_SECONDS
)
expires = datetime.now(tz=timezone.utc) + timedelta(seconds=response.expires_in - EXPIRES_GRACE_SECONDS)
credential_data = {
"access_token": response.access_token,
# TODO: "refresh_token":
Expand All @@ -82,7 +77,7 @@ async def logout():


@app.callback()
def callback(output_format: Optional[str] = None):
def callback(output_format: str | None = None):
if "DIRACX_OUTPUT_FORMAT" not in os.environ:
output_format = output_format or "rich"
if output_format is not None:
Expand Down
12 changes: 2 additions & 10 deletions src/diracx/cli/internal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import absolute_import

import json
from pathlib import Path

Expand Down Expand Up @@ -47,11 +45,7 @@ def generate_cs(
IdP=IdpConfig(URL=idp_url, ClientID=idp_client_id),
DefaultGroup=user_group,
Users={},
Groups={
user_group: GroupConfig(
JobShare=None, Properties=["NormalUser"], Quota=None, Users=[]
)
},
Groups={user_group: GroupConfig(JobShare=None, Properties=["NormalUser"], Quota=None, Users=[])},
)
config = Config(
Registry={vo: registry},
Expand Down Expand Up @@ -105,7 +99,5 @@ def add_user(
config_data = json.loads(config.json(exclude_unset=True))
yaml_path.write_text(yaml.safe_dump(config_data))
repo.index.add([yaml_path.relative_to(repo_path)])
repo.index.commit(
f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}"
)
repo.index.commit(f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}")
typer.echo(f"Successfully added user to {config_repo}", err=True)
12 changes: 5 additions & 7 deletions src/diracx/core/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from datetime import datetime
from typing import Any, Optional
from typing import Any

from pydantic import BaseModel as _BaseModel
from pydantic import EmailStr, PrivateAttr, root_validator
Expand All @@ -22,9 +22,7 @@ def legacy_adaptor(cls, v):
# though ideally we should parse the type hints properly.
for field, hint in cls.__annotations__.items():
# Convert comma separated lists to actual lists
if hint in {"list[str]", "list[SecurityProperty]"} and isinstance(
v.get(field), str
):
if hint in {"list[str]", "list[SecurityProperty]"} and isinstance(v.get(field), str):
v[field] = [x.strip() for x in v[field].split(",") if x.strip()]
# If the field is optional and the value is "None" convert it to None
if "| None" in hint and field in v:
Expand All @@ -49,12 +47,12 @@ class GroupConfig(BaseModel):
AutoAddVOMS: bool = False
AutoUploadPilotProxy: bool = False
AutoUploadProxy: bool = False
JobShare: Optional[int]
JobShare: int | None
Properties: list[SecurityProperty]
Quota: Optional[int]
Quota: int | None
Users: list[str]
AllowBackgroundTQs: bool = False
VOMSRole: Optional[str]
VOMSRole: str | None
AutoSyncVOMS: bool = False


Expand Down
8 changes: 2 additions & 6 deletions src/diracx/core/extensions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import absolute_import

__all__ = ("select_from_extension",)

import os
from collections import defaultdict
from collections.abc import Iterator
from importlib.metadata import EntryPoint, entry_points
from importlib.util import find_spec
from typing import Iterator


def extensions_by_priority() -> Iterator[str]:
Expand All @@ -17,9 +15,7 @@ def extensions_by_priority() -> Iterator[str]:
yield module_name


def select_from_extension(
*, group: str, name: str | None = None
) -> Iterator[EntryPoint]:
def select_from_extension(*, group: str, name: str | None = None) -> Iterator[EntryPoint]:
"""Select entry points by group and name, in order of priority.
Similar to ``importlib.metadata.entry_points.select`` except only modules
Expand Down
18 changes: 5 additions & 13 deletions src/diracx/core/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import inspect
import operator
from typing import Callable
from collections.abc import Callable

from diracx.core.extensions import select_from_extension

Expand All @@ -14,9 +14,7 @@ class SecurityProperty(str):
@classmethod
def available_properties(cls) -> set[SecurityProperty]:
properties = set()
for entry_point in select_from_extension(
group="diracx", name="properties_module"
):
for entry_point in select_from_extension(group="diracx", name="properties_module"):
properties_module = entry_point.load()
for _, obj in inspect.getmembers(properties_module):
if isinstance(obj, SecurityProperty):
Expand All @@ -26,23 +24,17 @@ def available_properties(cls) -> set[SecurityProperty]:
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self})"

def __and__(
self, value: SecurityProperty | UnevaluatedProperty
) -> UnevaluatedExpression:
def __and__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression:
if not isinstance(value, UnevaluatedProperty):
value = UnevaluatedProperty(value)
return UnevaluatedProperty(self) & value

def __or__(
self, value: SecurityProperty | UnevaluatedProperty
) -> UnevaluatedExpression:
def __or__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression:
if not isinstance(value, UnevaluatedProperty):
value = UnevaluatedProperty(value)
return UnevaluatedProperty(self) | value

def __xor__(
self, value: SecurityProperty | UnevaluatedProperty
) -> UnevaluatedExpression:
def __xor__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression:
if not isinstance(value, UnevaluatedProperty):
value = UnevaluatedProperty(value)
return UnevaluatedProperty(self) ^ value
Expand Down
3 changes: 2 additions & 1 deletion src/diracx/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import contextlib
import os
from abc import ABCMeta
from collections.abc import AsyncIterator
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import TYPE_CHECKING, AsyncIterator, Self
from typing import TYPE_CHECKING, Self

from pydantic import parse_obj_as
from sqlalchemy import Column as RawColumn
Expand Down
26 changes: 8 additions & 18 deletions src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import inspect
import logging
import os
from collections.abc import AsyncGenerator, Iterable
from functools import partial
from typing import AsyncContextManager, AsyncGenerator, Iterable, TypeVar
from typing import AsyncContextManager, TypeVar

import dotenv
from fastapi import APIRouter, Depends, Request
Expand Down Expand Up @@ -59,8 +60,7 @@ def create_app_inner(
available_db_classes: set[type[BaseDB]] = set()
for db_name, db_url in database_urls.items():
db_classes: list[type[BaseDB]] = [
entry_point.load()
for entry_point in select_from_extension(group="diracx.dbs", name=db_name)
entry_point.load() for entry_point in select_from_extension(group="diracx.dbs", name=db_name)
]
assert db_classes, f"Could not find {db_name=}"
# The first DB is the highest priority one
Expand All @@ -79,9 +79,7 @@ def create_app_inner(
# Without this AutoREST generates different client sources for each ordering
for system_name in sorted(enabled_systems):
assert system_name not in routers
for entry_point in select_from_extension(
group="diracx.services", name=system_name
):
for entry_point in select_from_extension(group="diracx.services", name=system_name):
routers[system_name] = entry_point.load()
break
else:
Expand All @@ -92,16 +90,12 @@ def create_app_inner(
# Ensure required settings are available
for cls in find_dependents(router, ServiceSettingsBase):
if cls not in available_settings_classes:
raise NotImplementedError(
f"Cannot enable {system_name=} as it requires {cls=}"
)
raise NotImplementedError(f"Cannot enable {system_name=} as it requires {cls=}")

# Ensure required DBs are available
missing_dbs = set(find_dependents(router, BaseDB)) - available_db_classes
if missing_dbs:
raise NotImplementedError(
f"Cannot enable {system_name=} as it requires {missing_dbs=}"
)
raise NotImplementedError(f"Cannot enable {system_name=} as it requires {missing_dbs=}")

# Add the router to the application
dependencies = []
Expand Down Expand Up @@ -155,18 +149,14 @@ def create_app() -> DiracFastAPI:


def dirac_error_handler(request: Request, exc: DiracError) -> Response:
return JSONResponse(
status_code=exc.http_status_code, content={"detail": exc.detail}
)
return JSONResponse(status_code=exc.http_status_code, content={"detail": exc.detail})


def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response:
return JSONResponse(status_code=exc.status_code, content=exc.data)


def find_dependents(
obj: APIRouter | Iterable[Dependant], cls: type[T]
) -> Iterable[type[T]]:
def find_dependents(obj: APIRouter | Iterable[Dependant], cls: type[T]) -> Iterable[type[T]]:
if isinstance(obj, APIRouter):
# TODO: Support dependencies of the router itself
# yield from find_dependents(obj.dependencies, cls)
Expand Down

0 comments on commit 1092836

Please sign in to comment.