diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index cc1c1c8b..00000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,45 +0,0 @@ -.defaults: - only: - - merge_requests - - master@chaen/chrissquare-hack-a-ton - -.defaults-micromamba: - extends: .defaults - image: registry.cern.ch/docker.io/mambaorg/micromamba - before_script: - - micromamba env create --file environment.yml --name test-env - - eval "$(micromamba shell hook --shell=bash)" - - micromamba activate test-env - - pip install git+https://github.com/chaen/DIRAC.git@chris-hack-a-ton - - pip install . - -pre-commit: - extends: .defaults - image: registry.cern.ch/docker.io/library/python:3.11 - variables: - PRE_COMMIT_HOME: ${CI_PROJECT_DIR}/.cache/pre-commit - cache: - paths: - - ${PRE_COMMIT_HOME} - before_script: - - pip install pre-commit - script: - - pre-commit run --all-files - -pytest: - extends: .defaults-micromamba - script: - - pytest . --cov-report=xml:coverage.xml --junitxml=report.xml - coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' - artifacts: - when: always - reports: - junit: report.xml - coverage_report: - coverage_format: cobertura - path: coverage.xml - -mypy: - extends: .defaults-micromamba - script: - - mypy . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3535615..82841c6e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,7 @@ repos: hooks: - id: ruff args: ["--fix"] + exclude: ^(src/diracx/client/) - repo: https://github.com/psf/black rev: 23.7.0 diff --git a/environment.yml b/environment.yml index bbd754d0..8a9a4ef6 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - httpx - isodate - mypy - - pydantic =1.10.10 + - pydantic =1.10.12 - pytest - pytest-asyncio - pytest-cov diff --git a/pyproject.toml b/pyproject.toml index aa503323..840dd447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,15 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] [tool.ruff] -select = ["E", "F", "B", "I", "PLE"] +select = [ + "E", # pycodestyle errrors + "F", # pyflakes + "B", # flake8-bugbear + "I", # isort + "PLE", # pylint errors + "UP", # pyUpgrade + "FLY", # flynt +] ignore = ["B905", "B008", "B006"] line-length = 120 src = ["src", "tests"] @@ -17,6 +25,12 @@ exclude = ["src/diracx/client/"] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. extend-immutable-calls = ["fastapi.Depends", "fastapi.Query", "fastapi.Path", "fastapi.Body", "fastapi.Header"] +[tool.black] +line-length = 120 + +[tool.isort] +profile = "black" + [tool.mypy] plugins = ["sqlalchemy.ext.mypy.plugin", "pydantic.mypy"] exclude = ["^src/diracx/client", "^tests/", "^build/"] diff --git a/src/diracx/cli/__init__.py b/src/diracx/cli/__init__.py index e545229e..1d050fb4 100644 --- a/src/diracx/cli/__init__.py +++ b/src/diracx/cli/__init__.py @@ -1,3 +1,5 @@ +# ruff: noqa: UP007 # because of https://github.com/tiangolo/typer/issues/533 + from __future__ import annotations import asyncio @@ -23,9 +25,7 @@ async def login( vo: str, group: Optional[str] = None, - property: Optional[list[str]] = Option( - None, help="Override the default(s) with one or more properties" - ), + property: Optional[list[str]] = Option(None, help="Override the default(s) with one or more properties"), ): scopes = [f"vo:{vo}"] if group: @@ -61,9 +61,7 @@ async def login( raise RuntimeError("Device authorization flow expired") CREDENTIALS_PATH.parent.mkdir(parents=True, exist_ok=True) - expires = datetime.now(tz=timezone.utc) + timedelta( - seconds=response.expires_in - EXPIRES_GRACE_SECONDS - ) + expires = datetime.now(tz=timezone.utc) + timedelta(seconds=response.expires_in - EXPIRES_GRACE_SECONDS) credential_data = { "access_token": response.access_token, # TODO: "refresh_token": diff --git a/src/diracx/cli/internal.py b/src/diracx/cli/internal.py index 75b9ab0c..a05f65e5 100644 --- a/src/diracx/cli/internal.py +++ b/src/diracx/cli/internal.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import json from pathlib import Path @@ -47,11 +45,7 @@ def generate_cs( IdP=IdpConfig(URL=idp_url, ClientID=idp_client_id), DefaultGroup=user_group, Users={}, - Groups={ - user_group: GroupConfig( - JobShare=None, Properties=["NormalUser"], Quota=None, Users=[] - ) - }, + Groups={user_group: GroupConfig(JobShare=None, Properties=["NormalUser"], Quota=None, Users=[])}, ) config = Config( Registry={vo: registry}, @@ -105,7 +99,5 @@ def add_user( config_data = json.loads(config.json(exclude_unset=True)) yaml_path.write_text(yaml.safe_dump(config_data)) repo.index.add([yaml_path.relative_to(repo_path)]) - repo.index.commit( - f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}" - ) + repo.index.commit(f"Added user {sub} ({preferred_username}) to vo {vo} and user_group {user_group}") typer.echo(f"Successfully added user to {config_repo}", err=True) diff --git a/src/diracx/cli/jobs.py b/src/diracx/cli/jobs.py index 6dcf6eb2..332bee65 100644 --- a/src/diracx/cli/jobs.py +++ b/src/diracx/cli/jobs.py @@ -104,9 +104,5 @@ def display_rich(data, unit: str) -> None: @app.async_command() async def submit(jdl: list[FileText]): async with Dirac(endpoint="http://localhost:8000") as api: - jobs = await api.jobs.submit_bulk_jobs( - [x.read() for x in jdl], headers=get_auth_headers() - ) - print( - f"Inserted {len(jobs)} jobs with ids: {','.join(map(str, (job.job_id for job in jobs)))}" - ) + jobs = await api.jobs.submit_bulk_jobs([x.read() for x in jdl], headers=get_auth_headers()) + print(f"Inserted {len(jobs)} jobs with ids: {','.join(map(str, (job.job_id for job in jobs)))}") diff --git a/src/diracx/client/_client.py b/src/diracx/client/_client.py index 186d43f0..a5f0910f 100644 --- a/src/diracx/client/_client.py +++ b/src/diracx/client/_client.py @@ -40,28 +40,18 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self, *, endpoint: str = "", **kwargs: Any ) -> None: self._config = DiracConfiguration(**kwargs) - self._client: PipelineClient = PipelineClient( - base_url=endpoint, config=self._config, **kwargs - ) + self._client: PipelineClient = PipelineClient(base_url=endpoint, config=self._config, **kwargs) - client_models = { - k: v for k, v in _models.__dict__.items() if isinstance(v, type) - } + client_models = {k: v for k, v in _models.__dict__.items() if isinstance(v, type)} self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) self._serialize.client_side_validation = False - self.well_known = WellKnownOperations( - self._client, self._config, self._serialize, self._deserialize - ) + self.well_known = WellKnownOperations(self._client, self._config, self._serialize, self._deserialize) self.auth = AuthOperations( # pylint: disable=abstract-class-instantiated self._client, self._config, self._serialize, self._deserialize ) - self.config = ConfigOperations( - self._client, self._config, self._serialize, self._deserialize - ) - self.jobs = JobsOperations( - self._client, self._config, self._serialize, self._deserialize - ) + self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) + self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/src/diracx/client/_configuration.py b/src/diracx/client/_configuration.py index b5106fed..0519a33c 100644 --- a/src/diracx/client/_configuration.py +++ b/src/diracx/client/_configuration.py @@ -26,24 +26,12 @@ def __init__(self, **kwargs: Any) -> None: self._configure(**kwargs) def _configure(self, **kwargs: Any) -> None: - self.user_agent_policy = kwargs.get( - "user_agent_policy" - ) or policies.UserAgentPolicy(**kwargs) - self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy( - **kwargs - ) + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) - self.logging_policy = kwargs.get( - "logging_policy" - ) or policies.NetworkTraceLoggingPolicy(**kwargs) - self.http_logging_policy = kwargs.get( - "http_logging_policy" - ) or policies.HttpLoggingPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs) - self.custom_hook_policy = kwargs.get( - "custom_hook_policy" - ) or policies.CustomHookPolicy(**kwargs) - self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy( - **kwargs - ) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs) self.authentication_policy = kwargs.get("authentication_policy") diff --git a/src/diracx/client/_patch.py b/src/diracx/client/_patch.py index d400d2d1..f7dd3251 100644 --- a/src/diracx/client/_patch.py +++ b/src/diracx/client/_patch.py @@ -8,9 +8,7 @@ """ from typing import List -__all__: List[ - str -] = [] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/src/diracx/client/_serialization.py b/src/diracx/client/_serialization.py index 615a169a..1e7a11b1 100644 --- a/src/diracx/client/_serialization.py +++ b/src/diracx/client/_serialization.py @@ -84,9 +84,7 @@ class RawDeserializer: CONTEXT_NAME = "deserialized_data" @classmethod - def deserialize_from_text( - cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None - ) -> Any: + def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None) -> Any: """Decode data according to content-type. Accept a stream of data as well, but will be load at once in memory for now. @@ -148,14 +146,10 @@ def _json_attemp(data): # context otherwise. _LOGGER.critical("Wasn't XML not JSON, failing") raise_with_traceback(DeserializationError, "XML is invalid") - raise DeserializationError( - "Cannot deserialize content-type: {}".format(content_type) - ) + raise DeserializationError("Cannot deserialize content-type: {}".format(content_type)) @classmethod - def deserialize_from_http_generics( - cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping - ) -> Any: + def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping) -> Any: """Deserialize from HTTP response. Use bytes and headers to NOT use any requests/aiohttp or whatever @@ -376,9 +370,7 @@ def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON: def as_dict( self, keep_readonly: bool = True, - key_transformer: Callable[ - [str, Dict[str, Any], Any], Any - ] = attribute_transformer, + key_transformer: Callable[[str, Dict[str, Any], Any], Any] = attribute_transformer, **kwargs: Any ) -> JSON: """Return a dict that can be serialized using json.dump. @@ -412,18 +404,14 @@ def my_key_transformer(key, attr_desc, value): :rtype: dict """ serializer = Serializer(self._infer_class_models()) - return serializer._serialize( - self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs - ) + return serializer._serialize(self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs) @classmethod def _infer_class_models(cls): try: str_models = cls.__module__.rsplit(".", 1)[0] models = sys.modules[str_models] - client_models = { - k: v for k, v in models.__dict__.items() if isinstance(v, type) - } + client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} if cls.__name__ not in client_models: raise ValueError("Not Autorest generated code") except Exception: @@ -432,9 +420,7 @@ def _infer_class_models(cls): return client_models @classmethod - def deserialize( - cls: Type[ModelType], data: Any, content_type: Optional[str] = None - ) -> ModelType: + def deserialize(cls: Type[ModelType], data: Any, content_type: Optional[str] = None) -> ModelType: """Parse a str using the RestAPI syntax and return a model. :param str data: A str using RestAPI structure. JSON by default. @@ -495,13 +481,9 @@ def _classify(cls, response, objects): if not isinstance(response, ET.Element): rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1] - subtype_value = response.pop( - rest_api_response_key, None - ) or response.pop(subtype_key, None) + subtype_value = response.pop(rest_api_response_key, None) or response.pop(subtype_key, None) else: - subtype_value = xml_key_extractor( - subtype_key, cls._attribute_map[subtype_key], response - ) + subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response) if subtype_value: # Try to match base class. Can be class name only # (bug to fix in Autorest to support x-ms-discriminator-name) @@ -629,9 +611,7 @@ def _serialize(self, target_obj, data_type=None, **kwargs): try: is_xml_model_serialization = kwargs["is_xml"] except KeyError: - is_xml_model_serialization = kwargs.setdefault( - "is_xml", target_obj.is_xml_model() - ) + is_xml_model_serialization = kwargs.setdefault("is_xml", target_obj.is_xml_model()) serialized = {} if is_xml_model_serialization: @@ -640,9 +620,7 @@ def _serialize(self, target_obj, data_type=None, **kwargs): attributes = target_obj._attribute_map for attr, attr_desc in attributes.items(): attr_name = attr - if not keep_readonly and target_obj._validation.get(attr_name, {}).get( - "readonly", False - ): + if not keep_readonly and target_obj._validation.get(attr_name, {}).get("readonly", False): continue if attr_name == "additional_properties" and attr_desc["key"] == "": @@ -654,15 +632,11 @@ def _serialize(self, target_obj, data_type=None, **kwargs): if is_xml_model_serialization: pass # Don't provide "transformer" for XML for now. Keep "orig_attr" else: # JSON - keys, orig_attr = key_transformer( - attr, attr_desc.copy(), orig_attr - ) + keys, orig_attr = key_transformer(attr, attr_desc.copy(), orig_attr) keys = keys if isinstance(keys, list) else [keys] kwargs["serialization_ctxt"] = attr_desc - new_attr = self.serialize_data( - orig_attr, attr_desc["type"], **kwargs - ) + new_attr = self.serialize_data(orig_attr, attr_desc["type"], **kwargs) if is_xml_model_serialization: xml_desc = attr_desc.get("xml", {}) @@ -709,9 +683,7 @@ def _serialize(self, target_obj, data_type=None, **kwargs): continue except (AttributeError, KeyError, TypeError) as err: - msg = "Attribute {} in object {} cannot be serialized.\n{}".format( - attr_name, class_name, str(target_obj) - ) + msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj)) raise_with_traceback(SerializationError, msg, err) else: return serialized @@ -733,9 +705,7 @@ def body(self, data, data_type, **kwargs): is_xml_model_serialization = kwargs["is_xml"] except KeyError: if internal_data_type and issubclass(internal_data_type, Model): - is_xml_model_serialization = kwargs.setdefault( - "is_xml", internal_data_type.is_xml_model() - ) + is_xml_model_serialization = kwargs.setdefault("is_xml", internal_data_type.is_xml_model()) else: is_xml_model_serialization = False if internal_data_type and not isinstance(internal_data_type, Enum): @@ -756,9 +726,7 @@ def body(self, data, data_type, **kwargs): ] data = deserializer._deserialize(data_type, data) except DeserializationError as err: - raise_with_traceback( - SerializationError, "Unable to build a model: " + str(err), err - ) + raise_with_traceback(SerializationError, "Unable to build a model: " + str(err), err) return self._serialize(data, data_type, **kwargs) @@ -798,12 +766,7 @@ def query(self, name, data, data_type, **kwargs): # Treat the list aside, since we don't want to encode the div separator if data_type.startswith("["): internal_data_type = data_type[1:-1] - data = [ - self.serialize_data(d, internal_data_type, **kwargs) - if d is not None - else "" - for d in data - ] + data = [self.serialize_data(d, internal_data_type, **kwargs) if d is not None else "" for d in data] if not kwargs.get("skip_quote", False): data = [quote(str(d), safe="") for d in data] return str(self.serialize_iter(data, internal_data_type, **kwargs)) @@ -975,9 +938,7 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs): is_wrapped = xml_desc.get("wrapped", False) node_name = xml_desc.get("itemsName", xml_name) if is_wrapped: - final_result = _create_xml_node( - xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None) - ) + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) else: final_result = [] # All list elements to "local_node" @@ -1009,9 +970,7 @@ def serialize_dict(self, attr, dict_type, **kwargs): serialized = {} for key, value in attr.items(): try: - serialized[self.serialize_unicode(key)] = self.serialize_data( - value, dict_type, **kwargs - ) + serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs) except ValueError: serialized[self.serialize_unicode(key)] = None @@ -1020,9 +979,7 @@ def serialize_dict(self, attr, dict_type, **kwargs): xml_desc = serialization_ctxt["xml"] xml_name = xml_desc["name"] - final_result = _create_xml_node( - xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None) - ) + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) for key, value in serialized.items(): ET.SubElement(final_result, key).text = value return final_result @@ -1068,9 +1025,7 @@ def serialize_object(self, attr, **kwargs): serialized = {} for key, value in attr.items(): try: - serialized[self.serialize_unicode(key)] = self.serialize_object( - value, **kwargs - ) + serialized[self.serialize_unicode(key)] = self.serialize_object(value, **kwargs) except ValueError: serialized[self.serialize_unicode(key)] = None return serialized @@ -1287,9 +1242,7 @@ def rest_key_case_insensitive_extractor(attr, attr_desc, data): key = _decode_attribute_map_key(dict_keys[0]) break working_key = _decode_attribute_map_key(dict_keys[0]) - working_data = attribute_key_case_insensitive_extractor( - working_key, None, working_data - ) + working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data) if working_data is None: # If at any point while following flatten JSON path see None, it means # that all properties under are None as well @@ -1382,10 +1335,7 @@ def xml_key_extractor(attr, attr_desc, data): # - Wrapped node # - Internal type is an enum (considered basic types) # - Internal type has no XML/Name node - if is_wrapped or ( - internal_type - and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map) - ): + if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map)): children = data.findall(xml_name) # If internal type has a local name and it's not a list, I use that name elif not is_iter_type and internal_type and "name" in internal_type_xml_map: @@ -1393,9 +1343,7 @@ def xml_key_extractor(attr, attr_desc, data): children = data.findall(xml_name) # That's an array else: - if ( - internal_type - ): # Complex type, ignore itemsName and use the complex type name + if internal_type: # Complex type, ignore itemsName and use the complex type name items_name = _extract_name_from_internal_type(internal_type) else: items_name = xml_desc.get("itemsName", xml_name) @@ -1424,9 +1372,7 @@ def xml_key_extractor(attr, attr_desc, data): # Here it's not a itertype, we should have found one element only or empty if len(children) > 1: - raise DeserializationError( - "Find several XML '{}' where it was not expected".format(xml_name) - ) + raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name)) return children[0] @@ -1439,9 +1385,7 @@ class Deserializer(object): basic_types = {str: "str", int: "int", bool: "bool", float: "float"} - valid_date = re.compile( - r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?" - ) + valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") def __init__(self, classes: Optional[Mapping[str, Type[ModelType]]] = None): self.deserialize_type = { @@ -1497,11 +1441,7 @@ def _deserialize(self, target_obj, data): """ # This is already a model, go recursive just in case if hasattr(data, "_attribute_map"): - constants = [ - name - for name, config in getattr(data, "_validation", {}).items() - if config.get("constant") - ] + constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")] try: for attr, mapconfig in data._attribute_map.items(): if attr in constants: @@ -1511,9 +1451,7 @@ def _deserialize(self, target_obj, data): continue local_type = mapconfig["type"] internal_data_type = local_type.strip("[]{}") - if internal_data_type not in self.dependencies or isinstance( - internal_data_type, Enum - ): + if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum): continue setattr(data, attr, self._deserialize(local_type, value)) return data @@ -1567,10 +1505,7 @@ def _deserialize(self, target_obj, data): def _build_additional_properties(self, attribute_map, data): if not self.additional_properties_detection: return None - if ( - "additional_properties" in attribute_map - and attribute_map.get("additional_properties", {}).get("key") != "" - ): + if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != "": # Check empty string. If it's not empty, someone has a real "additionalProperties" return None if isinstance(data, ET.Element): @@ -1650,21 +1585,15 @@ def _unpack_content(raw_data, content_type=None): if context: if RawDeserializer.CONTEXT_NAME in context: return context[RawDeserializer.CONTEXT_NAME] - raise ValueError( - "This pipeline didn't have the RawDeserializer policy; can't deserialize" - ) + raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize") # Assume this is enough to recognize universal_http.ClientResponse without importing it if hasattr(raw_data, "body"): - return RawDeserializer.deserialize_from_http_generics( - raw_data.text(), raw_data.headers - ) + return RawDeserializer.deserialize_from_http_generics(raw_data.text(), raw_data.headers) # Assume this enough to recognize requests.Response without importing it. if hasattr(raw_data, "_content_consumed"): - return RawDeserializer.deserialize_from_http_generics( - raw_data.text, raw_data.headers - ) + return RawDeserializer.deserialize_from_http_generics(raw_data.text, raw_data.headers) if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, "read"): return RawDeserializer.deserialize_from_text(raw_data, content_type) # type: ignore @@ -1679,17 +1608,9 @@ def _instantiate_model(self, response, attrs, additional_properties=None): if callable(response): subtype = getattr(response, "_subtype_map", {}) try: - readonly = [ - k for k, v in response._validation.items() if v.get("readonly") - ] - const = [ - k for k, v in response._validation.items() if v.get("constant") - ] - kwargs = { - k: v - for k, v in attrs.items() - if k not in subtype and k not in readonly + const - } + readonly = [k for k, v in response._validation.items() if v.get("readonly")] + const = [k for k, v in response._validation.items() if v.get("constant")] + kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const} response_obj = response(**kwargs) for attr in readonly: setattr(response_obj, attr, attrs.get(attr)) @@ -1726,17 +1647,11 @@ def deserialize_data(self, data, data_type): if data_type in self.basic_types.values(): return self.deserialize_basic(data, data_type) if data_type in self.deserialize_type: - if isinstance( - data, self.deserialize_expected_types.get(data_type, tuple()) - ): + if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): return data is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] - if ( - isinstance(data, ET.Element) - and is_a_text_parsing_type(data_type) - and not data.text - ): + if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: return None data_val = self.deserialize_type[data_type](data) return data_val @@ -1767,16 +1682,10 @@ def deserialize_iter(self, attr, iter_type): """ if attr is None: return None - if isinstance( - attr, ET.Element - ): # If I receive an element here, get the children + if isinstance(attr, ET.Element): # If I receive an element here, get the children attr = list(attr) if not isinstance(attr, (list, set)): - raise DeserializationError( - "Cannot deserialize as [{}] an object of type {}".format( - iter_type, type(attr) - ) - ) + raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format(iter_type, type(attr))) return [self.deserialize_data(a, iter_type) for a in attr] def deserialize_dict(self, attr, dict_type): @@ -1788,9 +1697,7 @@ def deserialize_dict(self, attr, dict_type): :rtype: dict """ if isinstance(attr, list): - return { - x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr - } + return {x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr} if isinstance(attr, ET.Element): # Transform value into {"Key": "value"} @@ -2022,9 +1929,7 @@ def deserialize_date(attr): if isinstance(attr, ET.Element): attr = attr.text if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore - raise DeserializationError( - "Date must have only digits and -. Received: %s" % attr - ) + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. return isodate.parse_date(attr, defaultmonth=None, defaultday=None) @@ -2039,9 +1944,7 @@ def deserialize_time(attr): if isinstance(attr, ET.Element): attr = attr.text if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore - raise DeserializationError( - "Date must have only digits and -. Received: %s" % attr - ) + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) return isodate.parse_time(attr) @staticmethod @@ -2057,10 +1960,7 @@ def deserialize_rfc(attr): try: parsed_date = email.utils.parsedate_tz(attr) # type: ignore date_obj = datetime.datetime( - *parsed_date[:6], - tzinfo=_FixedOffset( - datetime.timedelta(minutes=(parsed_date[9] or 0) / 60) - ) + *parsed_date[:6], tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0) / 60)) ) if not date_obj.tzinfo: date_obj = date_obj.astimezone(tz=TZ_UTC) diff --git a/src/diracx/client/_vendor.py b/src/diracx/client/_vendor.py index c5b16f59..7b62dcec 100644 --- a/src/diracx/client/_vendor.py +++ b/src/diracx/client/_vendor.py @@ -14,16 +14,12 @@ def _format_url_section(template, **kwargs): except KeyError as key: # Need the cast, as for some reasons "split" is typed as list[str | Any] formatted_components = cast(List[str], template.split("/")) - components = [ - c for c in formatted_components if "{}".format(key.args[0]) not in c - ] + components = [c for c in formatted_components if "{}".format(key.args[0]) not in c] template = "/".join(components) def raise_if_not_implemented(cls, abstract_methods): - not_implemented = [ - f for f in abstract_methods if not callable(getattr(cls, f, None)) - ] + not_implemented = [f for f in abstract_methods if not callable(getattr(cls, f, None))] if not_implemented: raise NotImplementedError( "The following methods on operation group '{}' are not implemented: '{}'." diff --git a/src/diracx/client/aio/_client.py b/src/diracx/client/aio/_client.py index 32f0e07d..684725e4 100644 --- a/src/diracx/client/aio/_client.py +++ b/src/diracx/client/aio/_client.py @@ -40,32 +40,20 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self, *, endpoint: str = "", **kwargs: Any ) -> None: self._config = DiracConfiguration(**kwargs) - self._client: AsyncPipelineClient = AsyncPipelineClient( - base_url=endpoint, config=self._config, **kwargs - ) + self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=endpoint, config=self._config, **kwargs) - client_models = { - k: v for k, v in _models.__dict__.items() if isinstance(v, type) - } + client_models = {k: v for k, v in _models.__dict__.items() if isinstance(v, type)} self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) self._serialize.client_side_validation = False - self.well_known = WellKnownOperations( - self._client, self._config, self._serialize, self._deserialize - ) + self.well_known = WellKnownOperations(self._client, self._config, self._serialize, self._deserialize) self.auth = AuthOperations( # pylint: disable=abstract-class-instantiated self._client, self._config, self._serialize, self._deserialize ) - self.config = ConfigOperations( - self._client, self._config, self._serialize, self._deserialize - ) - self.jobs = JobsOperations( - self._client, self._config, self._serialize, self._deserialize - ) + self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) + self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) - def send_request( - self, request: HttpRequest, **kwargs: Any - ) -> Awaitable[AsyncHttpResponse]: + def send_request(self, request: HttpRequest, **kwargs: Any) -> Awaitable[AsyncHttpResponse]: """Runs the network request through the client's chained policies. >>> from azure.core.rest import HttpRequest diff --git a/src/diracx/client/aio/_configuration.py b/src/diracx/client/aio/_configuration.py index 10d4b962..07f6f94a 100644 --- a/src/diracx/client/aio/_configuration.py +++ b/src/diracx/client/aio/_configuration.py @@ -26,26 +26,12 @@ def __init__(self, **kwargs: Any) -> None: self._configure(**kwargs) def _configure(self, **kwargs: Any) -> None: - self.user_agent_policy = kwargs.get( - "user_agent_policy" - ) or policies.UserAgentPolicy(**kwargs) - self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy( - **kwargs - ) + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) - self.logging_policy = kwargs.get( - "logging_policy" - ) or policies.NetworkTraceLoggingPolicy(**kwargs) - self.http_logging_policy = kwargs.get( - "http_logging_policy" - ) or policies.HttpLoggingPolicy(**kwargs) - self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy( - **kwargs - ) - self.custom_hook_policy = kwargs.get( - "custom_hook_policy" - ) or policies.CustomHookPolicy(**kwargs) - self.redirect_policy = kwargs.get( - "redirect_policy" - ) or policies.AsyncRedirectPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs) self.authentication_policy = kwargs.get("authentication_policy") diff --git a/src/diracx/client/aio/_patch.py b/src/diracx/client/aio/_patch.py index d400d2d1..f7dd3251 100644 --- a/src/diracx/client/aio/_patch.py +++ b/src/diracx/client/aio/_patch.py @@ -8,9 +8,7 @@ """ from typing import List -__all__: List[ - str -] = [] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/src/diracx/client/aio/_vendor.py b/src/diracx/client/aio/_vendor.py index 9889eb3a..3ada94d2 100644 --- a/src/diracx/client/aio/_vendor.py +++ b/src/diracx/client/aio/_vendor.py @@ -5,9 +5,7 @@ def raise_if_not_implemented(cls, abstract_methods): - not_implemented = [ - f for f in abstract_methods if not callable(getattr(cls, f, None)) - ] + not_implemented = [f for f in abstract_methods if not callable(getattr(cls, f, None))] if not_implemented: raise NotImplementedError( "The following methods on operation group '{}' are not implemented: '{}'." diff --git a/src/diracx/client/aio/operations/_operations.py b/src/diracx/client/aio/operations/_operations.py index a481b679..06da8589 100644 --- a/src/diracx/client/aio/operations/_operations.py +++ b/src/diracx/client/aio/operations/_operations.py @@ -52,9 +52,7 @@ else: from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports T = TypeVar("T") -ClsType = Optional[ - Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any] -] +ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]] JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object @@ -75,9 +73,7 @@ def __init__(self, *args, **kwargs) -> None: self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") @distributed_trace_async async def openid_configuration(self, **kwargs: Any) -> Any: @@ -109,18 +105,14 @@ async def openid_configuration(self, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -148,9 +140,7 @@ def __init__(self, *args, **kwargs) -> None: self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") raise_if_not_implemented( self.__class__, [ @@ -198,18 +188,14 @@ async def do_device_flow(self, *, user_code: str, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -266,23 +252,17 @@ async def initiate_device_flow( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) - deserialized = self._deserialize( - "InitiateDeviceFlowResponse", pipeline_response - ) + deserialized = self._deserialize("InitiateDeviceFlowResponse", pipeline_response) if cls: return cls(pipeline_response, deserialized, {}) @@ -329,18 +309,14 @@ async def finish_device_flow(self, *, code: str, state: str, **kwargs: Any) -> A request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -380,18 +356,14 @@ async def finished(self, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -463,18 +435,14 @@ async def authorization_flow( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -485,9 +453,7 @@ async def authorization_flow( return deserialized @distributed_trace_async - async def authorization_flow_complete( - self, *, code: str, state: str, **kwargs: Any - ) -> Any: + async def authorization_flow_complete(self, *, code: str, state: str, **kwargs: Any) -> Any: """Authorization Flow Complete. Authorization Flow Complete. @@ -522,18 +488,14 @@ async def authorization_flow_complete( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -561,18 +523,11 @@ def __init__(self, *args, **kwargs) -> None: self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") @distributed_trace_async async def serve_config( - self, - vo: str, - *, - if_none_match: Optional[str] = None, - if_modified_since: Optional[str] = None, - **kwargs: Any + self, vo: str, *, if_none_match: Optional[str] = None, if_modified_since: Optional[str] = None, **kwargs: Any ) -> Any: """Serve Config. @@ -617,18 +572,14 @@ async def serve_config( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -656,9 +607,7 @@ def __init__(self, *args, **kwargs) -> None: self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") @overload async def submit_bulk_jobs( @@ -697,9 +646,7 @@ async def submit_bulk_jobs( """ @distributed_trace_async - async def submit_bulk_jobs( - self, body: Union[List[str], IO], **kwargs: Any - ) -> List[_models.InsertedJob]: + async def submit_bulk_jobs(self, body: Union[List[str], IO], **kwargs: Any) -> List[_models.InsertedJob]: """Submit Bulk Jobs. Submit Bulk Jobs. @@ -724,9 +671,7 @@ async def submit_bulk_jobs( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[List[_models.InsertedJob]] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -747,18 +692,14 @@ async def submit_bulk_jobs( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("[InsertedJob]", pipeline_response) @@ -801,18 +742,14 @@ async def delete_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -855,18 +792,14 @@ async def get_single_job(self, job_id: int, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -909,18 +842,14 @@ async def delete_single_job(self, job_id: int, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -963,18 +892,14 @@ async def kill_single_job(self, job_id: int, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -985,9 +910,7 @@ async def kill_single_job(self, job_id: int, **kwargs: Any) -> Any: return deserialized @distributed_trace_async - async def get_single_job_status( - self, job_id: int, **kwargs: Any - ) -> Union[str, _models.JobStatus]: + async def get_single_job_status(self, job_id: int, **kwargs: Any) -> Union[str, _models.JobStatus]: """Get Single Job Status. Get Single Job Status. @@ -1019,18 +942,14 @@ async def get_single_job_status( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("str", pipeline_response) @@ -1041,9 +960,7 @@ async def get_single_job_status( return deserialized @distributed_trace_async - async def set_single_job_status( - self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any - ) -> Any: + async def set_single_job_status(self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any) -> Any: """Set Single Job Status. Set Single Job Status. @@ -1079,18 +996,14 @@ async def set_single_job_status( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1133,18 +1046,14 @@ async def kill_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1187,18 +1096,14 @@ async def get_bulk_job_status(self, *, job_ids: List[int], **kwargs: Any) -> Any request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1210,11 +1115,7 @@ async def get_bulk_job_status(self, *, job_ids: List[int], **kwargs: Any) -> Any @overload async def set_status_bulk( - self, - body: List[_models.JobStatusUpdate], - *, - content_type: str = "application/json", - **kwargs: Any + self, body: List[_models.JobStatusUpdate], *, content_type: str = "application/json", **kwargs: Any ) -> List[_models.JobStatusReturn]: """Set Status Bulk. @@ -1276,9 +1177,7 @@ async def set_status_bulk( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[List[_models.JobStatusReturn]] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -1299,18 +1198,14 @@ async def set_status_bulk( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("[JobStatusReturn]", pipeline_response) @@ -1419,9 +1314,7 @@ async def search( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[List[JSON]] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -1447,18 +1340,14 @@ async def search( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("[object]", pipeline_response) @@ -1470,11 +1359,7 @@ async def search( @overload async def summary( - self, - body: _models.JobSummaryParams, - *, - content_type: str = "application/json", - **kwargs: Any + self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any ) -> Any: """Summary. @@ -1491,9 +1376,7 @@ async def summary( """ @overload - async def summary( - self, body: IO, *, content_type: str = "application/json", **kwargs: Any - ) -> Any: + async def summary(self, body: IO, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. @@ -1509,9 +1392,7 @@ async def summary( """ @distributed_trace_async - async def summary( - self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any - ) -> Any: + async def summary(self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. @@ -1536,9 +1417,7 @@ async def summary( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[Any] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -1559,18 +1438,14 @@ async def summary( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) diff --git a/src/diracx/client/aio/operations/_patch.py b/src/diracx/client/aio/operations/_patch.py index 0fd71c5d..23c31718 100644 --- a/src/diracx/client/aio/operations/_patch.py +++ b/src/diracx/client/aio/operations/_patch.py @@ -69,10 +69,8 @@ async def token( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - await self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response diff --git a/src/diracx/client/models/_enums.py b/src/diracx/client/models/_enums.py index e782e7e6..825e0464 100644 --- a/src/diracx/client/models/_enums.py +++ b/src/diracx/client/models/_enums.py @@ -17,9 +17,7 @@ class Enum0(str, Enum, metaclass=CaseInsensitiveEnumMeta): class Enum1(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Enum1.""" - URN_IETF_PARAMS_OAUTH_GRANT_TYPE_DEVICE_CODE = ( - "urn:ietf:params:oauth:grant-type:device_code" - ) + URN_IETF_PARAMS_OAUTH_GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" class Enum2(str, Enum, metaclass=CaseInsensitiveEnumMeta): diff --git a/src/diracx/client/models/_models.py b/src/diracx/client/models/_models.py index 7a776f88..0a7c9c81 100644 --- a/src/diracx/client/models/_models.py +++ b/src/diracx/client/models/_models.py @@ -105,9 +105,7 @@ class HTTPValidationError(_serialization.Model): "detail": {"key": "detail", "type": "[ValidationError]"}, } - def __init__( - self, *, detail: Optional[List["_models.ValidationError"]] = None, **kwargs: Any - ) -> None: + def __init__(self, *, detail: Optional[List["_models.ValidationError"]] = None, **kwargs: Any) -> None: """ :keyword detail: Detail. :paramtype detail: list[~client.models.ValidationError] @@ -212,13 +210,7 @@ class InsertedJob(_serialization.Model): } def __init__( - self, - *, - job_id: int, - status: str, - minor_status: str, - time_stamp: datetime.datetime, - **kwargs: Any + self, *, job_id: int, status: str, minor_status: str, time_stamp: datetime.datetime, **kwargs: Any ) -> None: """ :keyword job_id: Jobid. Required. @@ -308,9 +300,7 @@ class JobStatusReturn(_serialization.Model): "status": {"key": "status", "type": "str"}, } - def __init__( - self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any - ) -> None: + def __init__(self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any) -> None: """ :keyword job_id: Job Id. Required. :paramtype job_id: int @@ -345,9 +335,7 @@ class JobStatusUpdate(_serialization.Model): "status": {"key": "status", "type": "str"}, } - def __init__( - self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any - ) -> None: + def __init__(self, *, job_id: int, status: Union[str, "_models.JobStatus"], **kwargs: Any) -> None: """ :keyword job_id: Job Id. Required. :paramtype job_id: int @@ -381,11 +369,7 @@ class JobSummaryParams(_serialization.Model): } def __init__( - self, - *, - grouping: List[str], - search: List["_models.JobSummaryParamsSearchItem"] = [], - **kwargs: Any + self, *, grouping: List[str], search: List["_models.JobSummaryParamsSearchItem"] = [], **kwargs: Any ) -> None: """ :keyword grouping: Grouping. Required. @@ -435,12 +419,7 @@ class ScalarSearchSpec(_serialization.Model): } def __init__( - self, - *, - parameter: str, - operator: Union[str, "_models.ScalarSearchOperator"], - value: str, - **kwargs: Any + self, *, parameter: str, operator: Union[str, "_models.ScalarSearchOperator"], value: str, **kwargs: Any ) -> None: """ :keyword parameter: Parameter. Required. @@ -478,9 +457,7 @@ class SortSpec(_serialization.Model): "direction": {"key": "direction", "type": "SortSpecDirection"}, } - def __init__( - self, *, parameter: str, direction: "_models.SortSpecDirection", **kwargs: Any - ) -> None: + def __init__(self, *, parameter: str, direction: "_models.SortSpecDirection", **kwargs: Any) -> None: """ :keyword parameter: Parameter. Required. :paramtype parameter: str @@ -527,9 +504,7 @@ class TokenResponse(_serialization.Model): "state": {"key": "state", "type": "str"}, } - def __init__( - self, *, access_token: str, expires_in: int, state: str, **kwargs: Any - ) -> None: + def __init__(self, *, access_token: str, expires_in: int, state: str, **kwargs: Any) -> None: """ :keyword access_token: Access Token. Required. :paramtype access_token: str @@ -569,14 +544,7 @@ class ValidationError(_serialization.Model): "type": {"key": "type", "type": "str"}, } - def __init__( - self, - *, - loc: List["_models.ValidationErrorLocItem"], - msg: str, - type: str, - **kwargs: Any - ) -> None: + def __init__(self, *, loc: List["_models.ValidationErrorLocItem"], msg: str, type: str, **kwargs: Any) -> None: """ :keyword loc: Location. Required. :paramtype loc: list[~client.models.ValidationErrorLocItem] @@ -627,12 +595,7 @@ class VectorSearchSpec(_serialization.Model): } def __init__( - self, - *, - parameter: str, - operator: Union[str, "_models.VectorSearchOperator"], - values: List[str], - **kwargs: Any + self, *, parameter: str, operator: Union[str, "_models.VectorSearchOperator"], values: List[str], **kwargs: Any ) -> None: """ :keyword parameter: Parameter. Required. diff --git a/src/diracx/client/operations/_operations.py b/src/diracx/client/operations/_operations.py index 68033e18..5dba2bbf 100644 --- a/src/diracx/client/operations/_operations.py +++ b/src/diracx/client/operations/_operations.py @@ -31,9 +31,7 @@ else: from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports T = TypeVar("T") -ClsType = Optional[ - Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any] -] +ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]] JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object _SERIALIZER = Serializer() @@ -71,14 +69,10 @@ def build_auth_do_device_flow_request(*, user_code: str, **kwargs: Any) -> HttpR # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="GET", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) -def build_auth_initiate_device_flow_request( - *, client_id: str, scope: str, audience: str, **kwargs: Any -) -> HttpRequest: +def build_auth_initiate_device_flow_request(*, client_id: str, scope: str, audience: str, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -95,14 +89,10 @@ def build_auth_initiate_device_flow_request( # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="POST", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) -def build_auth_finish_device_flow_request( - *, code: str, state: str, **kwargs: Any -) -> HttpRequest: +def build_auth_finish_device_flow_request(*, code: str, state: str, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -118,9 +108,7 @@ def build_auth_finish_device_flow_request( # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="GET", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) def build_auth_finished_request(**kwargs: Any) -> HttpRequest: @@ -158,12 +146,8 @@ def build_auth_authorization_flow_request( # Construct parameters _params["response_type"] = _SERIALIZER.query("response_type", response_type, "str") - _params["code_challenge"] = _SERIALIZER.query( - "code_challenge", code_challenge, "str" - ) - _params["code_challenge_method"] = _SERIALIZER.query( - "code_challenge_method", code_challenge_method, "str" - ) + _params["code_challenge"] = _SERIALIZER.query("code_challenge", code_challenge, "str") + _params["code_challenge_method"] = _SERIALIZER.query("code_challenge_method", code_challenge_method, "str") _params["client_id"] = _SERIALIZER.query("client_id", client_id, "str") _params["redirect_uri"] = _SERIALIZER.query("redirect_uri", redirect_uri, "str") _params["scope"] = _SERIALIZER.query("scope", scope, "str") @@ -172,9 +156,7 @@ def build_auth_authorization_flow_request( # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="GET", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) def build_auth_authorization_flow_complete_request( # pylint: disable=name-too-long @@ -195,9 +177,7 @@ def build_auth_authorization_flow_complete_request( # pylint: disable=name-too- # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="GET", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) def build_config_serve_config_request( @@ -221,13 +201,9 @@ def build_config_serve_config_request( # Construct headers if if_none_match is not None: - _headers["if-none-match"] = _SERIALIZER.header( - "if_none_match", if_none_match, "str" - ) + _headers["if-none-match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") if if_modified_since is not None: - _headers["if-modified-since"] = _SERIALIZER.header( - "if_modified_since", if_modified_since, "str" - ) + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) @@ -236,9 +212,7 @@ def build_config_serve_config_request( def build_jobs_submit_bulk_jobs_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -246,17 +220,13 @@ def build_jobs_submit_bulk_jobs_request(**kwargs: Any) -> HttpRequest: # Construct headers if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header( - "content_type", content_type, "str" - ) + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) -def build_jobs_delete_bulk_jobs_request( - *, job_ids: List[int], **kwargs: Any -) -> HttpRequest: +def build_jobs_delete_bulk_jobs_request(*, job_ids: List[int], **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -271,9 +241,7 @@ def build_jobs_delete_bulk_jobs_request( # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="DELETE", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="DELETE", url=_url, params=_params, headers=_headers, **kwargs) def build_jobs_get_single_job_request(job_id: int, **kwargs: Any) -> HttpRequest: @@ -374,14 +342,10 @@ def build_jobs_set_single_job_status_request( # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="POST", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) -def build_jobs_kill_bulk_jobs_request( - *, job_ids: List[int], **kwargs: Any -) -> HttpRequest: +def build_jobs_kill_bulk_jobs_request(*, job_ids: List[int], **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -396,14 +360,10 @@ def build_jobs_kill_bulk_jobs_request( # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="POST", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) -def build_jobs_get_bulk_job_status_request( - *, job_ids: List[int], **kwargs: Any -) -> HttpRequest: +def build_jobs_get_bulk_job_status_request(*, job_ids: List[int], **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -418,17 +378,13 @@ def build_jobs_get_bulk_job_status_request( # Construct headers _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="GET", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) def build_jobs_set_status_bulk_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -436,23 +392,17 @@ def build_jobs_set_status_bulk_request(**kwargs: Any) -> HttpRequest: # Construct headers if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header( - "content_type", content_type, "str" - ) + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) -def build_jobs_search_request( - *, page: int = 0, per_page: int = 100, **kwargs: Any -) -> HttpRequest: +def build_jobs_search_request(*, page: int = 0, per_page: int = 100, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -466,22 +416,16 @@ def build_jobs_search_request( # Construct headers if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header( - "content_type", content_type, "str" - ) + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - return HttpRequest( - method="POST", url=_url, params=_params, headers=_headers, **kwargs - ) + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) def build_jobs_summary_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -489,9 +433,7 @@ def build_jobs_summary_request(**kwargs: Any) -> HttpRequest: # Construct headers if content_type is not None: - _headers["Content-Type"] = _SERIALIZER.header( - "content_type", content_type, "str" - ) + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) @@ -514,9 +456,7 @@ def __init__(self, *args, **kwargs): self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") @distributed_trace def openid_configuration(self, **kwargs: Any) -> Any: @@ -548,18 +488,14 @@ def openid_configuration(self, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -587,9 +523,7 @@ def __init__(self, *args, **kwargs): self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") raise_if_not_implemented( self.__class__, [ @@ -637,18 +571,14 @@ def do_device_flow(self, *, user_code: str, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -705,23 +635,17 @@ def initiate_device_flow( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) - deserialized = self._deserialize( - "InitiateDeviceFlowResponse", pipeline_response - ) + deserialized = self._deserialize("InitiateDeviceFlowResponse", pipeline_response) if cls: return cls(pipeline_response, deserialized, {}) @@ -768,18 +692,14 @@ def finish_device_flow(self, *, code: str, state: str, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -819,18 +739,14 @@ def finished(self, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -902,18 +818,14 @@ def authorization_flow( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -924,9 +836,7 @@ def authorization_flow( return deserialized @distributed_trace - def authorization_flow_complete( - self, *, code: str, state: str, **kwargs: Any - ) -> Any: + def authorization_flow_complete(self, *, code: str, state: str, **kwargs: Any) -> Any: """Authorization Flow Complete. Authorization Flow Complete. @@ -961,18 +871,14 @@ def authorization_flow_complete( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1000,9 +906,7 @@ def __init__(self, *args, **kwargs): self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") @distributed_trace def serve_config( @@ -1056,18 +960,14 @@ def serve_config( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1095,9 +995,7 @@ def __init__(self, *args, **kwargs): self._client = input_args.pop(0) if input_args else kwargs.pop("client") self._config = input_args.pop(0) if input_args else kwargs.pop("config") self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer") - self._deserialize = ( - input_args.pop(0) if input_args else kwargs.pop("deserializer") - ) + self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer") @overload def submit_bulk_jobs( @@ -1136,9 +1034,7 @@ def submit_bulk_jobs( """ @distributed_trace - def submit_bulk_jobs( - self, body: Union[List[str], IO], **kwargs: Any - ) -> List[_models.InsertedJob]: + def submit_bulk_jobs(self, body: Union[List[str], IO], **kwargs: Any) -> List[_models.InsertedJob]: """Submit Bulk Jobs. Submit Bulk Jobs. @@ -1163,9 +1059,7 @@ def submit_bulk_jobs( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[List[_models.InsertedJob]] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -1186,18 +1080,14 @@ def submit_bulk_jobs( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("[InsertedJob]", pipeline_response) @@ -1240,18 +1130,14 @@ def delete_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1294,18 +1180,14 @@ def get_single_job(self, job_id: int, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1348,18 +1230,14 @@ def delete_single_job(self, job_id: int, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1402,18 +1280,14 @@ def kill_single_job(self, job_id: int, **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1424,9 +1298,7 @@ def kill_single_job(self, job_id: int, **kwargs: Any) -> Any: return deserialized @distributed_trace - def get_single_job_status( - self, job_id: int, **kwargs: Any - ) -> Union[str, _models.JobStatus]: + def get_single_job_status(self, job_id: int, **kwargs: Any) -> Union[str, _models.JobStatus]: """Get Single Job Status. Get Single Job Status. @@ -1458,18 +1330,14 @@ def get_single_job_status( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("str", pipeline_response) @@ -1480,9 +1348,7 @@ def get_single_job_status( return deserialized @distributed_trace - def set_single_job_status( - self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any - ) -> Any: + def set_single_job_status(self, job_id: int, *, status: Union[str, _models.JobStatus], **kwargs: Any) -> Any: """Set Single Job Status. Set Single Job Status. @@ -1518,18 +1384,14 @@ def set_single_job_status( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1572,18 +1434,14 @@ def kill_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1626,18 +1484,14 @@ def get_bulk_job_status(self, *, job_ids: List[int], **kwargs: Any) -> Any: request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) @@ -1715,9 +1569,7 @@ def set_status_bulk( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[List[_models.JobStatusReturn]] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -1738,18 +1590,14 @@ def set_status_bulk( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("[JobStatusReturn]", pipeline_response) @@ -1858,9 +1706,7 @@ def search( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[List[JSON]] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -1886,18 +1732,14 @@ def search( request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("[object]", pipeline_response) @@ -1930,9 +1772,7 @@ def summary( """ @overload - def summary( - self, body: IO, *, content_type: str = "application/json", **kwargs: Any - ) -> Any: + def summary(self, body: IO, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. @@ -1973,9 +1813,7 @@ def summary(self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any) -> A _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = kwargs.pop("params", {}) or {} - content_type: Optional[str] = kwargs.pop( - "content_type", _headers.pop("Content-Type", None) - ) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) cls: ClsType[Any] = kwargs.pop("cls", None) content_type = content_type or "application/json" @@ -1996,18 +1834,14 @@ def summary(self, body: Union[_models.JobSummaryParams, IO], **kwargs: Any) -> A request.url = self._client.format_url(request.url) _stream = False - pipeline_response: PipelineResponse = ( - self._client._pipeline.run( # pylint: disable=protected-access - request, stream=_stream, **kwargs - ) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs ) response = pipeline_response.http_response if response.status_code not in [200]: - map_error( - status_code=response.status_code, response=response, error_map=error_map - ) + map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response) deserialized = self._deserialize("object", pipeline_response) diff --git a/src/diracx/client/operations/_patch.py b/src/diracx/client/operations/_patch.py index b2f097ba..2bdd9763 100644 --- a/src/diracx/client/operations/_patch.py +++ b/src/diracx/client/operations/_patch.py @@ -15,9 +15,7 @@ from .. import models as _models from ._operations import AuthOperations as AuthOperationsGenerated -__all__: List[str] = [ - "AuthOperations" -] # Add all objects you want publicly available to users at this package level +__all__: List[str] = ["AuthOperations"] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/src/diracx/core/config/__init__.py b/src/diracx/core/config/__init__.py index a09784e3..ef10e5ae 100644 --- a/src/diracx/core/config/__init__.py +++ b/src/diracx/core/config/__init__.py @@ -96,9 +96,7 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: repo_location = Path(backend_url.path) self.repo_location = repo_location self.repo = git.Repo(repo_location) - self._latest_revision_cache: Cache = TTLCache( - MAX_CS_CACHED_VERSIONS, DEFAULT_CS_CACHE_TTL - ) + self._latest_revision_cache: Cache = TTLCache(MAX_CS_CACHED_VERSIONS, DEFAULT_CS_CACHE_TTL) self._read_raw_cache: Cache = LRUCache(MAX_CS_CACHED_VERSIONS) def __hash__(self): @@ -115,9 +113,7 @@ def latest_revision(self) -> tuple[str, datetime]: except git.exc.ODBError as e: # type: ignore raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e modified = rev.committed_datetime.astimezone(timezone.utc) - logger.debug( - "Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified - ) + logger.debug("Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified) return rev.hexsha, modified @cachedmethod(lambda self: self._read_raw_cache) diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py index 8f3470e0..3f2a628c 100644 --- a/src/diracx/core/config/schema.py +++ b/src/diracx/core/config/schema.py @@ -2,7 +2,7 @@ import os from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel as _BaseModel from pydantic import EmailStr, PrivateAttr, root_validator @@ -22,9 +22,7 @@ def legacy_adaptor(cls, v): # though ideally we should parse the type hints properly. for field, hint in cls.__annotations__.items(): # Convert comma separated lists to actual lists - if hint in {"list[str]", "list[SecurityProperty]"} and isinstance( - v.get(field), str - ): + if hint in {"list[str]", "list[SecurityProperty]"} and isinstance(v.get(field), str): v[field] = [x.strip() for x in v[field].split(",") if x.strip()] # If the field is optional and the value is "None" convert it to None if "| None" in hint and field in v: @@ -49,12 +47,12 @@ class GroupConfig(BaseModel): AutoAddVOMS: bool = False AutoUploadPilotProxy: bool = False AutoUploadProxy: bool = False - JobShare: Optional[int] + JobShare: int | None Properties: list[SecurityProperty] - Quota: Optional[int] + Quota: int | None Users: list[str] AllowBackgroundTQs: bool = False - VOMSRole: Optional[str] + VOMSRole: str | None AutoSyncVOMS: bool = False diff --git a/src/diracx/core/extensions.py b/src/diracx/core/extensions.py index 46786a96..27bb6396 100644 --- a/src/diracx/core/extensions.py +++ b/src/diracx/core/extensions.py @@ -1,12 +1,10 @@ -from __future__ import absolute_import - __all__ = ("select_from_extension",) import os from collections import defaultdict +from collections.abc import Iterator from importlib.metadata import EntryPoint, entry_points from importlib.util import find_spec -from typing import Iterator def extensions_by_priority() -> Iterator[str]: @@ -17,9 +15,7 @@ def extensions_by_priority() -> Iterator[str]: yield module_name -def select_from_extension( - *, group: str, name: str | None = None -) -> Iterator[EntryPoint]: +def select_from_extension(*, group: str, name: str | None = None) -> Iterator[EntryPoint]: """Select entry points by group and name, in order of priority. Similar to ``importlib.metadata.entry_points.select`` except only modules diff --git a/src/diracx/core/properties.py b/src/diracx/core/properties.py index 8bf80201..4403afc8 100644 --- a/src/diracx/core/properties.py +++ b/src/diracx/core/properties.py @@ -5,7 +5,7 @@ import inspect import operator -from typing import Callable +from collections.abc import Callable from diracx.core.extensions import select_from_extension @@ -14,9 +14,7 @@ class SecurityProperty(str): @classmethod def available_properties(cls) -> set[SecurityProperty]: properties = set() - for entry_point in select_from_extension( - group="diracx", name="properties_module" - ): + for entry_point in select_from_extension(group="diracx", name="properties_module"): properties_module = entry_point.load() for _, obj in inspect.getmembers(properties_module): if isinstance(obj, SecurityProperty): @@ -26,23 +24,17 @@ def available_properties(cls) -> set[SecurityProperty]: def __repr__(self) -> str: return f"{self.__class__.__name__}({self})" - def __and__( - self, value: SecurityProperty | UnevaluatedProperty - ) -> UnevaluatedExpression: + def __and__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression: if not isinstance(value, UnevaluatedProperty): value = UnevaluatedProperty(value) return UnevaluatedProperty(self) & value - def __or__( - self, value: SecurityProperty | UnevaluatedProperty - ) -> UnevaluatedExpression: + def __or__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression: if not isinstance(value, UnevaluatedProperty): value = UnevaluatedProperty(value) return UnevaluatedProperty(self) | value - def __xor__( - self, value: SecurityProperty | UnevaluatedProperty - ) -> UnevaluatedExpression: + def __xor__(self, value: SecurityProperty | UnevaluatedProperty) -> UnevaluatedExpression: if not isinstance(value, UnevaluatedProperty): value = UnevaluatedProperty(value) return UnevaluatedProperty(self) ^ value diff --git a/src/diracx/db/auth/db.py b/src/diracx/db/auth/db.py index 687e8aa9..30454eb4 100644 --- a/src/diracx/db/auth/db.py +++ b/src/diracx/db/auth/db.py @@ -25,9 +25,7 @@ class AuthDB(BaseDB): # This needs to be here for the BaseDB to create the engine metadata = AuthDBBase.metadata - async def device_flow_validate_user_code( - self, user_code: str, max_validity: int - ) -> str: + async def device_flow_validate_user_code(self, user_code: str, max_validity: int) -> str: """Validate that the user_code can be used (Pending status, not expired) Returns the scope field for the given user_code @@ -52,9 +50,7 @@ async def get_device_flow(self, device_code: str, max_validity: int): # multiple time concurrently stmt = select( DeviceFlows, - (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label( - "is_expired" - ), + (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label("is_expired"), ).with_for_update() stmt = stmt.where( DeviceFlows.device_code == device_code, @@ -67,9 +63,7 @@ async def get_device_flow(self, device_code: str, max_validity: int): if res["status"] == FlowStatus.READY: # Update the status to Done before returning await self.conn.execute( - update(DeviceFlows) - .where(DeviceFlows.device_code == device_code) - .values(status=FlowStatus.DONE) + update(DeviceFlows).where(DeviceFlows.device_code == device_code).values(status=FlowStatus.DONE) ) return res @@ -82,9 +76,7 @@ async def get_device_flow(self, device_code: str, max_validity: int): raise AuthorizationError("Bad state in device flow") - async def device_flow_insert_id_token( - self, user_code: str, id_token: dict[str, str], max_validity: int - ) -> None: + async def device_flow_insert_id_token(self, user_code: str, id_token: dict[str, str], max_validity: int) -> None: """ :raises: AuthorizationError if no such code or status not pending """ @@ -97,9 +89,7 @@ async def device_flow_insert_id_token( stmt = stmt.values(id_token=id_token, status=FlowStatus.READY) res = await self.conn.execute(stmt) if res.rowcount != 1: - raise AuthorizationError( - f"{res.rowcount} rows matched user_code {user_code}" - ) + raise AuthorizationError(f"{res.rowcount} rows matched user_code {user_code}") async def insert_device_flow( self, @@ -109,8 +99,7 @@ async def insert_device_flow( ) -> tuple[str, str]: for _ in range(MAX_RETRY): user_code = "".join( - secrets.choice(USER_CODE_ALPHABET) - for _ in range(DeviceFlows.user_code.type.length) # type: ignore + secrets.choice(USER_CODE_ALPHABET) for _ in range(DeviceFlows.user_code.type.length) # type: ignore ) # user_code = "2QRKPY" device_code = secrets.token_urlsafe() @@ -128,9 +117,7 @@ async def insert_device_flow( continue return user_code, device_code - raise NotImplementedError( - f"Could not insert new device flow after {MAX_RETRY} retries" - ) + raise NotImplementedError(f"Could not insert new device flow after {MAX_RETRY} retries") async def insert_authorization_flow( self, @@ -200,9 +187,7 @@ async def get_authorization_flow(self, code: str, max_validity: int): if res["status"] == FlowStatus.READY: # Update the status to Done before returning await self.conn.execute( - update(AuthorizationFlows) - .where(AuthorizationFlows.code == code) - .values(status=FlowStatus.DONE) + update(AuthorizationFlows).where(AuthorizationFlows.code == code).values(status=FlowStatus.DONE) ) return res diff --git a/src/diracx/db/dummy/db.py b/src/diracx/db/dummy/db.py index a4456219..06cf3483 100644 --- a/src/diracx/db/dummy/db.py +++ b/src/diracx/db/dummy/db.py @@ -30,11 +30,7 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]: stmt = stmt.group_by(*columns) # Execute the query - return [ - dict(row._mapping) - async for row in (await self.conn.stream(stmt)) - if row.count > 0 # type: ignore - ] + return [dict(row._mapping) async for row in (await self.conn.stream(stmt)) if row.count > 0] # type: ignore async def insert_owner(self, name: str) -> int: stmt = insert(Owners).values(name=name) @@ -43,9 +39,7 @@ async def insert_owner(self, name: str) -> int: return result.lastrowid async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int: - stmt = insert(Cars).values( - licensePlate=license_plate, model=model, ownerID=owner_id - ) + stmt = insert(Cars).values(licensePlate=license_plate, model=model, ownerID=owner_id) result = await self.conn.execute(stmt) # await self.engine.commit() diff --git a/src/diracx/db/jobs/db.py b/src/diracx/db/jobs/db.py index a933cfef..0533dc3d 100644 --- a/src/diracx/db/jobs/db.py +++ b/src/diracx/db/jobs/db.py @@ -30,11 +30,7 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]: stmt = stmt.group_by(*columns) # Execute the query - return [ - dict(row._mapping) - async for row in (await self.conn.stream(stmt)) - if row.count > 0 # type: ignore - ] + return [dict(row._mapping) async for row in (await self.conn.stream(stmt)) if row.count > 0] # type: ignore async def search( self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None @@ -42,12 +38,8 @@ async def search( # Find which columns to select columns = [x for x in Jobs.__table__.columns] if parameters: - if unrecognised_parameters := set(parameters) - set( - Jobs.__table__.columns.keys() - ): - raise InvalidQueryError( - f"Unrecognised parameters requested {unrecognised_parameters}" - ) + if unrecognised_parameters := set(parameters) - set(Jobs.__table__.columns.keys()): + raise InvalidQueryError(f"Unrecognised parameters requested {unrecognised_parameters}") columns = [c for c in columns if c.name in parameters] stmt = select(*columns) @@ -73,9 +65,7 @@ async def search( async def _insertNewJDL(self, jdl) -> int: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL - stmt = insert(JobJDLs).values( - JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl) - ) + stmt = insert(JobJDLs).values(JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl)) result = await self.conn.execute(stmt) # await self.engine.commit() return result.lastrowid @@ -140,9 +130,7 @@ async def _checkAndPrepareJob( async def setJobJDL(self, job_id, jdl): from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL - stmt = ( - update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl)) - ) + stmt = update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl)) await self.conn.execute(stmt) async def insert( @@ -173,9 +161,7 @@ async def insert( "DIRACSetup": dirac_setup, } - jobManifest = returnValueOrRaise( - checkAndAddOwner(jdl, owner, owner_dn, owner_group, dirac_setup) - ) + jobManifest = returnValueOrRaise(checkAndAddOwner(jdl, owner, owner_dn, owner_group, dirac_setup)) jdl = fixJDL(jdl) diff --git a/src/diracx/db/jobs/schema.py b/src/diracx/db/jobs/schema.py index 532f3a83..a00ab0bc 100644 --- a/src/diracx/db/jobs/schema.py +++ b/src/diracx/db/jobs/schema.py @@ -59,9 +59,7 @@ class Jobs(Base): JobType = Column("JobType", String(32), default="user") DIRACSetup = Column("DIRACSetup", String(32), default="test") JobGroup = Column("JobGroup", String(32), default="00000000") - JobSplitType = Column( - "JobSplitType", Enum("Single", "Master", "Subjob", "DAGNode"), default="Single" - ) + JobSplitType = Column("JobSplitType", Enum("Single", "Master", "Subjob", "DAGNode"), default="Single") MasterJobID = Column("MasterJobID", Integer, default=0) Site = Column("Site", String(100), default="ANY") JobName = Column("JobName", String(128), default="Unknown") @@ -89,9 +87,7 @@ class Jobs(Base): OSandboxReadyFlag = Column("OSandboxReadyFlag", EnumBackedBool(), default=False) RetrievedFlag = Column("RetrievedFlag", EnumBackedBool(), default=False) # TODO: Should this be True/False/"Failed"? Or True/False/Null? - AccountedFlag = Column( - "AccountedFlag", Enum("True", "False", "Failed"), default="False" - ) + AccountedFlag = Column("AccountedFlag", Enum("True", "False", "Failed"), default="False") __table_args__ = ( ForeignKeyConstraint(["JobID"], ["JobJDLs.JobID"]), diff --git a/src/diracx/db/utils.py b/src/diracx/db/utils.py index 4efc58d5..750a54e5 100644 --- a/src/diracx/db/utils.py +++ b/src/diracx/db/utils.py @@ -5,9 +5,10 @@ import contextlib import os from abc import ABCMeta +from collections.abc import AsyncIterator from datetime import datetime, timedelta, timezone from functools import partial -from typing import TYPE_CHECKING, AsyncIterator, Self +from typing import TYPE_CHECKING, Self from pydantic import parse_obj_as from sqlalchemy import Column as RawColumn diff --git a/src/diracx/routers/__init__.py b/src/diracx/routers/__init__.py index fd8710b9..43710cf3 100644 --- a/src/diracx/routers/__init__.py +++ b/src/diracx/routers/__init__.py @@ -3,8 +3,9 @@ import inspect import logging import os +from collections.abc import AsyncGenerator, Iterable from functools import partial -from typing import AsyncContextManager, AsyncGenerator, Iterable, TypeVar +from typing import AsyncContextManager, TypeVar import dotenv from fastapi import APIRouter, Depends, Request @@ -59,8 +60,7 @@ def create_app_inner( available_db_classes: set[type[BaseDB]] = set() for db_name, db_url in database_urls.items(): db_classes: list[type[BaseDB]] = [ - entry_point.load() - for entry_point in select_from_extension(group="diracx.dbs", name=db_name) + entry_point.load() for entry_point in select_from_extension(group="diracx.dbs", name=db_name) ] assert db_classes, f"Could not find {db_name=}" # The first DB is the highest priority one @@ -79,9 +79,7 @@ def create_app_inner( # Without this AutoREST generates different client sources for each ordering for system_name in sorted(enabled_systems): assert system_name not in routers - for entry_point in select_from_extension( - group="diracx.services", name=system_name - ): + for entry_point in select_from_extension(group="diracx.services", name=system_name): routers[system_name] = entry_point.load() break else: @@ -92,16 +90,12 @@ def create_app_inner( # Ensure required settings are available for cls in find_dependents(router, ServiceSettingsBase): if cls not in available_settings_classes: - raise NotImplementedError( - f"Cannot enable {system_name=} as it requires {cls=}" - ) + raise NotImplementedError(f"Cannot enable {system_name=} as it requires {cls=}") # Ensure required DBs are available missing_dbs = set(find_dependents(router, BaseDB)) - available_db_classes if missing_dbs: - raise NotImplementedError( - f"Cannot enable {system_name=} as it requires {missing_dbs=}" - ) + raise NotImplementedError(f"Cannot enable {system_name=} as it requires {missing_dbs=}") # Add the router to the application dependencies = [] @@ -155,18 +149,14 @@ def create_app() -> DiracFastAPI: def dirac_error_handler(request: Request, exc: DiracError) -> Response: - return JSONResponse( - status_code=exc.http_status_code, content={"detail": exc.detail} - ) + return JSONResponse(status_code=exc.http_status_code, content={"detail": exc.detail}) def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response: return JSONResponse(status_code=exc.status_code, content=exc.data) -def find_dependents( - obj: APIRouter | Iterable[Dependant], cls: type[T] -) -> Iterable[type[T]]: +def find_dependents(obj: APIRouter | Iterable[Dependant], cls: type[T]) -> Iterable[type[T]]: if isinstance(obj, APIRouter): # TODO: Support dependencies of the router itself # yield from find_dependents(obj.dependencies, cls) diff --git a/src/diracx/routers/auth.py b/src/diracx/routers/auth.py index ac88c4ad..898ec365 100644 --- a/src/diracx/routers/auth.py +++ b/src/diracx/routers/auth.py @@ -63,17 +63,11 @@ class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"): access_token_expire_minutes: int = 3000 refresh_token_expire_minutes: int = 3000 - available_properties: set[SecurityProperty] = Field( - default_factory=SecurityProperty.available_properties - ) + available_properties: set[SecurityProperty] = Field(default_factory=SecurityProperty.available_properties) def has_properties(expression: UnevaluatedProperty | SecurityProperty): - evaluator = ( - expression - if isinstance(expression, UnevaluatedProperty) - else UnevaluatedProperty(expression) - ) + evaluator = expression if isinstance(expression, UnevaluatedProperty) else UnevaluatedProperty(expression) async def require_property(user: Annotated[UserInfo, Depends(verify_dirac_token)]): if not evaluator(user.properties): @@ -127,9 +121,7 @@ async def fetch_jwk_set(url: str): async def parse_id_token(config, vo, raw_id_token: str, audience: str): - server_metadata = await get_server_metadata( - config.Registry[vo].IdP.server_metadata_url - ) + server_metadata = await get_server_metadata(config.Registry[vo].IdP.server_metadata_url) alg_values = server_metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwk_set = await fetch_jwk_set(config.Registry[vo].IdP.server_metadata_url) @@ -209,22 +201,16 @@ async def verify_dirac_token( ) -def create_access_token( - payload: dict, settings: AuthSettings, expires_delta: timedelta | None = None -) -> str: +def create_access_token(payload: dict, settings: AuthSettings, expires_delta: timedelta | None = None) -> str: to_encode = payload.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: - expire = datetime.utcnow() + timedelta( - minutes=settings.access_token_expire_minutes - ) + expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes) to_encode.update({"exp": expire}) jwt = JsonWebToken(settings.token_algorithm) - encoded_jwt = jwt.encode( - {"alg": settings.token_algorithm}, payload, settings.token_key.jwk - ) + encoded_jwt = jwt.encode({"alg": settings.token_algorithm}, payload, settings.token_key.jwk) return encoded_jwt.decode("ascii") @@ -240,9 +226,7 @@ async def exchange_token( preferred_username = id_token.get("preferred_username", sub) if sub not in config.Registry[vo].Groups[dirac_group].Users: - raise ValueError( - f"User is not a member of the requested group ({preferred_username}, {dirac_group})" - ) + raise ValueError(f"User is not a member of the requested group ({preferred_username}, {dirac_group})") payload = { "sub": f"{vo}:{sub}", @@ -290,9 +274,7 @@ async def initiate_device_flow( `auth//device?user_code=XYZ` """ if settings.dirac_client_id != client_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID" - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID") try: parse_and_validate_scope(scope, config, available_properties) @@ -302,9 +284,7 @@ async def initiate_device_flow( detail=e.args[0], ) from e - user_code, device_code = await auth_db.insert_device_flow( - client_id, scope, audience - ) + user_code, device_code = await auth_db.insert_device_flow(client_id, scope, audience) verification_uri = str(request.url.replace(query={})) @@ -317,22 +297,14 @@ async def initiate_device_flow( } -async def initiate_authorization_flow_with_iam( - config, vo: str, redirect_uri: str, state: dict[str, str] -): +async def initiate_authorization_flow_with_iam(config, vo: str, redirect_uri: str, state: dict[str, str]): # code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1 code_verifier = secrets.token_hex() # code_challenge: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .replace("=", "") - ) + code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().replace("=", "") - server_metadata = await get_server_metadata( - config.Registry[vo].IdP.server_metadata_url - ) + server_metadata = await get_server_metadata(config.Registry[vo].IdP.server_metadata_url) # Take these two from CS/.well-known authorization_endpoint = server_metadata["authorization_endpoint"] @@ -355,12 +327,8 @@ async def initiate_authorization_flow_with_iam( return authorization_flow_url -async def get_token_from_iam( - config, vo: str, code: str, state: dict[str, str], redirect_uri: str -) -> dict[str, str]: - server_metadata = await get_server_metadata( - config.Registry[vo].IdP.server_metadata_url - ) +async def get_token_from_iam(config, vo: str, code: str, state: dict[str, str], redirect_uri: str) -> dict[str, str]: + server_metadata = await get_server_metadata(config.Registry[vo].IdP.server_metadata_url) # Take these two from CS/.well-known token_endpoint = server_metadata["token_endpoint"] @@ -379,9 +347,7 @@ async def get_token_from_iam( data=data, ) if res.status_code >= 500: - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, "Failed to contact token endpoint" - ) + raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to contact token endpoint") elif res.status_code >= 400: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid code") @@ -421,9 +387,7 @@ async def do_device_flow( """ # Here we make sure the user_code actualy exists - scope = await auth_db.device_flow_validate_user_code( - user_code, settings.device_flow_expiration_seconds - ) + scope = await auth_db.device_flow_validate_user_code(user_code, settings.device_flow_expiration_seconds) parsed_scope = parse_and_validate_scope(scope, config, available_properties) redirect_uri = f"{request.url.replace(query='')}/complete" @@ -444,9 +408,7 @@ def decrypt_state(state): # TODO: There have been better schemes like rot13 return json.loads(base64.urlsafe_b64decode(state).decode()) except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state" - ) from e + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state") from e @router.get("/device/complete") @@ -466,9 +428,7 @@ async def finish_device_flow( in the cookie/session """ decrypted_state = decrypt_state(state) - assert ( - decrypted_state["grant_type"] == "urn:ietf:params:oauth:grant-type:device_code" - ) + assert decrypted_state["grant_type"] == "urn:ietf:params:oauth:grant-type:device_code" id_token = await get_token_from_iam( config, @@ -504,9 +464,7 @@ class ScopeInfoDict(TypedDict): vo: str -def parse_and_validate_scope( - scope: str, config: Config, available_properties: set[SecurityProperty] -) -> ScopeInfoDict: +def parse_and_validate_scope(scope: str, config: Config, available_properties: set[SecurityProperty]) -> ScopeInfoDict: """ Check: * At most one VO @@ -538,10 +496,7 @@ def parse_and_validate_scope( if not vos: available_vo_scopes = [repr(f"vo:{vo}") for vo in config.Registry] - raise ValueError( - "No vo scope requested, available values: " - f"{' '.join(available_vo_scopes)}" - ) + raise ValueError("No vo scope requested, available values: " f"{' '.join(available_vo_scopes)}") elif len(vos) > 1: raise ValueError(f"Only one vo is allowed but got {vos}") else: @@ -564,9 +519,7 @@ def parse_and_validate_scope( properties = [str(p) for p in config.Registry[vo].Groups[group].Properties] if not set(properties).issubset(available_properties): - raise ValueError( - f"{set(properties)-set(available_properties)} are not valid properties" - ) + raise ValueError(f"{set(properties)-set(available_properties)} are not valid properties") return { "group": group, @@ -623,8 +576,7 @@ def parse_and_validate_scope( @router.post("/token") async def token( grant_type: Annotated[ - Literal["authorization_code"] - | Literal["urn:ietf:params:oauth:grant-type:device_code"], + Literal["authorization_code"] | Literal["urn:ietf:params:oauth:grant-type:device_code"], Form(description="OAuth2 Grant type"), ], client_id: Annotated[str, Form(description="OAuth2 client id")], @@ -632,21 +584,15 @@ async def token( config: Config, settings: AuthSettings, available_properties: AvailableSecurityProperties, - device_code: Annotated[ - str | None, Form(description="device code for OAuth2 device flow") - ] = None, - code: Annotated[ - str | None, Form(description="Code for OAuth2 authorization code flow") - ] = None, + device_code: Annotated[str | None, Form(description="device code for OAuth2 device flow")] = None, + code: Annotated[str | None, Form(description="Code for OAuth2 authorization code flow")] = None, redirect_uri: Annotated[ str | None, Form(description="redirect_uri used with OAuth2 authorization code flow"), ] = None, code_verifier: Annotated[ str | None, - Form( - description="Verifier for the code challenge for the OAuth2 authorization flow with PKCE" - ), + Form(description="Verifier for the code challenge for the OAuth2 authorization flow with PKCE"), ] = None, ) -> TokenResponse: """ " Token endpoint to retrieve the token at the end of a flow. @@ -656,17 +602,11 @@ async def token( if grant_type == "urn:ietf:params:oauth:grant-type:device_code": assert device_code is not None try: - info = await auth_db.get_device_flow( - device_code, settings.device_flow_expiration_seconds - ) + info = await auth_db.get_device_flow(device_code, settings.device_flow_expiration_seconds) except PendingAuthorizationError as e: - raise DiracHttpResponse( - status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"} - ) from e + raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"}) from e except ExpiredFlowError as e: - raise DiracHttpResponse( - status.HTTP_400_BAD_REQUEST, {"error": "expired_token"} - ) from e + raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) from e # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) @@ -678,9 +618,7 @@ async def token( elif grant_type == "authorization_code": assert code is not None - info = await auth_db.get_authorization_flow( - code, settings.authorization_flow_expiration_seconds - ) + info = await auth_db.get_authorization_flow(code, settings.authorization_flow_expiration_seconds) if redirect_uri != info["redirect_uri"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -690,11 +628,7 @@ async def token( try: assert code_verifier is not None code_challenge = ( - base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode()).digest() - ) - .decode() - .strip("=") + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().strip("=") ) except Exception as e: raise HTTPException( @@ -744,13 +678,9 @@ async def authorization_flow( settings: AuthSettings, ): if settings.dirac_client_id != client_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID" - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID") if redirect_uri not in settings.allowed_redirects: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised redirect_uri" - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised redirect_uri") try: parsed_scope = parse_and_validate_scope(scope, config, available_properties) @@ -811,6 +741,4 @@ async def authorization_flow_complete( settings.authorization_flow_expiration_seconds, ) - return responses.RedirectResponse( - f"{redirect_uri}?code={code}&state={decrypted_state['external_state']}" - ) + return responses.RedirectResponse(f"{redirect_uri}?code={code}&state={decrypted_state['external_state']}") diff --git a/src/diracx/routers/configuration.py b/src/diracx/routers/configuration.py index 97a80be9..020ab5ea 100644 --- a/src/diracx/routers/configuration.py +++ b/src/diracx/routers/configuration.py @@ -47,16 +47,12 @@ async def serve_config( # a server gets out of sync with disk if if_modified_since: try: - not_before = datetime.strptime( - if_modified_since, LAST_MODIFIED_FORMAT - ).astimezone(timezone.utc) + not_before = datetime.strptime(if_modified_since, LAST_MODIFIED_FORMAT).astimezone(timezone.utc) except ValueError: pass else: if not_before > config._modified: - raise HTTPException( - status_code=status.HTTP_304_NOT_MODIFIED, headers=headers - ) + raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, headers=headers) response.headers.update(headers) diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index 9df412a1..1aeef9a8 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -32,6 +32,4 @@ def add_settings_annotation(cls: T) -> T: # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] -AvailableSecurityProperties = Annotated[ - set[SecurityProperty], Depends(SecurityProperty.available_properties) -] +AvailableSecurityProperties = Annotated[set[SecurityProperty], Depends(SecurityProperty.available_properties)] diff --git a/src/diracx/routers/fastapi_classes.py b/src/diracx/routers/fastapi_classes.py index 33d7d863..5e5e3f6e 100644 --- a/src/diracx/routers/fastapi_classes.py +++ b/src/diracx/routers/fastapi_classes.py @@ -19,9 +19,7 @@ def __init__(self): @contextlib.asynccontextmanager async def lifespan(app: DiracFastAPI): async with contextlib.AsyncExitStack() as stack: - await asyncio.gather( - *(stack.enter_async_context(f()) for f in app.lifetime_functions) - ) + await asyncio.gather(*(stack.enter_async_context(f()) for f in app.lifetime_functions)) yield self.lifetime_functions = [] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index 0092a4fa..92c04066 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -152,8 +152,7 @@ async def submit_bulk_jobs( if nJobs > MAX_PARAMETRIC_JOBS: raise NotImplementedError( EWMSJDL, - "Number of parametric jobs exceeds the limit of %d" - % MAX_PARAMETRIC_JOBS, + "Number of parametric jobs exceeds the limit of %d" % MAX_PARAMETRIC_JOBS, ) result = generateParametricJobs(jobClassAd) if not result["OK"]: @@ -172,11 +171,7 @@ async def submit_bulk_jobs( initialStatus = JobStatus.RECEIVED initialMinorStatus = "Job accepted" - for ( - jobDescription - ) in ( - jobDescList - ): # jobDescList because there might be a list generated by a parametric job + for jobDescription in jobDescList: # jobDescList because there might be a list generated by a parametric job job_id = await job_db.insert( jobDescription, user_info.sub, @@ -188,9 +183,7 @@ async def submit_bulk_jobs( user_info.vo, ) - logging.debug( - f'Job added to the JobDB", "{job_id} for {fixme_ownerDN}/{fixme_ownerGroup}' - ) + logging.debug(f'Job added to the JobDB", "{job_id} for {fixme_ownerDN}/{fixme_ownerGroup}') # TODO comment out for test just now # self.jobLoggingDB.addLoggingRecord( @@ -210,9 +203,7 @@ async def submit_bulk_jobs( # self.__sendJobsToOptimizationMind(jobIDList) # return result - return await asyncio.gather( - *(job_db.insert(j.owner, j.group, j.vo) for j in job_definitions) - ) + return await asyncio.gather(*(job_db.insert(j.owner, j.group, j.vo) for j in job_definitions)) @router.delete("/") @@ -276,9 +267,7 @@ async def set_status_bulk(job_update: list[JobStatusUpdate]) -> list[JobStatusRe "description": "Get only job statuses for specific jobs, ordered by status", "value": { "parameters": ["JobID", "Status"], - "search": [ - {"parameter": "JobID", "operator": "in", "values": ["6", "2", "3"]} - ], + "search": [{"parameter": "JobID", "operator": "in", "values": ["6", "2", "3"]}], "sort": [{"parameter": "JobID", "direction": "asc"}], }, }, @@ -324,9 +313,7 @@ async def search( user_info: Annotated[UserInfo, Depends(verify_dirac_token)], page: int = 0, per_page: int = 100, - body: Annotated[ - JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) - ] = None, + body: Annotated[JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES)] = None, ) -> list[dict[str, Any]]: """Retrieve information about jobs. @@ -344,9 +331,7 @@ async def search( } ) # TODO: Pagination - return await job_db.search( - body.parameters, body.search, body.sort, page=page, per_page=per_page - ) + return await job_db.search(body.parameters, body.search, body.sort, page=page, per_page=per_page) @router.post("/summary") diff --git a/tests/client/test_regenerate.py b/tests/client/test_regenerate.py index 2f0712d7..6c71d61e 100644 --- a/tests/client/test_regenerate.py +++ b/tests/client/test_regenerate.py @@ -29,9 +29,7 @@ def test_regenerate_client(test_client, tmp_path): assert (repo_root / ".git").is_dir() repo = git.Repo(repo_root) if repo.is_dirty(path=repo_root / "src" / "diracx" / "client"): - raise AssertionError( - "Client is currently in a modified state, skipping regeneration" - ) + raise AssertionError("Client is currently in a modified state, skipping regeneration") cmd = [ "autorest", diff --git a/tests/conftest.py b/tests/conftest.py index b53bec1d..212079e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,9 +75,7 @@ def with_app(test_auth_settings, with_config_repo): "JobDB": "sqlite+aiosqlite:///:memory:", "AuthDB": "sqlite+aiosqlite:///:memory:", }, - config_source=ConfigSource.create_from_url( - backend_url=f"git+file://{with_config_repo}" - ), + config_source=ConfigSource.create_from_url(backend_url=f"git+file://{with_config_repo}"), ) diff --git a/tests/core/test_secrets.py b/tests/core/test_secrets.py index 0a521dca..46ae2d23 100644 --- a/tests/core/test_secrets.py +++ b/tests/core/test_secrets.py @@ -23,9 +23,7 @@ def test_token_signing_key(tmp_path): key_file.write_text(private_key_pem) # Test that we can load a key from a file - compare_keys( - parse_obj_as(TokenSigningKey, f"{key_file}").jwk.get_private_key(), private_key - ) + compare_keys(parse_obj_as(TokenSigningKey, f"{key_file}").jwk.get_private_key(), private_key) compare_keys( parse_obj_as(TokenSigningKey, f"file://{key_file}").jwk.get_private_key(), private_key, diff --git a/tests/db/auth/test_authorization_flow.py b/tests/db/auth/test_authorization_flow.py index 22ffeca3..9bff404e 100644 --- a/tests/db/auth/test_authorization_flow.py +++ b/tests/db/auth/test_authorization_flow.py @@ -28,20 +28,14 @@ async def test_insert_id_token(auth_db: AuthDB): async with auth_db as auth_db: with pytest.raises(AuthorizationError): - code, redirect_uri = await auth_db.authorization_flow_insert_id_token( - uuid, id_token, EXPIRED - ) - code, redirect_uri = await auth_db.authorization_flow_insert_id_token( - uuid, id_token, MAX_VALIDITY - ) + code, redirect_uri = await auth_db.authorization_flow_insert_id_token(uuid, id_token, EXPIRED) + code, redirect_uri = await auth_db.authorization_flow_insert_id_token(uuid, id_token, MAX_VALIDITY) assert redirect_uri == "redirect_uri" # Cannot add a id_token a second time async with auth_db as auth_db: with pytest.raises(AuthorizationError): - await auth_db.authorization_flow_insert_id_token( - uuid, id_token, MAX_VALIDITY - ) + await auth_db.authorization_flow_insert_id_token(uuid, id_token, MAX_VALIDITY) async with auth_db as auth_db: with pytest.raises(NoResultFound): @@ -52,9 +46,7 @@ async def test_insert_id_token(auth_db: AuthDB): # Cannot add a id_token after finishing the flow async with auth_db as auth_db: with pytest.raises(AuthorizationError): - await auth_db.authorization_flow_insert_id_token( - uuid, id_token, MAX_VALIDITY - ) + await auth_db.authorization_flow_insert_id_token(uuid, id_token, MAX_VALIDITY) # We shouldn't be able to retrieve it twice async with auth_db as auth_db: diff --git a/tests/db/auth/test_device_flow.py b/tests/db/auth/test_device_flow.py index 05d703e2..3b81c165 100644 --- a/tests/db/auth/test_device_flow.py +++ b/tests/db/auth/test_device_flow.py @@ -25,9 +25,7 @@ async def test_device_user_code_collision(auth_db: AuthDB, monkeypatch): # First insert should work async with auth_db as auth_db: - code, device = await auth_db.insert_device_flow( - "client_id", "scope", "audience" - ) + code, device = await auth_db.insert_device_flow("client_id", "scope", "audience") assert code == "A" * USER_CODE_LENGTH assert device @@ -38,9 +36,7 @@ async def test_device_user_code_collision(auth_db: AuthDB, monkeypatch): monkeypatch.setattr(secrets, "choice", lambda _: "B") async with auth_db as auth_db: - code, device = await auth_db.insert_device_flow( - "client_id", "scope", "audience" - ) + code, device = await auth_db.insert_device_flow("client_id", "scope", "audience") assert code == "B" * USER_CODE_LENGTH assert device @@ -56,12 +52,8 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): # First insert async with auth_db as auth_db: - user_code1, device_code1 = await auth_db.insert_device_flow( - "client_id1", "scope1", "audience1" - ) - user_code2, device_code2 = await auth_db.insert_device_flow( - "client_id2", "scope2", "audience2" - ) + user_code1, device_code1 = await auth_db.insert_device_flow("client_id1", "scope1", "audience1") + user_code2, device_code2 = await auth_db.insert_device_flow("client_id2", "scope2", "audience2") assert user_code1 != user_code2 @@ -83,19 +75,13 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): async with auth_db as auth_db: with pytest.raises(AuthorizationError): - await auth_db.device_flow_insert_id_token( - user_code1, {"token": "mytoken"}, EXPIRED - ) + await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken"}, EXPIRED) - await auth_db.device_flow_insert_id_token( - user_code1, {"token": "mytoken"}, MAX_VALIDITY - ) + await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken"}, MAX_VALIDITY) # We should not be able to insert a id_token a second time with pytest.raises(AuthorizationError): - await auth_db.device_flow_insert_id_token( - user_code1, {"token": "mytoken2"}, MAX_VALIDITY - ) + await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken2"}, MAX_VALIDITY) with pytest.raises(ExpiredFlowError): await auth_db.get_device_flow(device_code1, EXPIRED) @@ -112,17 +98,13 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): # Re-adding a token should not work after it's been minted async with auth_db as auth_db: with pytest.raises(AuthorizationError): - await auth_db.device_flow_insert_id_token( - user_code1, {"token": "mytoken"}, MAX_VALIDITY - ) + await auth_db.device_flow_insert_id_token(user_code1, {"token": "mytoken"}, MAX_VALIDITY) async def test_device_flow_insert_id_token(auth_db: AuthDB): # First insert async with auth_db as auth_db: - user_code, device_code = await auth_db.insert_device_flow( - "client_id", "scope", "audience" - ) + user_code, device_code = await auth_db.insert_device_flow("client_id", "scope", "audience") # Make sure it exists, and is Pending async with auth_db as auth_db: diff --git a/tests/db/test_dummyDB.py b/tests/db/test_dummyDB.py index 38c75792..ae322edc 100644 --- a/tests/db/test_dummyDB.py +++ b/tests/db/test_dummyDB.py @@ -34,9 +34,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): assert owner_id # Add cars, belonging to the same guy - result = await asyncio.gather( - *(dummy_db.insert_car(uuid4(), f"model_{i}", owner_id) for i in range(10)) - ) + result = await asyncio.gather(*(dummy_db.insert_car(uuid4(), f"model_{i}", owner_id) for i in range(10))) assert result # Check that there are now 10 cars assigned to a single driver @@ -47,9 +45,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): # Test the selection async with dummy_db as dummy_db: - result = await dummy_db.summary( - ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] - ) + result = await dummy_db.summary(["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}]) assert result[0]["count"] == 1 diff --git a/tests/routers/test_auth.py b/tests/routers/test_auth.py index 8b40cbba..200405a6 100644 --- a/tests/routers/test_auth.py +++ b/tests/routers/test_auth.py @@ -68,11 +68,7 @@ async def fake_parse_id_token(raw_id_token: str, audience: str, *args, **kwargs) async def test_authorization_flow(test_client, auth_httpx_mock: HTTPXMock): code_verifier = secrets.token_hex() - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .replace("=", "") - ) + code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().replace("=", "") r = test_client.get( "/auth/authorize", @@ -97,9 +93,7 @@ async def test_authorization_flow(test_client, auth_httpx_mock: HTTPXMock): assert r.status_code == 401, r.text # Check that an invalid state returns an error - r = test_client.get( - redirect_uri, params={"code": "invalid-code", "state": "invalid-state"} - ) + r = test_client.get(redirect_uri, params={"code": "invalid-code", "state": "invalid-state"}) assert r.status_code == 400, r.text assert "Invalid state" in r.text @@ -173,9 +167,7 @@ async def test_device_flow(test_client, auth_httpx_mock: HTTPXMock): assert r.status_code == 401, r.text # Check that an invalid state returns an error - r = test_client.get( - redirect_uri, params={"code": "invalid-code", "state": "invalid-state"} - ) + r = test_client.get(redirect_uri, params={"code": "invalid-code", "state": "invalid-state"}) assert r.status_code == 400, r.text assert "Invalid state" in r.text @@ -237,10 +229,7 @@ def test_parse_scopes(vos, groups, scope, expected): "DefaultGroup": "lhcb_user", "IdP": {"URL": "https://idp.invalid", "ClientID": "test-idp"}, "Users": {}, - "Groups": { - group: {"Properties": ["NormalUser"], "Users": []} - for group in groups - }, + "Groups": {group: {"Properties": ["NormalUser"], "Users": []} for group in groups}, } for vo in vos }, @@ -272,10 +261,7 @@ def test_parse_scopes_invalid(vos, groups, scope, expected_error): "DefaultGroup": "lhcb_user", "IdP": {"URL": "https://idp.invalid", "ClientID": "test-idp"}, "Users": {}, - "Groups": { - group: {"Properties": ["NormalUser"], "Users": []} - for group in groups - }, + "Groups": {group: {"Properties": ["NormalUser"], "Users": []} for group in groups}, } for vo in vos }, diff --git a/tests/routers/test_job_manager.py b/tests/routers/test_job_manager.py index f9e9c8ed..7888b646 100644 --- a/tests/routers/test_job_manager.py +++ b/tests/routers/test_job_manager.py @@ -115,25 +115,17 @@ def test_insert_and_search(normal_user_client): r = normal_user_client.post( "/jobs/search", - json={ - "search": [{"parameter": "Status", "operator": "eq", "value": "RECEIVED"}] - }, + json={"search": [{"parameter": "Status", "operator": "eq", "value": "RECEIVED"}]}, ) assert r.status_code == 200, r.json() assert [x["JobID"] for x in r.json()] == submitted_job_ids - r = normal_user_client.post( - "/jobs/search", json={"parameters": ["JobID", "Status"]} - ) + r = normal_user_client.post("/jobs/search", json={"parameters": ["JobID", "Status"]}) assert r.status_code == 200, r.json() - assert r.json() == [ - {"JobID": jid, "Status": "RECEIVED"} for jid in submitted_job_ids - ] + assert r.json() == [{"JobID": jid, "Status": "RECEIVED"} for jid in submitted_job_ids] # Test /jobs/summary - r = normal_user_client.post( - "/jobs/summary", json={"grouping": ["Status", "OwnerDN"]} - ) + r = normal_user_client.post("/jobs/summary", json={"grouping": ["Status", "OwnerDN"]}) assert r.status_code == 200, r.json() assert r.json() == [{"Status": "RECEIVED", "OwnerDN": "ownerDN", "count": 1}]