diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
deleted file mode 100644
index cc1c1c8b..00000000
--- a/.gitlab-ci.yml
+++ /dev/null
@@ -1,45 +0,0 @@
-.defaults:
- only:
- - merge_requests
- - master@chaen/chrissquare-hack-a-ton
-
-.defaults-micromamba:
- extends: .defaults
- image: registry.cern.ch/docker.io/mambaorg/micromamba
- before_script:
- - micromamba env create --file environment.yml --name test-env
- - eval "$(micromamba shell hook --shell=bash)"
- - micromamba activate test-env
- - pip install git+https://github.com/chaen/DIRAC.git@chris-hack-a-ton
- - pip install .
-
-pre-commit:
- extends: .defaults
- image: registry.cern.ch/docker.io/library/python:3.11
- variables:
- PRE_COMMIT_HOME: ${CI_PROJECT_DIR}/.cache/pre-commit
- cache:
- paths:
- - ${PRE_COMMIT_HOME}
- before_script:
- - pip install pre-commit
- script:
- - pre-commit run --all-files
-
-pytest:
- extends: .defaults-micromamba
- script:
- - pytest . --cov-report=xml:coverage.xml --junitxml=report.xml
- coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
- artifacts:
- when: always
- reports:
- junit: report.xml
- coverage_report:
- coverage_format: cobertura
- path: coverage.xml
-
-mypy:
- extends: .defaults-micromamba
- script:
- - mypy .
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b3535615..82841c6e 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -17,6 +17,7 @@ repos:
hooks:
- id: ruff
args: ["--fix"]
+ exclude: ^(src/diracx/client/)
- repo: https://github.com/psf/black
rev: 23.7.0
diff --git a/environment.yml b/environment.yml
index bbd754d0..8a9a4ef6 100644
--- a/environment.yml
+++ b/environment.yml
@@ -30,7 +30,7 @@ dependencies:
- httpx
- isodate
- mypy
- - pydantic =1.10.10
+ - pydantic =1.10.12
- pytest
- pytest-asyncio
- pytest-cov
diff --git a/pyproject.toml b/pyproject.toml
index aa503323..840dd447 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,15 @@ build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
[tool.ruff]
-select = ["E", "F", "B", "I", "PLE"]
+select = [
+ "E", # pycodestyle errrors
+ "F", # pyflakes
+ "B", # flake8-bugbear
+ "I", # isort
+ "PLE", # pylint errors
+ "UP", # pyUpgrade
+ "FLY", # flynt
+]
ignore = ["B905", "B008", "B006"]
line-length = 120
src = ["src", "tests"]
@@ -17,6 +25,12 @@ exclude = ["src/diracx/client/"]
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
extend-immutable-calls = ["fastapi.Depends", "fastapi.Query", "fastapi.Path", "fastapi.Body", "fastapi.Header"]
+[tool.black]
+line-length = 120
+
+[tool.isort]
+profile = "black"
+
[tool.mypy]
plugins = ["sqlalchemy.ext.mypy.plugin", "pydantic.mypy"]
exclude = ["^src/diracx/client", "^tests/", "^build/"]
diff --git a/src/diracx/cli/__init__.py b/src/diracx/cli/__init__.py
index e545229e..1d050fb4 100644
--- a/src/diracx/cli/__init__.py
+++ b/src/diracx/cli/__init__.py
@@ -1,3 +1,5 @@
+# ruff: noqa: UP007 # because of https://github.com/tiangolo/typer/issues/533
+
from __future__ import annotations
import asyncio
@@ -23,9 +25,7 @@
async def login(
vo: str,
group: Optional[str] = None,
- property: Optional[list[str]] = Option(
- None, help="Override the default(s) with one or more properties"
- ),
+ property: Optional[list[str]] = Option(None, help="Override the default(s) with one or more properties"),
):
scopes = [f"vo:{vo}"]
if group:
@@ -61,9 +61,7 @@ async def login(
raise RuntimeError("Device authorization flow expired")
CREDENTIALS_PATH.parent.mkdir(parents=True, exist_ok=True)
- expires = datetime.now(tz=timezone.utc) + timedelta(
- seconds=response.expires_in - EXPIRES_GRACE_SECONDS
- )
+ expires = datetime.now(tz=timezone.utc) + timedelta(seconds=response.expires_in - EXPIRES_GRACE_SECONDS)
credential_data = {
"access_token": response.access_token,
# TODO: "refresh_token":
diff --git a/src/diracx/cli/internal.py b/src/diracx/cli/internal.py
index 75b9ab0c..a05f65e5 100644
--- a/src/diracx/cli/internal.py
+++ b/src/diracx/cli/internal.py
@@ -1,5 +1,3 @@
-from __future__ import absolute_import
-
import json
from pathlib import Path
@@ -47,11 +45,7 @@ def generate_cs(
IdP=IdpConfig(URL=idp_url, ClientID=idp_client_id),
DefaultGroup=user_group,
Users={},
- Groups={
- user_group: GroupConfig(
- JobShare=None, Properties=["NormalUser"], Quota=None, Users=[]
- )
- },
+ Groups={user_group: GroupConfig(JobShare=None, Properties=["NormalUser"], Quota=None, Users=[])},
)
config = Config(
Registry={vo: registry},
@@ -105,7 +99,5 @@ def add_user(
config_data = json.loads(config.json(exclude_unset=True))
yaml_path.write_text(yaml.safe_dump(config_data))
repo.index.add([yaml_path.relative_to(repo_path)])
- repo.index.commit(
- f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}"
- )
+ repo.index.commit(f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}")
typer.echo(f"Successfully added user to {config_repo}", err=True)
diff --git a/src/diracx/cli/jobs.py b/src/diracx/cli/jobs.py
index 6dcf6eb2..332bee65 100644
--- a/src/diracx/cli/jobs.py
+++ b/src/diracx/cli/jobs.py
@@ -104,9 +104,5 @@ def display_rich(data, unit: str) -> None:
@app.async_command()
async def submit(jdl: list[FileText]):
async with Dirac(endpoint="http://localhost:8000") as api:
- jobs = await api.jobs.submit_bulk_jobs(
- [x.read() for x in jdl], headers=get_auth_headers()
- )
- print(
- f"Inserted {len(jobs)} jobs with ids: {','.join(map(str, (job.job_id for job in jobs)))}"
- )
+ jobs = await api.jobs.submit_bulk_jobs([x.read() for x in jdl], headers=get_auth_headers())
+ print(f"Inserted {len(jobs)} jobs with ids: {','.join(map(str, (job.job_id for job in jobs)))}")
diff --git a/src/diracx/client/_client.py b/src/diracx/client/_client.py
index 186d43f0..a5f0910f 100644
--- a/src/diracx/client/_client.py
+++ b/src/diracx/client/_client.py
@@ -40,28 +40,18 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
self, *, endpoint: str = "", **kwargs: Any
) -> None:
self._config = DiracConfiguration(**kwargs)
- self._client: PipelineClient = PipelineClient(
- base_url=endpoint, config=self._config, **kwargs
- )
+ self._client: PipelineClient = PipelineClient(base_url=endpoint, config=self._config, **kwargs)
- client_models = {
- k: v for k, v in _models.__dict__.items() if isinstance(v, type)
- }
+ client_models = {k: v for k, v in _models.__dict__.items() if isinstance(v, type)}
self._serialize = Serializer(client_models)
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
- self.well_known = WellKnownOperations(
- self._client, self._config, self._serialize, self._deserialize
- )
+ self.well_known = WellKnownOperations(self._client, self._config, self._serialize, self._deserialize)
self.auth = AuthOperations( # pylint: disable=abstract-class-instantiated
self._client, self._config, self._serialize, self._deserialize
)
- self.config = ConfigOperations(
- self._client, self._config, self._serialize, self._deserialize
- )
- self.jobs = JobsOperations(
- self._client, self._config, self._serialize, self._deserialize
- )
+ self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize)
+ self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize)
def send_request(self, request: HttpRequest, **kwargs: Any) -> HttpResponse:
"""Runs the network request through the client's chained policies.
diff --git a/src/diracx/client/_configuration.py b/src/diracx/client/_configuration.py
index b5106fed..0519a33c 100644
--- a/src/diracx/client/_configuration.py
+++ b/src/diracx/client/_configuration.py
@@ -26,24 +26,12 @@ def __init__(self, **kwargs: Any) -> None:
self._configure(**kwargs)
def _configure(self, **kwargs: Any) -> None:
- self.user_agent_policy = kwargs.get(
- "user_agent_policy"
- ) or policies.UserAgentPolicy(**kwargs)
- self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(
- **kwargs
- )
+ self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs)
+ self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs)
self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs)
- self.logging_policy = kwargs.get(
- "logging_policy"
- ) or policies.NetworkTraceLoggingPolicy(**kwargs)
- self.http_logging_policy = kwargs.get(
- "http_logging_policy"
- ) or policies.HttpLoggingPolicy(**kwargs)
+ self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs)
+ self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs)
self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs)
- self.custom_hook_policy = kwargs.get(
- "custom_hook_policy"
- ) or policies.CustomHookPolicy(**kwargs)
- self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(
- **kwargs
- )
+ self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs)
+ self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get("authentication_policy")
diff --git a/src/diracx/client/_patch.py b/src/diracx/client/_patch.py
index d400d2d1..f7dd3251 100644
--- a/src/diracx/client/_patch.py
+++ b/src/diracx/client/_patch.py
@@ -8,9 +8,7 @@
"""
from typing import List
-__all__: List[
- str
-] = [] # Add all objects you want publicly available to users at this package level
+__all__: List[str] = [] # Add all objects you want publicly available to users at this package level
def patch_sdk():
diff --git a/src/diracx/client/_serialization.py b/src/diracx/client/_serialization.py
index 615a169a..1e7a11b1 100644
--- a/src/diracx/client/_serialization.py
+++ b/src/diracx/client/_serialization.py
@@ -84,9 +84,7 @@ class RawDeserializer:
CONTEXT_NAME = "deserialized_data"
@classmethod
- def deserialize_from_text(
- cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None
- ) -> Any:
+ def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None) -> Any:
"""Decode data according to content-type.
Accept a stream of data as well, but will be load at once in memory for now.
@@ -148,14 +146,10 @@ def _json_attemp(data):
# context otherwise.
_LOGGER.critical("Wasn't XML not JSON, failing")
raise_with_traceback(DeserializationError, "XML is invalid")
- raise DeserializationError(
- "Cannot deserialize content-type: {}".format(content_type)
- )
+ raise DeserializationError("Cannot deserialize content-type: {}".format(content_type))
@classmethod
- def deserialize_from_http_generics(
- cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping
- ) -> Any:
+ def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping) -> Any:
"""Deserialize from HTTP response.
Use bytes and headers to NOT use any requests/aiohttp or whatever
@@ -376,9 +370,7 @@ def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON:
def as_dict(
self,
keep_readonly: bool = True,
- key_transformer: Callable[
- [str, Dict[str, Any], Any], Any
- ] = attribute_transformer,
+ key_transformer: Callable[[str, Dict[str, Any], Any], Any] = attribute_transformer,
**kwargs: Any
) -> JSON:
"""Return a dict that can be serialized using json.dump.
@@ -412,18 +404,14 @@ def my_key_transformer(key, attr_desc, value):
:rtype: dict
"""
serializer = Serializer(self._infer_class_models())
- return serializer._serialize(
- self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs
- )
+ return serializer._serialize(self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs)
@classmethod
def _infer_class_models(cls):
try:
str_models = cls.__module__.rsplit(".", 1)[0]
models = sys.modules[str_models]
- client_models = {
- k: v for k, v in models.__dict__.items() if isinstance(v, type)
- }
+ client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)}
if cls.__name__ not in client_models:
raise ValueError("Not Autorest generated code")
except Exception:
@@ -432,9 +420,7 @@ def _infer_class_models(cls):
return client_models
@classmethod
- def deserialize(
- cls: Type[ModelType], data: Any, content_type: Optional[str] = None
- ) -> ModelType:
+ def deserialize(cls: Type[ModelType], data: Any, content_type: Optional[str] = None) -> ModelType:
"""Parse a str using the RestAPI syntax and return a model.
:param str data: A str using RestAPI structure. JSON by default.
@@ -495,13 +481,9 @@ def _classify(cls, response, objects):
if not isinstance(response, ET.Element):
rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1]
- subtype_value = response.pop(
- rest_api_response_key, None
- ) or response.pop(subtype_key, None)
+ subtype_value = response.pop(rest_api_response_key, None) or response.pop(subtype_key, None)
else:
- subtype_value = xml_key_extractor(
- subtype_key, cls._attribute_map[subtype_key], response
- )
+ subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response)
if subtype_value:
# Try to match base class. Can be class name only
# (bug to fix in Autorest to support x-ms-discriminator-name)
@@ -629,9 +611,7 @@ def _serialize(self, target_obj, data_type=None, **kwargs):
try:
is_xml_model_serialization = kwargs["is_xml"]
except KeyError:
- is_xml_model_serialization = kwargs.setdefault(
- "is_xml", target_obj.is_xml_model()
- )
+ is_xml_model_serialization = kwargs.setdefault("is_xml", target_obj.is_xml_model())
serialized = {}
if is_xml_model_serialization:
@@ -640,9 +620,7 @@ def _serialize(self, target_obj, data_type=None, **kwargs):
attributes = target_obj._attribute_map
for attr, attr_desc in attributes.items():
attr_name = attr
- if not keep_readonly and target_obj._validation.get(attr_name, {}).get(
- "readonly", False
- ):
+ if not keep_readonly and target_obj._validation.get(attr_name, {}).get("readonly", False):
continue
if attr_name == "additional_properties" and attr_desc["key"] == "":
@@ -654,15 +632,11 @@ def _serialize(self, target_obj, data_type=None, **kwargs):
if is_xml_model_serialization:
pass # Don't provide "transformer" for XML for now. Keep "orig_attr"
else: # JSON
- keys, orig_attr = key_transformer(
- attr, attr_desc.copy(), orig_attr
- )
+ keys, orig_attr = key_transformer(attr, attr_desc.copy(), orig_attr)
keys = keys if isinstance(keys, list) else [keys]
kwargs["serialization_ctxt"] = attr_desc
- new_attr = self.serialize_data(
- orig_attr, attr_desc["type"], **kwargs
- )
+ new_attr = self.serialize_data(orig_attr, attr_desc["type"], **kwargs)
if is_xml_model_serialization:
xml_desc = attr_desc.get("xml", {})
@@ -709,9 +683,7 @@ def _serialize(self, target_obj, data_type=None, **kwargs):
continue
except (AttributeError, KeyError, TypeError) as err:
- msg = "Attribute {} in object {} cannot be serialized.\n{}".format(
- attr_name, class_name, str(target_obj)
- )
+ msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj))
raise_with_traceback(SerializationError, msg, err)
else:
return serialized
@@ -733,9 +705,7 @@ def body(self, data, data_type, **kwargs):
is_xml_model_serialization = kwargs["is_xml"]
except KeyError:
if internal_data_type and issubclass(internal_data_type, Model):
- is_xml_model_serialization = kwargs.setdefault(
- "is_xml", internal_data_type.is_xml_model()
- )
+ is_xml_model_serialization = kwargs.setdefault("is_xml", internal_data_type.is_xml_model())
else:
is_xml_model_serialization = False
if internal_data_type and not isinstance(internal_data_type, Enum):
@@ -756,9 +726,7 @@ def body(self, data, data_type, **kwargs):
]
data = deserializer._deserialize(data_type, data)
except DeserializationError as err:
- raise_with_traceback(
- SerializationError, "Unable to build a model: " + str(err), err
- )
+ raise_with_traceback(SerializationError, "Unable to build a model: " + str(err), err)
return self._serialize(data, data_type, **kwargs)
@@ -798,12 +766,7 @@ def query(self, name, data, data_type, **kwargs):
# Treat the list aside, since we don't want to encode the div separator
if data_type.startswith("["):
internal_data_type = data_type[1:-1]
- data = [
- self.serialize_data(d, internal_data_type, **kwargs)
- if d is not None
- else ""
- for d in data
- ]
+ data = [self.serialize_data(d, internal_data_type, **kwargs) if d is not None else "" for d in data]
if not kwargs.get("skip_quote", False):
data = [quote(str(d), safe="") for d in data]
return str(self.serialize_iter(data, internal_data_type, **kwargs))
@@ -975,9 +938,7 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs):
is_wrapped = xml_desc.get("wrapped", False)
node_name = xml_desc.get("itemsName", xml_name)
if is_wrapped:
- final_result = _create_xml_node(
- xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)
- )
+ final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None))
else:
final_result = []
# All list elements to "local_node"
@@ -1009,9 +970,7 @@ def serialize_dict(self, attr, dict_type, **kwargs):
serialized = {}
for key, value in attr.items():
try:
- serialized[self.serialize_unicode(key)] = self.serialize_data(
- value, dict_type, **kwargs
- )
+ serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs)
except ValueError:
serialized[self.serialize_unicode(key)] = None
@@ -1020,9 +979,7 @@ def serialize_dict(self, attr, dict_type, **kwargs):
xml_desc = serialization_ctxt["xml"]
xml_name = xml_desc["name"]
- final_result = _create_xml_node(
- xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)
- )
+ final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None))
for key, value in serialized.items():
ET.SubElement(final_result, key).text = value
return final_result
@@ -1068,9 +1025,7 @@ def serialize_object(self, attr, **kwargs):
serialized = {}
for key, value in attr.items():
try:
- serialized[self.serialize_unicode(key)] = self.serialize_object(
- value, **kwargs
- )
+ serialized[self.serialize_unicode(key)] = self.serialize_object(value, **kwargs)
except ValueError:
serialized[self.serialize_unicode(key)] = None
return serialized
@@ -1287,9 +1242,7 @@ def rest_key_case_insensitive_extractor(attr, attr_desc, data):
key = _decode_attribute_map_key(dict_keys[0])
break
working_key = _decode_attribute_map_key(dict_keys[0])
- working_data = attribute_key_case_insensitive_extractor(
- working_key, None, working_data
- )
+ working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data)
if working_data is None:
# If at any point while following flatten JSON path see None, it means
# that all properties under are None as well
@@ -1382,10 +1335,7 @@ def xml_key_extractor(attr, attr_desc, data):
# - Wrapped node
# - Internal type is an enum (considered basic types)
# - Internal type has no XML/Name node
- if is_wrapped or (
- internal_type
- and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map)
- ):
+ if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map)):
children = data.findall(xml_name)
# If internal type has a local name and it's not a list, I use that name
elif not is_iter_type and internal_type and "name" in internal_type_xml_map:
@@ -1393,9 +1343,7 @@ def xml_key_extractor(attr, attr_desc, data):
children = data.findall(xml_name)
# That's an array
else:
- if (
- internal_type
- ): # Complex type, ignore itemsName and use the complex type name
+ if internal_type: # Complex type, ignore itemsName and use the complex type name
items_name = _extract_name_from_internal_type(internal_type)
else:
items_name = xml_desc.get("itemsName", xml_name)
@@ -1424,9 +1372,7 @@ def xml_key_extractor(attr, attr_desc, data):
# Here it's not a itertype, we should have found one element only or empty
if len(children) > 1:
- raise DeserializationError(
- "Find several XML '{}' where it was not expected".format(xml_name)
- )
+ raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name))
return children[0]
@@ -1439,9 +1385,7 @@ class Deserializer(object):
basic_types = {str: "str", int: "int", bool: "bool", float: "float"}
- valid_date = re.compile(
- r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?"
- )
+ valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?")
def __init__(self, classes: Optional[Mapping[str, Type[ModelType]]] = None):
self.deserialize_type = {
@@ -1497,11 +1441,7 @@ def _deserialize(self, target_obj, data):
"""
# This is already a model, go recursive just in case
if hasattr(data, "_attribute_map"):
- constants = [
- name
- for name, config in getattr(data, "_validation", {}).items()
- if config.get("constant")
- ]
+ constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")]
try:
for attr, mapconfig in data._attribute_map.items():
if attr in constants:
@@ -1511,9 +1451,7 @@ def _deserialize(self, target_obj, data):
continue
local_type = mapconfig["type"]
internal_data_type = local_type.strip("[]{}")
- if internal_data_type not in self.dependencies or isinstance(
- internal_data_type, Enum
- ):
+ if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum):
continue
setattr(data, attr, self._deserialize(local_type, value))
return data
@@ -1567,10 +1505,7 @@ def _deserialize(self, target_obj, data):
def _build_additional_properties(self, attribute_map, data):
if not self.additional_properties_detection:
return None
- if (
- "additional_properties" in attribute_map
- and attribute_map.get("additional_properties", {}).get("key") != ""
- ):
+ if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != "":
# Check empty string. If it's not empty, someone has a real "additionalProperties"
return None
if isinstance(data, ET.Element):
@@ -1650,21 +1585,15 @@ def _unpack_content(raw_data, content_type=None):
if context:
if RawDeserializer.CONTEXT_NAME in context:
return context[RawDeserializer.CONTEXT_NAME]
- raise ValueError(
- "This pipeline didn't have the RawDeserializer policy; can't deserialize"
- )
+ raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize")
# Assume this is enough to recognize universal_http.ClientResponse without importing it
if hasattr(raw_data, "body"):
- return RawDeserializer.deserialize_from_http_generics(
- raw_data.text(), raw_data.headers
- )
+ return RawDeserializer.deserialize_from_http_generics(raw_data.text(), raw_data.headers)
# Assume this enough to recognize requests.Response without importing it.
if hasattr(raw_data, "_content_consumed"):
- return RawDeserializer.deserialize_from_http_generics(
- raw_data.text, raw_data.headers
- )
+ return RawDeserializer.deserialize_from_http_generics(raw_data.text, raw_data.headers)
if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, "read"):
return RawDeserializer.deserialize_from_text(raw_data, content_type) # type: ignore
@@ -1679,17 +1608,9 @@ def _instantiate_model(self, response, attrs, additional_properties=None):
if callable(response):
subtype = getattr(response, "_subtype_map", {})
try:
- readonly = [
- k for k, v in response._validation.items() if v.get("readonly")
- ]
- const = [
- k for k, v in response._validation.items() if v.get("constant")
- ]
- kwargs = {
- k: v
- for k, v in attrs.items()
- if k not in subtype and k not in readonly + const
- }
+ readonly = [k for k, v in response._validation.items() if v.get("readonly")]
+ const = [k for k, v in response._validation.items() if v.get("constant")]
+ kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const}
response_obj = response(**kwargs)
for attr in readonly:
setattr(response_obj, attr, attrs.get(attr))
@@ -1726,17 +1647,11 @@ def deserialize_data(self, data, data_type):
if data_type in self.basic_types.values():
return self.deserialize_basic(data, data_type)
if data_type in self.deserialize_type:
- if isinstance(
- data, self.deserialize_expected_types.get(data_type, tuple())
- ):
+ if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())):
return data
is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"]
- if (
- isinstance(data, ET.Element)
- and is_a_text_parsing_type(data_type)
- and not data.text
- ):
+ if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text:
return None
data_val = self.deserialize_type[data_type](data)
return data_val
@@ -1767,16 +1682,10 @@ def deserialize_iter(self, attr, iter_type):
"""
if attr is None:
return None
- if isinstance(
- attr, ET.Element
- ): # If I receive an element here, get the children
+ if isinstance(attr, ET.Element): # If I receive an element here, get the children
attr = list(attr)
if not isinstance(attr, (list, set)):
- raise DeserializationError(
- "Cannot deserialize as [{}] an object of type {}".format(
- iter_type, type(attr)
- )
- )
+ raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format(iter_type, type(attr)))
return [self.deserialize_data(a, iter_type) for a in attr]
def deserialize_dict(self, attr, dict_type):
@@ -1788,9 +1697,7 @@ def deserialize_dict(self, attr, dict_type):
:rtype: dict
"""
if isinstance(attr, list):
- return {
- x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr
- }
+ return {x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr}
if isinstance(attr, ET.Element):
# Transform value into {"Key": "value"}
@@ -2022,9 +1929,7 @@ def deserialize_date(attr):
if isinstance(attr, ET.Element):
attr = attr.text
if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore
- raise DeserializationError(
- "Date must have only digits and -. Received: %s" % attr
- )
+ raise DeserializationError("Date must have only digits and -. Received: %s" % attr)
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
return isodate.parse_date(attr, defaultmonth=None, defaultday=None)
@@ -2039,9 +1944,7 @@ def deserialize_time(attr):
if isinstance(attr, ET.Element):
attr = attr.text
if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore
- raise DeserializationError(
- "Date must have only digits and -. Received: %s" % attr
- )
+ raise DeserializationError("Date must have only digits and -. Received: %s" % attr)
return isodate.parse_time(attr)
@staticmethod
@@ -2057,10 +1960,7 @@ def deserialize_rfc(attr):
try:
parsed_date = email.utils.parsedate_tz(attr) # type: ignore
date_obj = datetime.datetime(
- *parsed_date[:6],
- tzinfo=_FixedOffset(
- datetime.timedelta(minutes=(parsed_date[9] or 0) / 60)
- )
+ *parsed_date[:6], tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0) / 60))
)
if not date_obj.tzinfo:
date_obj = date_obj.astimezone(tz=TZ_UTC)
diff --git a/src/diracx/client/_vendor.py b/src/diracx/client/_vendor.py
index c5b16f59..7b62dcec 100644
--- a/src/diracx/client/_vendor.py
+++ b/src/diracx/client/_vendor.py
@@ -14,16 +14,12 @@ def _format_url_section(template, **kwargs):
except KeyError as key:
# Need the cast, as for some reasons "split" is typed as list[str | Any]
formatted_components = cast(List[str], template.split("/"))
- components = [
- c for c in formatted_components if "{}".format(key.args[0]) not in c
- ]
+ components = [c for c in formatted_components if "{}".format(key.args[0]) not in c]
template = "/".join(components)
def raise_if_not_implemented(cls, abstract_methods):
- not_implemented = [
- f for f in abstract_methods if not callable(getattr(cls, f, None))
- ]
+ not_implemented = [f for f in abstract_methods if not callable(getattr(cls, f, None))]
if not_implemented:
raise NotImplementedError(
"The following methods on operation group '{}' are not implemented: '{}'."
diff --git a/src/diracx/client/aio/_client.py b/src/diracx/client/aio/_client.py
index 32f0e07d..684725e4 100644
--- a/src/diracx/client/aio/_client.py
+++ b/src/diracx/client/aio/_client.py
@@ -40,32 +40,20 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
self, *, endpoint: str = "", **kwargs: Any
) -> None:
self._config = DiracConfiguration(**kwargs)
- self._client: AsyncPipelineClient = AsyncPipelineClient(
- base_url=endpoint, config=self._config, **kwargs
- )
+ self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=endpoint, config=self._config, **kwargs)
- client_models = {
- k: v for k, v in _models.__dict__.items() if isinstance(v, type)
- }
+ client_models = {k: v for k, v in _models.__dict__.items() if isinstance(v, type)}
self._serialize = Serializer(client_models)
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
- self.well_known = WellKnownOperations(
- self._client, self._config, self._serialize, self._deserialize
- )
+ self.well_known = WellKnownOperations(self._client, self._config, self._serialize, self._deserialize)
self.auth = AuthOperations( # pylint: disable=abstract-class-instantiated
self._client, self._config, self._serialize, self._deserialize
)
- self.config = ConfigOperations(
- self._client, self._config, self._serialize, self._deserialize
- )
- self.jobs = JobsOperations(
- self._client, self._config, self._serialize, self._deserialize
- )
+ self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize)
+ self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize)
- def send_request(
- self, request: HttpRequest, **kwargs: Any
- ) -> Awaitable[AsyncHttpResponse]:
+ def send_request(self, request: HttpRequest, **kwargs: Any) -> Awaitable[AsyncHttpResponse]:
"""Runs the network request through the client's chained policies.
>>> from azure.core.rest import HttpRequest
diff --git a/src/diracx/client/aio/_configuration.py b/src/diracx/client/aio/_configuration.py
index 10d4b962..07f6f94a 100644
--- a/src/diracx/client/aio/_configuration.py
+++ b/src/diracx/client/aio/_configuration.py
@@ -26,26 +26,12 @@ def __init__(self, **kwargs: Any) -> None:
self._configure(**kwargs)
def _configure(self, **kwargs: Any) -> None:
- self.user_agent_policy = kwargs.get(
- "user_agent_policy"
- ) or policies.UserAgentPolicy(**kwargs)
- self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(
- **kwargs
- )
+ self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs)
+ self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs)
self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs)
- self.logging_policy = kwargs.get(
- "logging_policy"
- ) or policies.NetworkTraceLoggingPolicy(**kwargs)
- self.http_logging_policy = kwargs.get(
- "http_logging_policy"
- ) or policies.HttpLoggingPolicy(**kwargs)
- self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(
- **kwargs
- )
- self.custom_hook_policy = kwargs.get(
- "custom_hook_policy"
- ) or policies.CustomHookPolicy(**kwargs)
- self.redirect_policy = kwargs.get(
- "redirect_policy"
- ) or policies.AsyncRedirectPolicy(**kwargs)
+ self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs)
+ self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs)
+ self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(**kwargs)
+ self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs)
+ self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get("authentication_policy")
diff --git a/src/diracx/client/aio/_patch.py b/src/diracx/client/aio/_patch.py
index d400d2d1..f7dd3251 100644
--- a/src/diracx/client/aio/_patch.py
+++ b/src/diracx/client/aio/_patch.py
@@ -8,9 +8,7 @@
"""
from typing import List
-__all__: List[
- str
-] = [] # Add all objects you want publicly available to users at this package level
+__all__: List[str] = [] # Add all objects you want publicly available to users at this package level
def patch_sdk():
diff --git a/src/diracx/client/aio/_vendor.py b/src/diracx/client/aio/_vendor.py
index 9889eb3a..3ada94d2 100644
--- a/src/diracx/client/aio/_vendor.py
+++ b/src/diracx/client/aio/_vendor.py
@@ -5,9 +5,7 @@
def raise_if_not_implemented(cls, abstract_methods):
- not_implemented = [
- f for f in abstract_methods if not callable(getattr(cls, f, None))
- ]
+ not_implemented = [f for f in abstract_methods if not callable(getattr(cls, f, None))]
if not_implemented:
raise NotImplementedError(
"The following methods on operation group '{}' are not implemented: '{}'."
diff --git a/src/diracx/client/aio/operations/_operations.py b/src/diracx/client/aio/operations/_operations.py
index a481b679..06da8589 100644
--- a/src/diracx/client/aio/operations/_operations.py
+++ b/src/diracx/client/aio/operations/_operations.py
@@ -52,9 +52,7 @@
else:
from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports
T = TypeVar("T")
-ClsType = Optional[
- Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]
-]
+ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]]
JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object
@@ -75,9 +73,7 @@ def __init__(self, *args, **kwargs) -> None:
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
@distributed_trace_async
async def openid_configuration(self, **kwargs: Any) -> Any:
@@ -109,18 +105,14 @@ async def openid_configuration(self, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -148,9 +140,7 @@ def __init__(self, *args, **kwargs) -> None:
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
raise_if_not_implemented(
self.__class__,
[
@@ -198,18 +188,14 @@ async def do_device_flow(self, *, user_code: str, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -266,23 +252,17 @@ async def initiate_device_flow(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
- deserialized = self._deserialize(
- "InitiateDeviceFlowResponse", pipeline_response
- )
+ deserialized = self._deserialize("InitiateDeviceFlowResponse", pipeline_response)
if cls:
return cls(pipeline_response, deserialized, {})
@@ -329,18 +309,14 @@ async def finish_device_flow(self, *, code: str, state: str, **kwargs: Any) -> A
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -380,18 +356,14 @@ async def finished(self, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -463,18 +435,14 @@ async def authorization_flow(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -485,9 +453,7 @@ async def authorization_flow(
return deserialized
@distributed_trace_async
- async def authorization_flow_complete(
- self, *, code: str, state: str, **kwargs: Any
- ) -> Any:
+ async def authorization_flow_complete(self, *, code: str, state: str, **kwargs: Any) -> Any:
"""Authorization Flow Complete.
Authorization Flow Complete.
@@ -522,18 +488,14 @@ async def authorization_flow_complete(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -561,18 +523,11 @@ def __init__(self, *args, **kwargs) -> None:
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
@distributed_trace_async
async def serve_config(
- self,
- vo: str,
- *,
- if_none_match: Optional[str] = None,
- if_modified_since: Optional[str] = None,
- **kwargs: Any
+ self, vo: str, *, if_none_match: Optional[str] = None, if_modified_since: Optional[str] = None, **kwargs: Any
) -> Any:
"""Serve Config.
@@ -617,18 +572,14 @@ async def serve_config(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -656,9 +607,7 @@ def __init__(self, *args, **kwargs) -> None:
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
@overload
async def submit_bulk_jobs(
@@ -697,9 +646,7 @@ async def submit_bulk_jobs(
"""
@distributed_trace_async
- async def submit_bulk_jobs(
- self, body: Union[List[str], IO], **kwargs: Any
- ) -> List[_models.InsertedJob]:
+ async def submit_bulk_jobs(self, body: Union[List[str], IO], **kwargs: Any) -> List[_models.InsertedJob]:
"""Submit Bulk Jobs.
Submit Bulk Jobs.
@@ -724,9 +671,7 @@ async def submit_bulk_jobs(
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[List[_models.InsertedJob]] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -747,18 +692,14 @@ async def submit_bulk_jobs(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("[InsertedJob]", pipeline_response)
@@ -801,18 +742,14 @@ async def delete_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -855,18 +792,14 @@ async def get_single_job(self, job_id: int, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -909,18 +842,14 @@ async def delete_single_job(self, job_id: int, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -963,18 +892,14 @@ async def kill_single_job(self, job_id: int, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -985,9 +910,7 @@ async def kill_single_job(self, job_id: int, **kwargs: Any) -> Any:
return deserialized
@distributed_trace_async
- async def get_single_job_status(
- self, job_id: int, **kwargs: Any
- ) -> Union[str, _models.JobStatus]:
+ async def get_single_job_status(self, job_id: int, **kwargs: Any) -> Union[str, _models.JobStatus]:
"""Get Single Job Status.
Get Single Job Status.
@@ -1019,18 +942,14 @@ async def get_single_job_status(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("str", pipeline_response)
@@ -1041,9 +960,7 @@ async def get_single_job_status(
return deserialized
@distributed_trace_async
- async def set_single_job_status(
- self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any
- ) -> Any:
+ async def set_single_job_status(self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any) -> Any:
"""Set Single Job Status.
Set Single Job Status.
@@ -1079,18 +996,14 @@ async def set_single_job_status(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1133,18 +1046,14 @@ async def kill_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1187,18 +1096,14 @@ async def get_bulk_job_status(self, *, job_ids: List[int], **kwargs: Any) -> Any
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1210,11 +1115,7 @@ async def get_bulk_job_status(self, *, job_ids: List[int], **kwargs: Any) -> Any
@overload
async def set_status_bulk(
- self,
- body: List[_models.JobStatusUpdate],
- *,
- content_type: str = "application/json",
- **kwargs: Any
+ self, body: List[_models.JobStatusUpdate], *, content_type: str = "application/json", **kwargs: Any
) -> List[_models.JobStatusReturn]:
"""Set Status Bulk.
@@ -1276,9 +1177,7 @@ async def set_status_bulk(
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[List[_models.JobStatusReturn]] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -1299,18 +1198,14 @@ async def set_status_bulk(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("[JobStatusReturn]", pipeline_response)
@@ -1419,9 +1314,7 @@ async def search(
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[List[JSON]] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -1447,18 +1340,14 @@ async def search(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("[object]", pipeline_response)
@@ -1470,11 +1359,7 @@ async def search(
@overload
async def summary(
- self,
- body: _models.JobSummaryParams,
- *,
- content_type: str = "application/json",
- **kwargs: Any
+ self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any
) -> Any:
"""Summary.
@@ -1491,9 +1376,7 @@ async def summary(
"""
@overload
- async def summary(
- self, body: IO, *, content_type: str = "application/json", **kwargs: Any
- ) -> Any:
+ async def summary(self, body: IO, *, content_type: str = "application/json", **kwargs: Any) -> Any:
"""Summary.
Show information suitable for plotting.
@@ -1509,9 +1392,7 @@ async def summary(
"""
@distributed_trace_async
- async def summary(
- self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any
- ) -> Any:
+ async def summary(self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any) -> Any:
"""Summary.
Show information suitable for plotting.
@@ -1536,9 +1417,7 @@ async def summary(
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[Any] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -1559,18 +1438,14 @@ async def summary(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
diff --git a/src/diracx/client/aio/operations/_patch.py b/src/diracx/client/aio/operations/_patch.py
index 0fd71c5d..23c31718 100644
--- a/src/diracx/client/aio/operations/_patch.py
+++ b/src/diracx/client/aio/operations/_patch.py
@@ -69,10 +69,8 @@ async def token(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- await self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
diff --git a/src/diracx/client/models/_enums.py b/src/diracx/client/models/_enums.py
index e782e7e6..825e0464 100644
--- a/src/diracx/client/models/_enums.py
+++ b/src/diracx/client/models/_enums.py
@@ -17,9 +17,7 @@ class Enum0(str, Enum, metaclass=CaseInsensitiveEnumMeta):
class Enum1(str, Enum, metaclass=CaseInsensitiveEnumMeta):
"""Enum1."""
- URN_IETF_PARAMS_OAUTH_GRANT_TYPE_DEVICE_CODE = (
- "urn:ietf:params:oauth:grant-type:device_code"
- )
+ URN_IETF_PARAMS_OAUTH_GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
class Enum2(str, Enum, metaclass=CaseInsensitiveEnumMeta):
diff --git a/src/diracx/client/models/_models.py b/src/diracx/client/models/_models.py
index 7a776f88..0a7c9c81 100644
--- a/src/diracx/client/models/_models.py
+++ b/src/diracx/client/models/_models.py
@@ -105,9 +105,7 @@ class HTTPValidationError(_serialization.Model):
"detail": {"key": "detail", "type": "[ValidationError]"},
}
- def __init__(
- self, *, detail: Optional[List["_models.ValidationError"]] = None, **kwargs: Any
- ) -> None:
+ def __init__(self, *, detail: Optional[List["_models.ValidationError"]] = None, **kwargs: Any) -> None:
"""
:keyword detail: Detail.
:paramtype detail: list[~client.models.ValidationError]
@@ -212,13 +210,7 @@ class InsertedJob(_serialization.Model):
}
def __init__(
- self,
- *,
- job_id: int,
- status: str,
- minor_status: str,
- time_stamp: datetime.datetime,
- **kwargs: Any
+ self, *, job_id: int, status: str, minor_status: str, time_stamp: datetime.datetime, **kwargs: Any
) -> None:
"""
:keyword job_id: Jobid. Required.
@@ -308,9 +300,7 @@ class JobStatusReturn(_serialization.Model):
"status": {"key": "status", "type": "str"},
}
- def __init__(
- self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any
- ) -> None:
+ def __init__(self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any) -> None:
"""
:keyword job_id: Job Id. Required.
:paramtype job_id: int
@@ -345,9 +335,7 @@ class JobStatusUpdate(_serialization.Model):
"status": {"key": "status", "type": "str"},
}
- def __init__(
- self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any
- ) -> None:
+ def __init__(self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any) -> None:
"""
:keyword job_id: Job Id. Required.
:paramtype job_id: int
@@ -381,11 +369,7 @@ class JobSummaryParams(_serialization.Model):
}
def __init__(
- self,
- *,
- grouping: List[str],
- search: List["_models.JobSummaryParamsSearchItem"] = [],
- **kwargs: Any
+ self, *, grouping: List[str], search: List["_models.JobSummaryParamsSearchItem"] = [], **kwargs: Any
) -> None:
"""
:keyword grouping: Grouping. Required.
@@ -435,12 +419,7 @@ class ScalarSearchSpec(_serialization.Model):
}
def __init__(
- self,
- *,
- parameter: str,
- operator: Union[str, "_models.ScalarSearchOperator"],
- value: str,
- **kwargs: Any
+ self, *, parameter: str, operator: Union[str, "_models.ScalarSearchOperator"], value: str, **kwargs: Any
) -> None:
"""
:keyword parameter: Parameter. Required.
@@ -478,9 +457,7 @@ class SortSpec(_serialization.Model):
"direction": {"key": "direction", "type": "SortSpecDirection"},
}
- def __init__(
- self, *, parameter: str, direction: "_models.SortSpecDirection", **kwargs: Any
- ) -> None:
+ def __init__(self, *, parameter: str, direction: "_models.SortSpecDirection", **kwargs: Any) -> None:
"""
:keyword parameter: Parameter. Required.
:paramtype parameter: str
@@ -527,9 +504,7 @@ class TokenResponse(_serialization.Model):
"state": {"key": "state", "type": "str"},
}
- def __init__(
- self, *, access_token: str, expires_in: int, state: str, **kwargs: Any
- ) -> None:
+ def __init__(self, *, access_token: str, expires_in: int, state: str, **kwargs: Any) -> None:
"""
:keyword access_token: Access Token. Required.
:paramtype access_token: str
@@ -569,14 +544,7 @@ class ValidationError(_serialization.Model):
"type": {"key": "type", "type": "str"},
}
- def __init__(
- self,
- *,
- loc: List["_models.ValidationErrorLocItem"],
- msg: str,
- type: str,
- **kwargs: Any
- ) -> None:
+ def __init__(self, *, loc: List["_models.ValidationErrorLocItem"], msg: str, type: str, **kwargs: Any) -> None:
"""
:keyword loc: Location. Required.
:paramtype loc: list[~client.models.ValidationErrorLocItem]
@@ -627,12 +595,7 @@ class VectorSearchSpec(_serialization.Model):
}
def __init__(
- self,
- *,
- parameter: str,
- operator: Union[str, "_models.VectorSearchOperator"],
- values: List[str],
- **kwargs: Any
+ self, *, parameter: str, operator: Union[str, "_models.VectorSearchOperator"], values: List[str], **kwargs: Any
) -> None:
"""
:keyword parameter: Parameter. Required.
diff --git a/src/diracx/client/operations/_operations.py b/src/diracx/client/operations/_operations.py
index 68033e18..5dba2bbf 100644
--- a/src/diracx/client/operations/_operations.py
+++ b/src/diracx/client/operations/_operations.py
@@ -31,9 +31,7 @@
else:
from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports
T = TypeVar("T")
-ClsType = Optional[
- Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]
-]
+ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]]
JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object
_SERIALIZER = Serializer()
@@ -71,14 +69,10 @@ def build_auth_do_device_flow_request(*, user_code: str, **kwargs: Any) -> HttpR
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="GET", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs)
-def build_auth_initiate_device_flow_request(
- *, client_id: str, scope: str, audience: str, **kwargs: Any
-) -> HttpRequest:
+def build_auth_initiate_device_flow_request(*, client_id: str, scope: str, audience: str, **kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})
@@ -95,14 +89,10 @@ def build_auth_initiate_device_flow_request(
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="POST", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs)
-def build_auth_finish_device_flow_request(
- *, code: str, state: str, **kwargs: Any
-) -> HttpRequest:
+def build_auth_finish_device_flow_request(*, code: str, state: str, **kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})
@@ -118,9 +108,7 @@ def build_auth_finish_device_flow_request(
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="GET", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs)
def build_auth_finished_request(**kwargs: Any) -> HttpRequest:
@@ -158,12 +146,8 @@ def build_auth_authorization_flow_request(
# Construct parameters
_params["response_type"] = _SERIALIZER.query("response_type", response_type, "str")
- _params["code_challenge"] = _SERIALIZER.query(
- "code_challenge", code_challenge, "str"
- )
- _params["code_challenge_method"] = _SERIALIZER.query(
- "code_challenge_method", code_challenge_method, "str"
- )
+ _params["code_challenge"] = _SERIALIZER.query("code_challenge", code_challenge, "str")
+ _params["code_challenge_method"] = _SERIALIZER.query("code_challenge_method", code_challenge_method, "str")
_params["client_id"] = _SERIALIZER.query("client_id", client_id, "str")
_params["redirect_uri"] = _SERIALIZER.query("redirect_uri", redirect_uri, "str")
_params["scope"] = _SERIALIZER.query("scope", scope, "str")
@@ -172,9 +156,7 @@ def build_auth_authorization_flow_request(
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="GET", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs)
def build_auth_authorization_flow_complete_request( # pylint: disable=name-too-long
@@ -195,9 +177,7 @@ def build_auth_authorization_flow_complete_request( # pylint: disable=name-too-
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="GET", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs)
def build_config_serve_config_request(
@@ -221,13 +201,9 @@ def build_config_serve_config_request(
# Construct headers
if if_none_match is not None:
- _headers["if-none-match"] = _SERIALIZER.header(
- "if_none_match", if_none_match, "str"
- )
+ _headers["if-none-match"] = _SERIALIZER.header("if_none_match", if_none_match, "str")
if if_modified_since is not None:
- _headers["if-modified-since"] = _SERIALIZER.header(
- "if_modified_since", if_modified_since, "str"
- )
+ _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str")
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs)
@@ -236,9 +212,7 @@ def build_config_serve_config_request(
def build_jobs_submit_bulk_jobs_request(**kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
accept = _headers.pop("Accept", "application/json")
# Construct URL
@@ -246,17 +220,13 @@ def build_jobs_submit_bulk_jobs_request(**kwargs: Any) -> HttpRequest:
# Construct headers
if content_type is not None:
- _headers["Content-Type"] = _SERIALIZER.header(
- "content_type", content_type, "str"
- )
+ _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str")
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs)
-def build_jobs_delete_bulk_jobs_request(
- *, job_ids: List[int], **kwargs: Any
-) -> HttpRequest:
+def build_jobs_delete_bulk_jobs_request(*, job_ids: List[int], **kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})
@@ -271,9 +241,7 @@ def build_jobs_delete_bulk_jobs_request(
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="DELETE", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="DELETE", url=_url, params=_params, headers=_headers, **kwargs)
def build_jobs_get_single_job_request(job_id: int, **kwargs: Any) -> HttpRequest:
@@ -374,14 +342,10 @@ def build_jobs_set_single_job_status_request(
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="POST", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs)
-def build_jobs_kill_bulk_jobs_request(
- *, job_ids: List[int], **kwargs: Any
-) -> HttpRequest:
+def build_jobs_kill_bulk_jobs_request(*, job_ids: List[int], **kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})
@@ -396,14 +360,10 @@ def build_jobs_kill_bulk_jobs_request(
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="POST", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs)
-def build_jobs_get_bulk_job_status_request(
- *, job_ids: List[int], **kwargs: Any
-) -> HttpRequest:
+def build_jobs_get_bulk_job_status_request(*, job_ids: List[int], **kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})
@@ -418,17 +378,13 @@ def build_jobs_get_bulk_job_status_request(
# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="GET", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs)
def build_jobs_set_status_bulk_request(**kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
accept = _headers.pop("Accept", "application/json")
# Construct URL
@@ -436,23 +392,17 @@ def build_jobs_set_status_bulk_request(**kwargs: Any) -> HttpRequest:
# Construct headers
if content_type is not None:
- _headers["Content-Type"] = _SERIALIZER.header(
- "content_type", content_type, "str"
- )
+ _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str")
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs)
-def build_jobs_search_request(
- *, page: int = 0, per_page: int = 100, **kwargs: Any
-) -> HttpRequest:
+def build_jobs_search_request(*, page: int = 0, per_page: int = 100, **kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
accept = _headers.pop("Accept", "application/json")
# Construct URL
@@ -466,22 +416,16 @@ def build_jobs_search_request(
# Construct headers
if content_type is not None:
- _headers["Content-Type"] = _SERIALIZER.header(
- "content_type", content_type, "str"
- )
+ _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str")
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
- return HttpRequest(
- method="POST", url=_url, params=_params, headers=_headers, **kwargs
- )
+ return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs)
def build_jobs_summary_request(**kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
accept = _headers.pop("Accept", "application/json")
# Construct URL
@@ -489,9 +433,7 @@ def build_jobs_summary_request(**kwargs: Any) -> HttpRequest:
# Construct headers
if content_type is not None:
- _headers["Content-Type"] = _SERIALIZER.header(
- "content_type", content_type, "str"
- )
+ _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str")
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")
return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs)
@@ -514,9 +456,7 @@ def __init__(self, *args, **kwargs):
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
@distributed_trace
def openid_configuration(self, **kwargs: Any) -> Any:
@@ -548,18 +488,14 @@ def openid_configuration(self, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -587,9 +523,7 @@ def __init__(self, *args, **kwargs):
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
raise_if_not_implemented(
self.__class__,
[
@@ -637,18 +571,14 @@ def do_device_flow(self, *, user_code: str, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -705,23 +635,17 @@ def initiate_device_flow(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
- deserialized = self._deserialize(
- "InitiateDeviceFlowResponse", pipeline_response
- )
+ deserialized = self._deserialize("InitiateDeviceFlowResponse", pipeline_response)
if cls:
return cls(pipeline_response, deserialized, {})
@@ -768,18 +692,14 @@ def finish_device_flow(self, *, code: str, state: str, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -819,18 +739,14 @@ def finished(self, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -902,18 +818,14 @@ def authorization_flow(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -924,9 +836,7 @@ def authorization_flow(
return deserialized
@distributed_trace
- def authorization_flow_complete(
- self, *, code: str, state: str, **kwargs: Any
- ) -> Any:
+ def authorization_flow_complete(self, *, code: str, state: str, **kwargs: Any) -> Any:
"""Authorization Flow Complete.
Authorization Flow Complete.
@@ -961,18 +871,14 @@ def authorization_flow_complete(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1000,9 +906,7 @@ def __init__(self, *args, **kwargs):
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
@distributed_trace
def serve_config(
@@ -1056,18 +960,14 @@ def serve_config(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1095,9 +995,7 @@ def __init__(self, *args, **kwargs):
self._client = input_args.pop(0) if input_args else kwargs.pop("client")
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
- self._deserialize = (
- input_args.pop(0) if input_args else kwargs.pop("deserializer")
- )
+ self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
@overload
def submit_bulk_jobs(
@@ -1136,9 +1034,7 @@ def submit_bulk_jobs(
"""
@distributed_trace
- def submit_bulk_jobs(
- self, body: Union[List[str], IO], **kwargs: Any
- ) -> List[_models.InsertedJob]:
+ def submit_bulk_jobs(self, body: Union[List[str], IO], **kwargs: Any) -> List[_models.InsertedJob]:
"""Submit Bulk Jobs.
Submit Bulk Jobs.
@@ -1163,9 +1059,7 @@ def submit_bulk_jobs(
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[List[_models.InsertedJob]] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -1186,18 +1080,14 @@ def submit_bulk_jobs(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("[InsertedJob]", pipeline_response)
@@ -1240,18 +1130,14 @@ def delete_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1294,18 +1180,14 @@ def get_single_job(self, job_id: int, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1348,18 +1230,14 @@ def delete_single_job(self, job_id: int, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1402,18 +1280,14 @@ def kill_single_job(self, job_id: int, **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1424,9 +1298,7 @@ def kill_single_job(self, job_id: int, **kwargs: Any) -> Any:
return deserialized
@distributed_trace
- def get_single_job_status(
- self, job_id: int, **kwargs: Any
- ) -> Union[str, _models.JobStatus]:
+ def get_single_job_status(self, job_id: int, **kwargs: Any) -> Union[str, _models.JobStatus]:
"""Get Single Job Status.
Get Single Job Status.
@@ -1458,18 +1330,14 @@ def get_single_job_status(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("str", pipeline_response)
@@ -1480,9 +1348,7 @@ def get_single_job_status(
return deserialized
@distributed_trace
- def set_single_job_status(
- self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any
- ) -> Any:
+ def set_single_job_status(self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any) -> Any:
"""Set Single Job Status.
Set Single Job Status.
@@ -1518,18 +1384,14 @@ def set_single_job_status(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1572,18 +1434,14 @@ def kill_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1626,18 +1484,14 @@ def get_bulk_job_status(self, *, job_ids: List[int], **kwargs: Any) -> Any:
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
@@ -1715,9 +1569,7 @@ def set_status_bulk(
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[List[_models.JobStatusReturn]] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -1738,18 +1590,14 @@ def set_status_bulk(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("[JobStatusReturn]", pipeline_response)
@@ -1858,9 +1706,7 @@ def search(
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[List[JSON]] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -1886,18 +1732,14 @@ def search(
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("[object]", pipeline_response)
@@ -1930,9 +1772,7 @@ def summary(
"""
@overload
- def summary(
- self, body: IO, *, content_type: str = "application/json", **kwargs: Any
- ) -> Any:
+ def summary(self, body: IO, *, content_type: str = "application/json", **kwargs: Any) -> Any:
"""Summary.
Show information suitable for plotting.
@@ -1973,9 +1813,7 @@ def summary(self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any) -> A
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}
- content_type: Optional[str] = kwargs.pop(
- "content_type", _headers.pop("Content-Type", None)
- )
+ content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[Any] = kwargs.pop("cls", None)
content_type = content_type or "application/json"
@@ -1996,18 +1834,14 @@ def summary(self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any) -> A
request.url = self._client.format_url(request.url)
_stream = False
- pipeline_response: PipelineResponse = (
- self._client._pipeline.run( # pylint: disable=protected-access
- request, stream=_stream, **kwargs
- )
+ pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
+ request, stream=_stream, **kwargs
)
response = pipeline_response.http_response
if response.status_code not in [200]:
- map_error(
- status_code=response.status_code, response=response, error_map=error_map
- )
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)
deserialized = self._deserialize("object", pipeline_response)
diff --git a/src/diracx/client/operations/_patch.py b/src/diracx/client/operations/_patch.py
index b2f097ba..2bdd9763 100644
--- a/src/diracx/client/operations/_patch.py
+++ b/src/diracx/client/operations/_patch.py
@@ -15,9 +15,7 @@
from .. import models as _models
from ._operations import AuthOperations as AuthOperationsGenerated
-__all__: List[str] = [
- "AuthOperations"
-] # Add all objects you want publicly available to users at this package level
+__all__: List[str] = ["AuthOperations"] # Add all objects you want publicly available to users at this package level
def patch_sdk():
diff --git a/src/diracx/core/config/__init__.py b/src/diracx/core/config/__init__.py
index a09784e3..ef10e5ae 100644
--- a/src/diracx/core/config/__init__.py
+++ b/src/diracx/core/config/__init__.py
@@ -96,9 +96,7 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
repo_location = Path(backend_url.path)
self.repo_location = repo_location
self.repo = git.Repo(repo_location)
- self._latest_revision_cache: Cache = TTLCache(
- MAX_CS_CACHED_VERSIONS, DEFAULT_CS_CACHE_TTL
- )
+ self._latest_revision_cache: Cache = TTLCache(MAX_CS_CACHED_VERSIONS, DEFAULT_CS_CACHE_TTL)
self._read_raw_cache: Cache = LRUCache(MAX_CS_CACHED_VERSIONS)
def __hash__(self):
@@ -115,9 +113,7 @@ def latest_revision(self) -> tuple[str, datetime]:
except git.exc.ODBError as e: # type: ignore
raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e
modified = rev.committed_datetime.astimezone(timezone.utc)
- logger.debug(
- "Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified
- )
+ logger.debug("Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified)
return rev.hexsha, modified
@cachedmethod(lambda self: self._read_raw_cache)
diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py
index 8f3470e0..3f2a628c 100644
--- a/src/diracx/core/config/schema.py
+++ b/src/diracx/core/config/schema.py
@@ -2,7 +2,7 @@
import os
from datetime import datetime
-from typing import Any, Optional
+from typing import Any
from pydantic import BaseModel as _BaseModel
from pydantic import EmailStr, PrivateAttr, root_validator
@@ -22,9 +22,7 @@ def legacy_adaptor(cls, v):
# though ideally we should parse the type hints properly.
for field, hint in cls.__annotations__.items():
# Convert comma separated lists to actual lists
- if hint in {"list[str]", "list[SecurityProperty]"} and isinstance(
- v.get(field), str
- ):
+ if hint in {"list[str]", "list[SecurityProperty]"} and isinstance(v.get(field), str):
v[field] = [x.strip() for x in v[field].split(",") if x.strip()]
# If the field is optional and the value is "None" convert it to None
if "| None" in hint and field in v:
@@ -49,12 +47,12 @@ class GroupConfig(BaseModel):
AutoAddVOMS: bool = False
AutoUploadPilotProxy: bool = False
AutoUploadProxy: bool = False
- JobShare: Optional[int]
+ JobShare: int | None
Properties: list[SecurityProperty]
- Quota: Optional[int]
+ Quota: int | None
Users: list[str]
AllowBackgroundTQs: bool = False
- VOMSRole: Optional[str]
+ VOMSRole: str | None
AutoSyncVOMS: bool = False
diff --git a/src/diracx/core/extensions.py b/src/diracx/core/extensions.py
index 46786a96..27bb6396 100644
--- a/src/diracx/core/extensions.py
+++ b/src/diracx/core/extensions.py
@@ -1,12 +1,10 @@
-from __future__ import absolute_import
-
__all__ = ("select_from_extension",)
import os
from collections import defaultdict
+from collections.abc import Iterator
from importlib.metadata import EntryPoint, entry_points
from importlib.util import find_spec
-from typing import Iterator
def extensions_by_priority() -> Iterator[str]:
@@ -17,9 +15,7 @@ def extensions_by_priority() -> Iterator[str]:
yield module_name
-def select_from_extension(
- *, group: str, name: str | None = None
-) -> Iterator[EntryPoint]:
+def select_from_extension(*, group: str, name: str | None = None) -> Iterator[EntryPoint]:
"""Select entry points by group and name, in order of priority.
Similar to ``importlib.metadata.entry_points.select`` except only modules
diff --git a/src/diracx/core/properties.py b/src/diracx/core/properties.py
index 8bf80201..4403afc8 100644
--- a/src/diracx/core/properties.py
+++ b/src/diracx/core/properties.py
@@ -5,7 +5,7 @@
import inspect
import operator
-from typing import Callable
+from collections.abc import Callable
from diracx.core.extensions import select_from_extension
@@ -14,9 +14,7 @@ class SecurityProperty(str):
@classmethod
def available_properties(cls) -> set[SecurityProperty]:
properties = set()
- for entry_point in select_from_extension(
- group="diracx", name="properties_module"
- ):
+ for entry_point in select_from_extension(group="diracx", name="properties_module"):
properties_module = entry_point.load()
for _, obj in inspect.getmembers(properties_module):
if isinstance(obj, SecurityProperty):
@@ -26,23 +24,17 @@ def available_properties(cls) -> set[SecurityProperty]:
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self})"
- def __and__(
- self, value: SecurityProperty | UnevaluatedProperty
- ) -> UnevaluatedExpression:
+ def __and__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression:
if not isinstance(value, UnevaluatedProperty):
value = UnevaluatedProperty(value)
return UnevaluatedProperty(self) & value
- def __or__(
- self, value: SecurityProperty | UnevaluatedProperty
- ) -> UnevaluatedExpression:
+ def __or__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression:
if not isinstance(value, UnevaluatedProperty):
value = UnevaluatedProperty(value)
return UnevaluatedProperty(self) | value
- def __xor__(
- self, value: SecurityProperty | UnevaluatedProperty
- ) -> UnevaluatedExpression:
+ def __xor__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression:
if not isinstance(value, UnevaluatedProperty):
value = UnevaluatedProperty(value)
return UnevaluatedProperty(self) ^ value
diff --git a/src/diracx/db/auth/db.py b/src/diracx/db/auth/db.py
index 687e8aa9..30454eb4 100644
--- a/src/diracx/db/auth/db.py
+++ b/src/diracx/db/auth/db.py
@@ -25,9 +25,7 @@ class AuthDB(BaseDB):
# This needs to be here for the BaseDB to create the engine
metadata = AuthDBBase.metadata
- async def device_flow_validate_user_code(
- self, user_code: str, max_validity: int
- ) -> str:
+ async def device_flow_validate_user_code(self, user_code: str, max_validity: int) -> str:
"""Validate that the user_code can be used (Pending status, not expired)
Returns the scope field for the given user_code
@@ -52,9 +50,7 @@ async def get_device_flow(self, device_code: str, max_validity: int):
# multiple time concurrently
stmt = select(
DeviceFlows,
- (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label(
- "is_expired"
- ),
+ (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label("is_expired"),
).with_for_update()
stmt = stmt.where(
DeviceFlows.device_code == device_code,
@@ -67,9 +63,7 @@ async def get_device_flow(self, device_code: str, max_validity: int):
if res["status"] == FlowStatus.READY:
# Update the status to Done before returning
await self.conn.execute(
- update(DeviceFlows)
- .where(DeviceFlows.device_code == device_code)
- .values(status=FlowStatus.DONE)
+ update(DeviceFlows).where(DeviceFlows.device_code == device_code).values(status=FlowStatus.DONE)
)
return res
@@ -82,9 +76,7 @@ async def get_device_flow(self, device_code: str, max_validity: int):
raise AuthorizationError("Bad state in device flow")
- async def device_flow_insert_id_token(
- self, user_code: str, id_token: dict[str, str], max_validity: int
- ) -> None:
+ async def device_flow_insert_id_token(self, user_code: str, id_token: dict[str, str], max_validity: int) -> None:
"""
:raises: AuthorizationError if no such code or status not pending
"""
@@ -97,9 +89,7 @@ async def device_flow_insert_id_token(
stmt = stmt.values(id_token=id_token, status=FlowStatus.READY)
res = await self.conn.execute(stmt)
if res.rowcount != 1:
- raise AuthorizationError(
- f"{res.rowcount} rows matched user_code {user_code}"
- )
+ raise AuthorizationError(f"{res.rowcount} rows matched user_code {user_code}")
async def insert_device_flow(
self,
@@ -109,8 +99,7 @@ async def insert_device_flow(
) -> tuple[str, str]:
for _ in range(MAX_RETRY):
user_code = "".join(
- secrets.choice(USER_CODE_ALPHABET)
- for _ in range(DeviceFlows.user_code.type.length) # type: ignore
+ secrets.choice(USER_CODE_ALPHABET) for _ in range(DeviceFlows.user_code.type.length) # type: ignore
)
# user_code = "2QRKPY"
device_code = secrets.token_urlsafe()
@@ -128,9 +117,7 @@ async def insert_device_flow(
continue
return user_code, device_code
- raise NotImplementedError(
- f"Could not insert new device flow after {MAX_RETRY} retries"
- )
+ raise NotImplementedError(f"Could not insert new device flow after {MAX_RETRY} retries")
async def insert_authorization_flow(
self,
@@ -200,9 +187,7 @@ async def get_authorization_flow(self, code: str, max_validity: int):
if res["status"] == FlowStatus.READY:
# Update the status to Done before returning
await self.conn.execute(
- update(AuthorizationFlows)
- .where(AuthorizationFlows.code == code)
- .values(status=FlowStatus.DONE)
+ update(AuthorizationFlows).where(AuthorizationFlows.code == code).values(status=FlowStatus.DONE)
)
return res
diff --git a/src/diracx/db/dummy/db.py b/src/diracx/db/dummy/db.py
index a4456219..06cf3483 100644
--- a/src/diracx/db/dummy/db.py
+++ b/src/diracx/db/dummy/db.py
@@ -30,11 +30,7 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]:
stmt = stmt.group_by(*columns)
# Execute the query
- return [
- dict(row._mapping)
- async for row in (await self.conn.stream(stmt))
- if row.count > 0 # type: ignore
- ]
+ return [dict(row._mapping) async for row in (await self.conn.stream(stmt)) if row.count > 0] # type: ignore
async def insert_owner(self, name: str) -> int:
stmt = insert(Owners).values(name=name)
@@ -43,9 +39,7 @@ async def insert_owner(self, name: str) -> int:
return result.lastrowid
async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int:
- stmt = insert(Cars).values(
- licensePlate=license_plate, model=model, ownerID=owner_id
- )
+ stmt = insert(Cars).values(licensePlate=license_plate, model=model, ownerID=owner_id)
result = await self.conn.execute(stmt)
# await self.engine.commit()
diff --git a/src/diracx/db/jobs/db.py b/src/diracx/db/jobs/db.py
index a933cfef..0533dc3d 100644
--- a/src/diracx/db/jobs/db.py
+++ b/src/diracx/db/jobs/db.py
@@ -30,11 +30,7 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]:
stmt = stmt.group_by(*columns)
# Execute the query
- return [
- dict(row._mapping)
- async for row in (await self.conn.stream(stmt))
- if row.count > 0 # type: ignore
- ]
+ return [dict(row._mapping) async for row in (await self.conn.stream(stmt)) if row.count > 0] # type: ignore
async def search(
self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None
@@ -42,12 +38,8 @@ async def search(
# Find which columns to select
columns = [x for x in Jobs.__table__.columns]
if parameters:
- if unrecognised_parameters := set(parameters) - set(
- Jobs.__table__.columns.keys()
- ):
- raise InvalidQueryError(
- f"Unrecognised parameters requested {unrecognised_parameters}"
- )
+ if unrecognised_parameters := set(parameters) - set(Jobs.__table__.columns.keys()):
+ raise InvalidQueryError(f"Unrecognised parameters requested {unrecognised_parameters}")
columns = [c for c in columns if c.name in parameters]
stmt = select(*columns)
@@ -73,9 +65,7 @@ async def search(
async def _insertNewJDL(self, jdl) -> int:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
- stmt = insert(JobJDLs).values(
- JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl)
- )
+ stmt = insert(JobJDLs).values(JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl))
result = await self.conn.execute(stmt)
# await self.engine.commit()
return result.lastrowid
@@ -140,9 +130,7 @@ async def _checkAndPrepareJob(
async def setJobJDL(self, job_id, jdl):
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
- stmt = (
- update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl))
- )
+ stmt = update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl))
await self.conn.execute(stmt)
async def insert(
@@ -173,9 +161,7 @@ async def insert(
"DIRACSetup": dirac_setup,
}
- jobManifest = returnValueOrRaise(
- checkAndAddOwner(jdl, owner, owner_dn, owner_group, dirac_setup)
- )
+ jobManifest = returnValueOrRaise(checkAndAddOwner(jdl, owner, owner_dn, owner_group, dirac_setup))
jdl = fixJDL(jdl)
diff --git a/src/diracx/db/jobs/schema.py b/src/diracx/db/jobs/schema.py
index 532f3a83..a00ab0bc 100644
--- a/src/diracx/db/jobs/schema.py
+++ b/src/diracx/db/jobs/schema.py
@@ -59,9 +59,7 @@ class Jobs(Base):
JobType = Column("JobType", String(32), default="user")
DIRACSetup = Column("DIRACSetup", String(32), default="test")
JobGroup = Column("JobGroup", String(32), default="00000000")
- JobSplitType = Column(
- "JobSplitType", Enum("Single", "Master", "Subjob", "DAGNode"), default="Single"
- )
+ JobSplitType = Column("JobSplitType", Enum("Single", "Master", "Subjob", "DAGNode"), default="Single")
MasterJobID = Column("MasterJobID", Integer, default=0)
Site = Column("Site", String(100), default="ANY")
JobName = Column("JobName", String(128), default="Unknown")
@@ -89,9 +87,7 @@ class Jobs(Base):
OSandboxReadyFlag = Column("OSandboxReadyFlag", EnumBackedBool(), default=False)
RetrievedFlag = Column("RetrievedFlag", EnumBackedBool(), default=False)
# TODO: Should this be True/False/"Failed"? Or True/False/Null?
- AccountedFlag = Column(
- "AccountedFlag", Enum("True", "False", "Failed"), default="False"
- )
+ AccountedFlag = Column("AccountedFlag", Enum("True", "False", "Failed"), default="False")
__table_args__ = (
ForeignKeyConstraint(["JobID"], ["JobJDLs.JobID"]),
diff --git a/src/diracx/db/utils.py b/src/diracx/db/utils.py
index 4efc58d5..750a54e5 100644
--- a/src/diracx/db/utils.py
+++ b/src/diracx/db/utils.py
@@ -5,9 +5,10 @@
import contextlib
import os
from abc import ABCMeta
+from collections.abc import AsyncIterator
from datetime import datetime, timedelta, timezone
from functools import partial
-from typing import TYPE_CHECKING, AsyncIterator, Self
+from typing import TYPE_CHECKING, Self
from pydantic import parse_obj_as
from sqlalchemy import Column as RawColumn
diff --git a/src/diracx/routers/__init__.py b/src/diracx/routers/__init__.py
index fd8710b9..43710cf3 100644
--- a/src/diracx/routers/__init__.py
+++ b/src/diracx/routers/__init__.py
@@ -3,8 +3,9 @@
import inspect
import logging
import os
+from collections.abc import AsyncGenerator, Iterable
from functools import partial
-from typing import AsyncContextManager, AsyncGenerator, Iterable, TypeVar
+from typing import AsyncContextManager, TypeVar
import dotenv
from fastapi import APIRouter, Depends, Request
@@ -59,8 +60,7 @@ def create_app_inner(
available_db_classes: set[type[BaseDB]] = set()
for db_name, db_url in database_urls.items():
db_classes: list[type[BaseDB]] = [
- entry_point.load()
- for entry_point in select_from_extension(group="diracx.dbs", name=db_name)
+ entry_point.load() for entry_point in select_from_extension(group="diracx.dbs", name=db_name)
]
assert db_classes, f"Could not find {db_name=}"
# The first DB is the highest priority one
@@ -79,9 +79,7 @@ def create_app_inner(
# Without this AutoREST generates different client sources for each ordering
for system_name in sorted(enabled_systems):
assert system_name not in routers
- for entry_point in select_from_extension(
- group="diracx.services", name=system_name
- ):
+ for entry_point in select_from_extension(group="diracx.services", name=system_name):
routers[system_name] = entry_point.load()
break
else:
@@ -92,16 +90,12 @@ def create_app_inner(
# Ensure required settings are available
for cls in find_dependents(router, ServiceSettingsBase):
if cls not in available_settings_classes:
- raise NotImplementedError(
- f"Cannot enable {system_name=} as it requires {cls=}"
- )
+ raise NotImplementedError(f"Cannot enable {system_name=} as it requires {cls=}")
# Ensure required DBs are available
missing_dbs = set(find_dependents(router, BaseDB)) - available_db_classes
if missing_dbs:
- raise NotImplementedError(
- f"Cannot enable {system_name=} as it requires {missing_dbs=}"
- )
+ raise NotImplementedError(f"Cannot enable {system_name=} as it requires {missing_dbs=}")
# Add the router to the application
dependencies = []
@@ -155,18 +149,14 @@ def create_app() -> DiracFastAPI:
def dirac_error_handler(request: Request, exc: DiracError) -> Response:
- return JSONResponse(
- status_code=exc.http_status_code, content={"detail": exc.detail}
- )
+ return JSONResponse(status_code=exc.http_status_code, content={"detail": exc.detail})
def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response:
return JSONResponse(status_code=exc.status_code, content=exc.data)
-def find_dependents(
- obj: APIRouter | Iterable[Dependant], cls: type[T]
-) -> Iterable[type[T]]:
+def find_dependents(obj: APIRouter | Iterable[Dependant], cls: type[T]) -> Iterable[type[T]]:
if isinstance(obj, APIRouter):
# TODO: Support dependencies of the router itself
# yield from find_dependents(obj.dependencies, cls)
diff --git a/src/diracx/routers/auth.py b/src/diracx/routers/auth.py
index ac88c4ad..898ec365 100644
--- a/src/diracx/routers/auth.py
+++ b/src/diracx/routers/auth.py
@@ -63,17 +63,11 @@ class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"):
access_token_expire_minutes: int = 3000
refresh_token_expire_minutes: int = 3000
- available_properties: set[SecurityProperty] = Field(
- default_factory=SecurityProperty.available_properties
- )
+ available_properties: set[SecurityProperty] = Field(default_factory=SecurityProperty.available_properties)
def has_properties(expression: UnevaluatedProperty | SecurityProperty):
- evaluator = (
- expression
- if isinstance(expression, UnevaluatedProperty)
- else UnevaluatedProperty(expression)
- )
+ evaluator = expression if isinstance(expression, UnevaluatedProperty) else UnevaluatedProperty(expression)
async def require_property(user: Annotated[UserInfo, Depends(verify_dirac_token)]):
if not evaluator(user.properties):
@@ -127,9 +121,7 @@ async def fetch_jwk_set(url: str):
async def parse_id_token(config, vo, raw_id_token: str, audience: str):
- server_metadata = await get_server_metadata(
- config.Registry[vo].IdP.server_metadata_url
- )
+ server_metadata = await get_server_metadata(config.Registry[vo].IdP.server_metadata_url)
alg_values = server_metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwk_set = await fetch_jwk_set(config.Registry[vo].IdP.server_metadata_url)
@@ -209,22 +201,16 @@ async def verify_dirac_token(
)
-def create_access_token(
- payload: dict, settings: AuthSettings, expires_delta: timedelta | None = None
-) -> str:
+def create_access_token(payload: dict, settings: AuthSettings, expires_delta: timedelta | None = None) -> str:
to_encode = payload.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
- expire = datetime.utcnow() + timedelta(
- minutes=settings.access_token_expire_minutes
- )
+ expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
to_encode.update({"exp": expire})
jwt = JsonWebToken(settings.token_algorithm)
- encoded_jwt = jwt.encode(
- {"alg": settings.token_algorithm}, payload, settings.token_key.jwk
- )
+ encoded_jwt = jwt.encode({"alg": settings.token_algorithm}, payload, settings.token_key.jwk)
return encoded_jwt.decode("ascii")
@@ -240,9 +226,7 @@ async def exchange_token(
preferred_username = id_token.get("preferred_username", sub)
if sub not in config.Registry[vo].Groups[dirac_group].Users:
- raise ValueError(
- f"User is not a member of the requested group ({preferred_username}, {dirac_group})"
- )
+ raise ValueError(f"User is not a member of the requested group ({preferred_username}, {dirac_group})")
payload = {
"sub": f"{vo}:{sub}",
@@ -290,9 +274,7 @@ async def initiate_device_flow(
`auth//device?user_code=XYZ`
"""
if settings.dirac_client_id != client_id:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID"
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID")
try:
parse_and_validate_scope(scope, config, available_properties)
@@ -302,9 +284,7 @@ async def initiate_device_flow(
detail=e.args[0],
) from e
- user_code, device_code = await auth_db.insert_device_flow(
- client_id, scope, audience
- )
+ user_code, device_code = await auth_db.insert_device_flow(client_id, scope, audience)
verification_uri = str(request.url.replace(query={}))
@@ -317,22 +297,14 @@ async def initiate_device_flow(
}
-async def initiate_authorization_flow_with_iam(
- config, vo: str, redirect_uri: str, state: dict[str, str]
-):
+async def initiate_authorization_flow_with_iam(config, vo: str, redirect_uri: str, state: dict[str, str]):
# code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1
code_verifier = secrets.token_hex()
# code_challenge: https://www.rfc-editor.org/rfc/rfc7636#section-4.2
- code_challenge = (
- base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
- .decode()
- .replace("=", "")
- )
+ code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().replace("=", "")
- server_metadata = await get_server_metadata(
- config.Registry[vo].IdP.server_metadata_url
- )
+ server_metadata = await get_server_metadata(config.Registry[vo].IdP.server_metadata_url)
# Take these two from CS/.well-known
authorization_endpoint = server_metadata["authorization_endpoint"]
@@ -355,12 +327,8 @@ async def initiate_authorization_flow_with_iam(
return authorization_flow_url
-async def get_token_from_iam(
- config, vo: str, code: str, state: dict[str, str], redirect_uri: str
-) -> dict[str, str]:
- server_metadata = await get_server_metadata(
- config.Registry[vo].IdP.server_metadata_url
- )
+async def get_token_from_iam(config, vo: str, code: str, state: dict[str, str], redirect_uri: str) -> dict[str, str]:
+ server_metadata = await get_server_metadata(config.Registry[vo].IdP.server_metadata_url)
# Take these two from CS/.well-known
token_endpoint = server_metadata["token_endpoint"]
@@ -379,9 +347,7 @@ async def get_token_from_iam(
data=data,
)
if res.status_code >= 500:
- raise HTTPException(
- status.HTTP_502_BAD_GATEWAY, "Failed to contact token endpoint"
- )
+ raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to contact token endpoint")
elif res.status_code >= 400:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid code")
@@ -421,9 +387,7 @@ async def do_device_flow(
"""
# Here we make sure the user_code actualy exists
- scope = await auth_db.device_flow_validate_user_code(
- user_code, settings.device_flow_expiration_seconds
- )
+ scope = await auth_db.device_flow_validate_user_code(user_code, settings.device_flow_expiration_seconds)
parsed_scope = parse_and_validate_scope(scope, config, available_properties)
redirect_uri = f"{request.url.replace(query='')}/complete"
@@ -444,9 +408,7 @@ def decrypt_state(state):
# TODO: There have been better schemes like rot13
return json.loads(base64.urlsafe_b64decode(state).decode())
except Exception as e:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state"
- ) from e
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state") from e
@router.get("/device/complete")
@@ -466,9 +428,7 @@ async def finish_device_flow(
in the cookie/session
"""
decrypted_state = decrypt_state(state)
- assert (
- decrypted_state["grant_type"] == "urn:ietf:params:oauth:grant-type:device_code"
- )
+ assert decrypted_state["grant_type"] == "urn:ietf:params:oauth:grant-type:device_code"
id_token = await get_token_from_iam(
config,
@@ -504,9 +464,7 @@ class ScopeInfoDict(TypedDict):
vo: str
-def parse_and_validate_scope(
- scope: str, config: Config, available_properties: set[SecurityProperty]
-) -> ScopeInfoDict:
+def parse_and_validate_scope(scope: str, config: Config, available_properties: set[SecurityProperty]) -> ScopeInfoDict:
"""
Check:
* At most one VO
@@ -538,10 +496,7 @@ def parse_and_validate_scope(
if not vos:
available_vo_scopes = [repr(f"vo:{vo}") for vo in config.Registry]
- raise ValueError(
- "No vo scope requested, available values: "
- f"{' '.join(available_vo_scopes)}"
- )
+ raise ValueError("No vo scope requested, available values: " f"{' '.join(available_vo_scopes)}")
elif len(vos) > 1:
raise ValueError(f"Only one vo is allowed but got {vos}")
else:
@@ -564,9 +519,7 @@ def parse_and_validate_scope(
properties = [str(p) for p in config.Registry[vo].Groups[group].Properties]
if not set(properties).issubset(available_properties):
- raise ValueError(
- f"{set(properties)-set(available_properties)} are not valid properties"
- )
+ raise ValueError(f"{set(properties)-set(available_properties)} are not valid properties")
return {
"group": group,
@@ -623,8 +576,7 @@ def parse_and_validate_scope(
@router.post("/token")
async def token(
grant_type: Annotated[
- Literal["authorization_code"]
- | Literal["urn:ietf:params:oauth:grant-type:device_code"],
+ Literal["authorization_code"] | Literal["urn:ietf:params:oauth:grant-type:device_code"],
Form(description="OAuth2 Grant type"),
],
client_id: Annotated[str, Form(description="OAuth2 client id")],
@@ -632,21 +584,15 @@ async def token(
config: Config,
settings: AuthSettings,
available_properties: AvailableSecurityProperties,
- device_code: Annotated[
- str | None, Form(description="device code for OAuth2 device flow")
- ] = None,
- code: Annotated[
- str | None, Form(description="Code for OAuth2 authorization code flow")
- ] = None,
+ device_code: Annotated[str | None, Form(description="device code for OAuth2 device flow")] = None,
+ code: Annotated[str | None, Form(description="Code for OAuth2 authorization code flow")] = None,
redirect_uri: Annotated[
str | None,
Form(description="redirect_uri used with OAuth2 authorization code flow"),
] = None,
code_verifier: Annotated[
str | None,
- Form(
- description="Verifier for the code challenge for the OAuth2 authorization flow with PKCE"
- ),
+ Form(description="Verifier for the code challenge for the OAuth2 authorization flow with PKCE"),
] = None,
) -> TokenResponse:
""" " Token endpoint to retrieve the token at the end of a flow.
@@ -656,17 +602,11 @@ async def token(
if grant_type == "urn:ietf:params:oauth:grant-type:device_code":
assert device_code is not None
try:
- info = await auth_db.get_device_flow(
- device_code, settings.device_flow_expiration_seconds
- )
+ info = await auth_db.get_device_flow(device_code, settings.device_flow_expiration_seconds)
except PendingAuthorizationError as e:
- raise DiracHttpResponse(
- status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"}
- ) from e
+ raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"}) from e
except ExpiredFlowError as e:
- raise DiracHttpResponse(
- status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}
- ) from e
+ raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) from e
# raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"})
# raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"})
@@ -678,9 +618,7 @@ async def token(
elif grant_type == "authorization_code":
assert code is not None
- info = await auth_db.get_authorization_flow(
- code, settings.authorization_flow_expiration_seconds
- )
+ info = await auth_db.get_authorization_flow(code, settings.authorization_flow_expiration_seconds)
if redirect_uri != info["redirect_uri"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -690,11 +628,7 @@ async def token(
try:
assert code_verifier is not None
code_challenge = (
- base64.urlsafe_b64encode(
- hashlib.sha256(code_verifier.encode()).digest()
- )
- .decode()
- .strip("=")
+ base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().strip("=")
)
except Exception as e:
raise HTTPException(
@@ -744,13 +678,9 @@ async def authorization_flow(
settings: AuthSettings,
):
if settings.dirac_client_id != client_id:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID"
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID")
if redirect_uri not in settings.allowed_redirects:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised redirect_uri"
- )
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised redirect_uri")
try:
parsed_scope = parse_and_validate_scope(scope, config, available_properties)
@@ -811,6 +741,4 @@ async def authorization_flow_complete(
settings.authorization_flow_expiration_seconds,
)
- return responses.RedirectResponse(
- f"{redirect_uri}?code={code}&state={decrypted_state['external_state']}"
- )
+ return responses.RedirectResponse(f"{redirect_uri}?code={code}&state={decrypted_state['external_state']}")
diff --git a/src/diracx/routers/configuration.py b/src/diracx/routers/configuration.py
index 97a80be9..020ab5ea 100644
--- a/src/diracx/routers/configuration.py
+++ b/src/diracx/routers/configuration.py
@@ -47,16 +47,12 @@ async def serve_config(
# a server gets out of sync with disk
if if_modified_since:
try:
- not_before = datetime.strptime(
- if_modified_since, LAST_MODIFIED_FORMAT
- ).astimezone(timezone.utc)
+ not_before = datetime.strptime(if_modified_since, LAST_MODIFIED_FORMAT).astimezone(timezone.utc)
except ValueError:
pass
else:
if not_before > config._modified:
- raise HTTPException(
- status_code=status.HTTP_304_NOT_MODIFIED, headers=headers
- )
+ raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, headers=headers)
response.headers.update(headers)
diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py
index 9df412a1..1aeef9a8 100644
--- a/src/diracx/routers/dependencies.py
+++ b/src/diracx/routers/dependencies.py
@@ -32,6 +32,4 @@ def add_settings_annotation(cls: T) -> T:
# Miscellaneous
Config = Annotated[_Config, Depends(ConfigSource.create)]
-AvailableSecurityProperties = Annotated[
- set[SecurityProperty], Depends(SecurityProperty.available_properties)
-]
+AvailableSecurityProperties = Annotated[set[SecurityProperty], Depends(SecurityProperty.available_properties)]
diff --git a/src/diracx/routers/fastapi_classes.py b/src/diracx/routers/fastapi_classes.py
index 33d7d863..5e5e3f6e 100644
--- a/src/diracx/routers/fastapi_classes.py
+++ b/src/diracx/routers/fastapi_classes.py
@@ -19,9 +19,7 @@ def __init__(self):
@contextlib.asynccontextmanager
async def lifespan(app: DiracFastAPI):
async with contextlib.AsyncExitStack() as stack:
- await asyncio.gather(
- *(stack.enter_async_context(f()) for f in app.lifetime_functions)
- )
+ await asyncio.gather(*(stack.enter_async_context(f()) for f in app.lifetime_functions))
yield
self.lifetime_functions = []
diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py
index 0092a4fa..92c04066 100644
--- a/src/diracx/routers/job_manager/__init__.py
+++ b/src/diracx/routers/job_manager/__init__.py
@@ -152,8 +152,7 @@ async def submit_bulk_jobs(
if nJobs > MAX_PARAMETRIC_JOBS:
raise NotImplementedError(
EWMSJDL,
- "Number of parametric jobs exceeds the limit of %d"
- % MAX_PARAMETRIC_JOBS,
+ "Number of parametric jobs exceeds the limit of %d" % MAX_PARAMETRIC_JOBS,
)
result = generateParametricJobs(jobClassAd)
if not result["OK"]:
@@ -172,11 +171,7 @@ async def submit_bulk_jobs(
initialStatus = JobStatus.RECEIVED
initialMinorStatus = "Job accepted"
- for (
- jobDescription
- ) in (
- jobDescList
- ): # jobDescList because there might be a list generated by a parametric job
+ for jobDescription in jobDescList: # jobDescList because there might be a list generated by a parametric job
job_id = await job_db.insert(
jobDescription,
user_info.sub,
@@ -188,9 +183,7 @@ async def submit_bulk_jobs(
user_info.vo,
)
- logging.debug(
- f'Job added to the JobDB", "{job_id} for {fixme_ownerDN}/{fixme_ownerGroup}'
- )
+ logging.debug(f'Job added to the JobDB", "{job_id} for {fixme_ownerDN}/{fixme_ownerGroup}')
# TODO comment out for test just now
# self.jobLoggingDB.addLoggingRecord(
@@ -210,9 +203,7 @@ async def submit_bulk_jobs(
# self.__sendJobsToOptimizationMind(jobIDList)
# return result
- return await asyncio.gather(
- *(job_db.insert(j.owner, j.group, j.vo) for j in job_definitions)
- )
+ return await asyncio.gather(*(job_db.insert(j.owner, j.group, j.vo) for j in job_definitions))
@router.delete("/")
@@ -276,9 +267,7 @@ async def set_status_bulk(job_update: list[JobStatusUpdate]) -> list[JobStatusRe
"description": "Get only job statuses for specific jobs, ordered by status",
"value": {
"parameters": ["JobID", "Status"],
- "search": [
- {"parameter": "JobID", "operator": "in", "values": ["6", "2", "3"]}
- ],
+ "search": [{"parameter": "JobID", "operator": "in", "values": ["6", "2", "3"]}],
"sort": [{"parameter": "JobID", "direction": "asc"}],
},
},
@@ -324,9 +313,7 @@ async def search(
user_info: Annotated[UserInfo, Depends(verify_dirac_token)],
page: int = 0,
per_page: int = 100,
- body: Annotated[
- JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES)
- ] = None,
+ body: Annotated[JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES)] = None,
) -> list[dict[str, Any]]:
"""Retrieve information about jobs.
@@ -344,9 +331,7 @@ async def search(
}
)
# TODO: Pagination
- return await job_db.search(
- body.parameters, body.search, body.sort, page=page, per_page=per_page
- )
+ return await job_db.search(body.parameters, body.search, body.sort, page=page, per_page=per_page)
@router.post("/summary")
diff --git a/tests/client/test_regenerate.py b/tests/client/test_regenerate.py
index 2f0712d7..6c71d61e 100644
--- a/tests/client/test_regenerate.py
+++ b/tests/client/test_regenerate.py
@@ -29,9 +29,7 @@ def test_regenerate_client(test_client, tmp_path):
assert (repo_root / ".git").is_dir()
repo = git.Repo(repo_root)
if repo.is_dirty(path=repo_root / "src" / "diracx" / "client"):
- raise AssertionError(
- "Client is currently in a modified state, skipping regeneration"
- )
+ raise AssertionError("Client is currently in a modified state, skipping regeneration")
cmd = [
"autorest",
diff --git a/tests/conftest.py b/tests/conftest.py
index b53bec1d..212079e4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -75,9 +75,7 @@ def with_app(test_auth_settings, with_config_repo):
"JobDB": "sqlite+aiosqlite:///:memory:",
"AuthDB": "sqlite+aiosqlite:///:memory:",
},
- config_source=ConfigSource.create_from_url(
- backend_url=f"git+file://{with_config_repo}"
- ),
+ config_source=ConfigSource.create_from_url(backend_url=f"git+file://{with_config_repo}"),
)
diff --git a/tests/core/test_secrets.py b/tests/core/test_secrets.py
index 0a521dca..46ae2d23 100644
--- a/tests/core/test_secrets.py
+++ b/tests/core/test_secrets.py
@@ -23,9 +23,7 @@ def test_token_signing_key(tmp_path):
key_file.write_text(private_key_pem)
# Test that we can load a key from a file
- compare_keys(
- parse_obj_as(TokenSigningKey, f"{key_file}").jwk.get_private_key(), private_key
- )
+ compare_keys(parse_obj_as(TokenSigningKey, f"{key_file}").jwk.get_private_key(), private_key)
compare_keys(
parse_obj_as(TokenSigningKey, f"file://{key_file}").jwk.get_private_key(),
private_key,
diff --git a/tests/db/auth/test_authorization_flow.py b/tests/db/auth/test_authorization_flow.py
index 22ffeca3..9bff404e 100644
--- a/tests/db/auth/test_authorization_flow.py
+++ b/tests/db/auth/test_authorization_flow.py
@@ -28,20 +28,14 @@ async def test_insert_id_token(auth_db: AuthDB):
async with auth_db as auth_db:
with pytest.raises(AuthorizationError):
- code, redirect_uri = await auth_db.authorization_flow_insert_id_token(
- uuid, id_token, EXPIRED
- )
- code, redirect_uri = await auth_db.authorization_flow_insert_id_token(
- uuid, id_token, MAX_VALIDITY
- )
+ code, redirect_uri = await auth_db.authorization_flow_insert_id_token(uuid, id_token, EXPIRED)
+ code, redirect_uri = await auth_db.authorization_flow_insert_id_token(uuid, id_token, MAX_VALIDITY)
assert redirect_uri == "redirect_uri"
# Cannot add a id_token a second time
async with auth_db as auth_db:
with pytest.raises(AuthorizationError):
- await auth_db.authorization_flow_insert_id_token(
- uuid, id_token, MAX_VALIDITY
- )
+ await auth_db.authorization_flow_insert_id_token(uuid, id_token, MAX_VALIDITY)
async with auth_db as auth_db:
with pytest.raises(NoResultFound):
@@ -52,9 +46,7 @@ async def test_insert_id_token(auth_db: AuthDB):
# Cannot add a id_token after finishing the flow
async with auth_db as auth_db:
with pytest.raises(AuthorizationError):
- await auth_db.authorization_flow_insert_id_token(
- uuid, id_token, MAX_VALIDITY
- )
+ await auth_db.authorization_flow_insert_id_token(uuid, id_token, MAX_VALIDITY)
# We shouldn't be able to retrieve it twice
async with auth_db as auth_db:
diff --git a/tests/db/auth/test_device_flow.py b/tests/db/auth/test_device_flow.py
index 05d703e2..3b81c165 100644
--- a/tests/db/auth/test_device_flow.py
+++ b/tests/db/auth/test_device_flow.py
@@ -25,9 +25,7 @@ async def test_device_user_code_collision(auth_db: AuthDB, monkeypatch):
# First insert should work
async with auth_db as auth_db:
- code, device = await auth_db.insert_device_flow(
- "client_id", "scope", "audience"
- )
+ code, device = await auth_db.insert_device_flow("client_id", "scope", "audience")
assert code == "A" * USER_CODE_LENGTH
assert device
@@ -38,9 +36,7 @@ async def test_device_user_code_collision(auth_db: AuthDB, monkeypatch):
monkeypatch.setattr(secrets, "choice", lambda _: "B")
async with auth_db as auth_db:
- code, device = await auth_db.insert_device_flow(
- "client_id", "scope", "audience"
- )
+ code, device = await auth_db.insert_device_flow("client_id", "scope", "audience")
assert code == "B" * USER_CODE_LENGTH
assert device
@@ -56,12 +52,8 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch):
# First insert
async with auth_db as auth_db:
- user_code1, device_code1 = await auth_db.insert_device_flow(
- "client_id1", "scope1", "audience1"
- )
- user_code2, device_code2 = await auth_db.insert_device_flow(
- "client_id2", "scope2", "audience2"
- )
+ user_code1, device_code1 = await auth_db.insert_device_flow("client_id1", "scope1", "audience1")
+ user_code2, device_code2 = await auth_db.insert_device_flow("client_id2", "scope2", "audience2")
assert user_code1 != user_code2
@@ -83,19 +75,13 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch):
async with auth_db as auth_db:
with pytest.raises(AuthorizationError):
- await auth_db.device_flow_insert_id_token(
- user_code1, {"token": "mytoken"}, EXPIRED
- )
+ await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken"}, EXPIRED)
- await auth_db.device_flow_insert_id_token(
- user_code1, {"token": "mytoken"}, MAX_VALIDITY
- )
+ await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken"}, MAX_VALIDITY)
# We should not be able to insert a id_token a second time
with pytest.raises(AuthorizationError):
- await auth_db.device_flow_insert_id_token(
- user_code1, {"token": "mytoken2"}, MAX_VALIDITY
- )
+ await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken2"}, MAX_VALIDITY)
with pytest.raises(ExpiredFlowError):
await auth_db.get_device_flow(device_code1, EXPIRED)
@@ -112,17 +98,13 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch):
# Re-adding a token should not work after it's been minted
async with auth_db as auth_db:
with pytest.raises(AuthorizationError):
- await auth_db.device_flow_insert_id_token(
- user_code1, {"token": "mytoken"}, MAX_VALIDITY
- )
+ await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken"}, MAX_VALIDITY)
async def test_device_flow_insert_id_token(auth_db: AuthDB):
# First insert
async with auth_db as auth_db:
- user_code, device_code = await auth_db.insert_device_flow(
- "client_id", "scope", "audience"
- )
+ user_code, device_code = await auth_db.insert_device_flow("client_id", "scope", "audience")
# Make sure it exists, and is Pending
async with auth_db as auth_db:
diff --git a/tests/db/test_dummyDB.py b/tests/db/test_dummyDB.py
index 38c75792..ae322edc 100644
--- a/tests/db/test_dummyDB.py
+++ b/tests/db/test_dummyDB.py
@@ -34,9 +34,7 @@ async def test_insert_and_summary(dummy_db: DummyDB):
assert owner_id
# Add cars, belonging to the same guy
- result = await asyncio.gather(
- *(dummy_db.insert_car(uuid4(), f"model_{i}", owner_id) for i in range(10))
- )
+ result = await asyncio.gather(*(dummy_db.insert_car(uuid4(), f"model_{i}", owner_id) for i in range(10)))
assert result
# Check that there are now 10 cars assigned to a single driver
@@ -47,9 +45,7 @@ async def test_insert_and_summary(dummy_db: DummyDB):
# Test the selection
async with dummy_db as dummy_db:
- result = await dummy_db.summary(
- ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}]
- )
+ result = await dummy_db.summary(["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}])
assert result[0]["count"] == 1
diff --git a/tests/routers/test_auth.py b/tests/routers/test_auth.py
index 8b40cbba..200405a6 100644
--- a/tests/routers/test_auth.py
+++ b/tests/routers/test_auth.py
@@ -68,11 +68,7 @@ async def fake_parse_id_token(raw_id_token: str, audience: str, *args, **kwargs)
async def test_authorization_flow(test_client, auth_httpx_mock: HTTPXMock):
code_verifier = secrets.token_hex()
- code_challenge = (
- base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
- .decode()
- .replace("=", "")
- )
+ code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().replace("=", "")
r = test_client.get(
"/auth/authorize",
@@ -97,9 +93,7 @@ async def test_authorization_flow(test_client, auth_httpx_mock: HTTPXMock):
assert r.status_code == 401, r.text
# Check that an invalid state returns an error
- r = test_client.get(
- redirect_uri, params={"code": "invalid-code", "state": "invalid-state"}
- )
+ r = test_client.get(redirect_uri, params={"code": "invalid-code", "state": "invalid-state"})
assert r.status_code == 400, r.text
assert "Invalid state" in r.text
@@ -173,9 +167,7 @@ async def test_device_flow(test_client, auth_httpx_mock: HTTPXMock):
assert r.status_code == 401, r.text
# Check that an invalid state returns an error
- r = test_client.get(
- redirect_uri, params={"code": "invalid-code", "state": "invalid-state"}
- )
+ r = test_client.get(redirect_uri, params={"code": "invalid-code", "state": "invalid-state"})
assert r.status_code == 400, r.text
assert "Invalid state" in r.text
@@ -237,10 +229,7 @@ def test_parse_scopes(vos, groups, scope, expected):
"DefaultGroup": "lhcb_user",
"IdP": {"URL": "https://idp.invalid", "ClientID": "test-idp"},
"Users": {},
- "Groups": {
- group: {"Properties": ["NormalUser"], "Users": []}
- for group in groups
- },
+ "Groups": {group: {"Properties": ["NormalUser"], "Users": []} for group in groups},
}
for vo in vos
},
@@ -272,10 +261,7 @@ def test_parse_scopes_invalid(vos, groups, scope, expected_error):
"DefaultGroup": "lhcb_user",
"IdP": {"URL": "https://idp.invalid", "ClientID": "test-idp"},
"Users": {},
- "Groups": {
- group: {"Properties": ["NormalUser"], "Users": []}
- for group in groups
- },
+ "Groups": {group: {"Properties": ["NormalUser"], "Users": []} for group in groups},
}
for vo in vos
},
diff --git a/tests/routers/test_job_manager.py b/tests/routers/test_job_manager.py
index f9e9c8ed..7888b646 100644
--- a/tests/routers/test_job_manager.py
+++ b/tests/routers/test_job_manager.py
@@ -115,25 +115,17 @@ def test_insert_and_search(normal_user_client):
r = normal_user_client.post(
"/jobs/search",
- json={
- "search": [{"parameter": "Status", "operator": "eq", "value": "RECEIVED"}]
- },
+ json={"search": [{"parameter": "Status", "operator": "eq", "value": "RECEIVED"}]},
)
assert r.status_code == 200, r.json()
assert [x["JobID"] for x in r.json()] == submitted_job_ids
- r = normal_user_client.post(
- "/jobs/search", json={"parameters": ["JobID", "Status"]}
- )
+ r = normal_user_client.post("/jobs/search", json={"parameters": ["JobID", "Status"]})
assert r.status_code == 200, r.json()
- assert r.json() == [
- {"JobID": jid, "Status": "RECEIVED"} for jid in submitted_job_ids
- ]
+ assert r.json() == [{"JobID": jid, "Status": "RECEIVED"} for jid in submitted_job_ids]
# Test /jobs/summary
- r = normal_user_client.post(
- "/jobs/summary", json={"grouping": ["Status", "OwnerDN"]}
- )
+ r = normal_user_client.post("/jobs/summary", json={"grouping": ["Status", "OwnerDN"]})
assert r.status_code == 200, r.json()
assert r.json() == [{"Status": "RECEIVED", "OwnerDN": "ownerDN", "count": 1}]