diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index aa558f636..9e37d5081 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,7 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 10cfad884..397b7f989 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,7 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.aio.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.aio.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index 10db0c7a9..be02776fc 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index a87316174..82c27ee1b 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -52,6 +52,12 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_pilots_add_pilot_stamps_request, + build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2356,3 +2362,583 @@ async def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index a408e57d2..0c70ce3e9 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations from ....patches.jobs.aio import JobsOperations +from ....patches.pilots.aio import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 06de02aab..ae52349c3 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -16,6 +16,8 @@ BodyAuthGetOidcTokenGrantType, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyPilotsAddPilotStamps, + BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, HeartbeatData, @@ -27,6 +29,7 @@ JobStatusUpdate, Metadata, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -52,6 +55,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -67,6 +71,8 @@ "BodyAuthGetOidcTokenGrantType", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyPilotsAddPilotStamps", + "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -78,6 +84,7 @@ "JobStatusUpdate", "Metadata", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -100,6 +107,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/diracx-client/src/diracx/client/_generated/models/_enums.py b/diracx-client/src/diracx/client/_generated/models/_enums.py index 663d9c951..23edf99d3 100644 --- a/diracx-client/src/diracx/client/_generated/models/_enums.py +++ b/diracx-client/src/diracx/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index fc909fe5a..8763de15c 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -146,6 +146,109 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: self.job_ids = job_ids +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -886,6 +989,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index 10db0c7a9..be02776fc 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index b8800ca84..c682d2a3a 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -586,6 +586,124 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/jobs" + + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2875,3 +2993,579 @@ def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any) -> L return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index b7b8c67fa..b14e98b84 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations from ...patches.jobs.sync import JobsOperations +from ...patches.pilots.sync import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py new file mode 100644 index 000000000..ac533a67c --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator_async import distributed_trace_async + +from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace_async + async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().search(**make_search_body(**kwargs)) + + @distributed_trace_async + async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().summary(**make_summary_body(**kwargs)) + + @distributed_trace_async + async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace_async + async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py new file mode 100644 index 000000000..3f5ec8c4b --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -0,0 +1,146 @@ +"""Utilities which are common to the sync and async pilots operator patches.""" + +from __future__ import annotations + +__all__ = [ + "make_search_body", + "SearchKwargs", + "make_summary_body", + "SummaryKwargs", + "AddPilotStampsKwargs", + "make_add_pilot_stamps_body", + "UpdatePilotFieldsKwargs", + "make_update_pilot_fields_body" +] + +import json +from io import BytesIO +from typing import Any, IO, TypedDict, Unpack, cast, Literal + +from diracx.core.models import SearchSpec, PilotStatus, PilotFieldsMapping + + +class ResponseExtra(TypedDict, total=False): + content_type: str + headers: dict[str, str] + params: dict[str, str] + cls: Any + + +# ------------------ Search ------------------ +class SearchBody(TypedDict, total=False): + parameters: list[str] | None + search: list[SearchSpec] | None + sort: list[str] | None + + +class SearchExtra(ResponseExtra, total=False): + page: int + per_page: int + + +class SearchKwargs(SearchBody, SearchExtra): ... + + +class UnderlyingSearchArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs: + body: SearchBody = {} + for key in SearchBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["parameters", "search", "sort"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSearchArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(SearchExtra, kwargs)) + return result + +# ------------------ Summary ------------------ + +class SummaryBody(TypedDict, total=False): + grouping: list[str] + search: list[str] + + +class SummaryKwargs(SummaryBody, ResponseExtra): ... + + +class UnderlyingSummaryArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs: + body: SummaryBody = {} + for key in SummaryBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["grouping", "search"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ AddPilotStamps ------------------ + +class AddPilotStampsBody(TypedDict, total=False): + pilot_stamps: list[str] + grid_type: str + grid_site: str + pilot_references: dict[str, str] + pilot_status: PilotStatus + vo: str + +class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... + +class UnderlyingAddPilotStampsArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> UnderlyingAddPilotStampsArgs: + body: AddPilotStampsBody = {} + for key in AddPilotStampsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingAddPilotStampsArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ UpdatePilotFields ------------------ + +class UpdatePilotFieldsBody(TypedDict, total=False): + pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + +class UpdatePilotFieldsKwargs(UpdatePilotFieldsBody, ResponseExtra): ... + +class UnderlyingUpdatePilotFields(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> UnderlyingUpdatePilotFields: + body: UpdatePilotFieldsBody = {} + for key in UpdatePilotFieldsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps_to_fields_mapping"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py new file mode 100644 index 000000000..744cee161 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator import distributed_trace + +from ..._generated.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace + def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().search(**make_search_body(**kwargs)) + + @distributed_trace + def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().summary(**make_summary_body(**kwargs)) + + @distributed_trace + def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace + def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 54d7c240d..19d8d5a41 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -15,6 +15,7 @@ class DiracError(RuntimeError): def __init__(self, detail: str = "Unknown"): self.detail = detail + super().__init__(detail) class AuthorizationError(DiracError): ... @@ -49,19 +50,19 @@ class InvalidQueryError(DiracError): class TokenNotFoundError(DiracError): - def __init__(self, jti: str, detail: str | None = None): + def __init__(self, jti: str, detail: str = ""): self.jti: str = jti super().__init__(f"Token {jti} not found" + (f" ({detail})" if detail else "")) class JobNotFoundError(DiracError): - def __init__(self, job_id: int, detail: str | None = None): + def __init__(self, job_id: int, detail: str = ""): self.job_id: int = job_id super().__init__(f"Job {job_id} not found" + (f" ({detail})" if detail else "")) class SandboxNotFoundError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -71,7 +72,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyAssignedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -81,7 +82,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyInsertedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -91,7 +92,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class JobError(DiracError): - def __init__(self, job_id, detail: str | None = None): + def __init__(self, job_id, detail: str = ""): self.job_id: int = job_id super().__init__( f"Error concerning job {job_id}" + (f" ({detail})" if detail else "") @@ -100,3 +101,15 @@ def __init__(self, job_id, detail: str | None = None): class NotReadyError(DiracError): """Tried to access a value which is asynchronously loaded but not yet available.""" + + +class PilotNotFoundError(DiracError): + """At least one pilot is not found.""" + + +class PilotAlreadyExistsError(DiracError): + """At least one pilot already exists, we avoid collitions.""" + + +class PilotAlreadyAssociatedWithJobError(DiracError): + """We can't associate a pilot with the same job twice.""" diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index bacecd5ad..18144fc38 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import StrEnum -from typing import Literal +from typing import Literal, Optional from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -31,7 +31,7 @@ class VectorSearchOperator(StrEnum): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str | int + value: str | int | datetime class VectorSearchSpec(TypedDict): @@ -325,3 +325,37 @@ class JobCommand(BaseModel): job_id: int command: Literal["Kill"] arguments: str | None = None + + +class PilotFieldsMapping(BaseModel, extra="forbid"): + """All the fields that a user can modify on a Pilot (except PilotStamp).""" + + PilotStamp: str + StatusReason: Optional[str] = None + Status: Optional[PilotStatus] = None + BenchMark: Optional[float] = None + DestinationSite: Optional[str] = None + Queue: Optional[str] = None + GridSite: Optional[str] = None + GridType: Optional[str] = None + AccountingSent: Optional[bool] = None + CurrentJobID: Optional[int] = None + + +class PilotStatus(StrEnum): + #: The pilot has been generated and is transferred to a remote site: + SUBMITTED = "Submitted" + #: The pilot is waiting for a computing resource in a batch queue: + WAITING = "Waiting" + #: The pilot is running a payload on a worker node: + RUNNING = "Running" + #: The pilot finished its execution: + DONE = "Done" + #: The pilot execution failed: + FAILED = "Failed" + #: The pilot was deleted: + DELETED = "Deleted" + #: The pilot execution was aborted: + ABORTED = "Aborted" + #: Cannot get information about the pilot status: + UNKNOWN = "Unknown" diff --git a/diracx-db/src/diracx/db/sql/__init__.py b/diracx-db/src/diracx/db/sql/__init__.py index 3be3af8a3..e2f141ad5 100644 --- a/diracx-db/src/diracx/db/sql/__init__.py +++ b/diracx-db/src/diracx/db/sql/__init__.py @@ -12,6 +12,6 @@ from .auth.db import AuthDB from .job.db import JobDB from .job_logging.db import JobLoggingDB -from .pilot_agents.db import PilotAgentsDB +from .pilots.db import PilotAgentsDB from .sandbox_metadata.db import SandboxMetadataDB from .task_queue.db import TaskQueueDB diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 5735b43bb..12b0d719c 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -3,6 +3,7 @@ from sqlalchemy import insert from uuid_utils import UUID +from diracx.core.models import SearchSpec from diracx.db.sql.utils import BaseSQLDB from .schema import Base as DummyDBBase @@ -21,8 +22,11 @@ class DummyDB(BaseSQLDB): # This needs to be here for the BaseSQLDB to create the engine metadata = DummyDBBase.metadata - async def summary(self, group_by, search) -> list[dict[str, str | int]]: - return await self._summary(Cars, group_by, search) + async def summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self._summary(table=Cars, group_by=group_by, search=search) async def insert_owner(self, name: str) -> int: stmt = insert(Owners).values(name=name) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 01cdb83a1..40b39f33b 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -13,8 +13,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec -from ..utils import BaseSQLDB, _get_columns -from ..utils.functions import utcnow +from ..utils import BaseSQLDB, _get_columns, utcnow from .schema import ( HeartBeatLoggingInfo, InputData, diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py deleted file mode 100644 index 954f081b1..000000000 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone - -from sqlalchemy import insert - -from ..utils import BaseSQLDB -from .schema import PilotAgents, PilotAgentsDBBase - - -class PilotAgentsDB(BaseSQLDB): - """PilotAgentsDB class is a front-end to the PilotAgents Database.""" - - metadata = PilotAgentsDBBase.metadata - - async def add_pilot_references( - self, - pilot_ref: list[str], - vo: str, - grid_type: str = "DIRAC", - pilot_stamps: dict | None = None, - ) -> None: - if pilot_stamps is None: - pilot_stamps = {} - - now = datetime.now(tz=timezone.utc) - - # Prepare the list of dictionaries for bulk insertion - values = [ - { - "PilotJobReference": ref, - "VO": vo, - "GridType": grid_type, - "SubmissionTime": now, - "LastUpdateTime": now, - "Status": "Submitted", - "PilotStamp": pilot_stamps.get(ref, ""), - } - for ref in pilot_ref - ] - - # Insert multiple rows in a single execute call - stmt = insert(PilotAgents).values(values) - await self.conn.execute(stmt) - return diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py b/diracx-db/src/diracx/db/sql/pilots/__init__.py similarity index 100% rename from diracx-db/src/diracx/db/sql/pilot_agents/__init__.py rename to diracx-db/src/diracx/db/sql/pilots/__init__.py diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py new file mode 100644 index 000000000..0bfb32e07 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import bindparam +from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import delete, insert, update + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, + PilotNotFoundError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + SearchSpec, + SortSpec, +) + +from ..utils import ( + BaseSQLDB, +) +from .schema import ( + JobToPilotMapping, + PilotAgents, + PilotAgentsDBBase, + PilotOutput, +) + + +class PilotAgentsDB(BaseSQLDB): + """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + + metadata = PilotAgentsDBBase.metadata + + # ----------------------------- Insert Functions ----------------------------- + + async def add_pilots( + self, + pilot_stamps: list[str], + vo: str, + grid_type: str = "DIRAC", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: dict[str, str] | None = None, + status: str = PilotStatus.SUBMITTED, + ): + """Bulk add pilots in the DB. + + If we can't find a pilot_reference associated with a stamp, we take the stamp by default. + """ + if pilot_references is None: + pilot_references = {} + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + values = [ + { + "PilotJobReference": pilot_references.get(stamp, stamp), + "VO": vo, + "GridType": grid_type, + "GridSite": grid_site, + "DestinationSite": destination_site, + "SubmissionTime": now, + "LastUpdateTime": now, + "Status": status, + "PilotStamp": stamp, + } + for stamp in pilot_stamps + ] + + # Insert multiple rows in a single execute call and use 'returning' to get primary keys + stmt = insert(PilotAgents).values(values) # Assuming 'id' is the primary key + + await self.conn.execute(stmt) + + async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): + """Associate a pilot with jobs. + + job_to_pilot_mapping format: + ```py + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + ] + ``` + + Raises: + - PilotNotFoundError if a pilot_id is not associated with a pilot. + - PilotAlreadyAssociatedWithJobError if the pilot is already associated with one of the given jobs. + - NotImplementedError if the integrity error is not caught. + + **Important note**: We assume that a job exists. + + """ + # Insert multiple rows in a single execute call + stmt = insert(JobToPilotMapping).values(job_to_pilot_mapping) + + try: + await self.conn.execute(stmt) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise PilotNotFoundError( + detail="at least one of these pilots do not exist", + ) from e + + if ( + "duplicate entry" in str(e.orig).lower() + or "unique constraint" in str(e.orig).lower() + ): + raise PilotAlreadyAssociatedWithJobError( + detail="at least one of these pilots is already associated with a given job." + ) from e + + # Other errors to catch + raise NotImplementedError( + "Engine Specific error not caught" + str(e) + ) from e + + # ----------------------------- Delete Functions ----------------------------- + + async def delete_pilots(self, pilot_ids: list[int]): + """Destructive function. Delete pilots.""" + stmt = delete(PilotAgents).where(PilotAgents.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + async def remove_jobs_from_pilots(self, pilot_ids: list[int]): + """Destructive function. De-associate jobs and pilots.""" + stmt = delete(JobToPilotMapping).where( + JobToPilotMapping.pilot_id.in_(pilot_ids) + ) + + await self.conn.execute(stmt) + + async def delete_pilot_logs(self, pilot_ids: list[int]): + """Destructive function. Remove logs from pilots.""" + stmt = delete(PilotOutput).where(PilotOutput.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + # ----------------------------- Update Functions ----------------------------- + + async def update_pilot_fields( + self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + ): + """Bulk update pilots with a mapping. + + pilot_stamps_to_fields_mapping format: + ```py + [ + { + "PilotStamp": pilot_stamp, + "BenchMark": bench_mark, + "StatusReason": pilot_reason, + "AccountingSent": accounting_sent, + "Status": status, + "CurrentJobID": current_job_id, + "Queue": queue, + ... + } + ] + ``` + + The mapping helps to update multiple fields at a time. + + Raises PilotNotFoundError if one of the pilots is not found. + """ + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp == bindparam("b_pilot_stamp")) + .values( + { + key: bindparam(key) + for key in pilot_stamps_to_fields_mapping[0] + .model_dump(exclude_none=True) + .keys() + if key != "PilotStamp" + } + ) + ) + + values = [ + { + **{"b_pilot_stamp": mapping.PilotStamp}, + **mapping.model_dump(exclude={"PilotStamp"}, exclude_none=True), + } + for mapping in pilot_stamps_to_fields_mapping + ] + + res = await self.conn.execute(stmt, values) + + if res.rowcount != len(pilot_stamps_to_fields_mapping): + raise PilotNotFoundError("at least one of the given pilot does not exist.") + + # ----------------------------- Search Functions ----------------------------- + + async def search_pilots( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for pilot information in the database.""" + return await self._search( + table=PilotAgents, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def search_pilot_to_job_mapping( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for jobs that are associated with pilots.""" + return await self._search( + table=JobToPilotMapping, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def pilot_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self._summary(table=PilotAgents, group_by=group_by, search=search) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py similarity index 92% rename from diracx-db/src/diracx/db/sql/pilot_agents/schema.py rename to diracx-db/src/diracx/db/sql/pilots/schema.py index bff7c460c..af087f1f8 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -10,6 +10,8 @@ ) from sqlalchemy.orm import declarative_base +from diracx.core.models import PilotStatus + from ..utils import Column, EnumBackedBool, NullColumn PilotAgentsDBBase = declarative_base() @@ -31,12 +33,13 @@ class PilotAgents(PilotAgentsDBBase): benchmark = Column("BenchMark", Double, default=0.0) submission_time = NullColumn("SubmissionTime", DateTime) last_update_time = NullColumn("LastUpdateTime", DateTime) - status = Column("Status", String(32), default="Unknown") + status = Column("Status", String(32), default=PilotStatus.UNKNOWN) status_reason = Column("StatusReason", String(255), default="Unknown") accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False) __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), + Index("PilotStamp", "PilotStamp"), Index("Status", "Status"), Index("Statuskey", "GridSite", "DestinationSite", "Status"), ) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 5cbb31b3f..53b3f3c96 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -7,21 +7,25 @@ apply_search_filters, apply_sort_constraints, ) -from .functions import hash, substract_date, utcnow +from .functions import ( + hash, + substract_date, + utcnow, +) from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn __all__ = ( "_get_columns", - "utcnow", + "apply_search_filters", + "apply_sort_constraints", + "BaseSQLDB", "Column", - "NullColumn", "DateNowColumn", - "BaseSQLDB", "EnumBackedBool", "EnumColumn", - "apply_search_filters", - "apply_sort_constraints", - "substract_date", "hash", + "NullColumn", + "substract_date", "SQLDBUnavailableError", + "utcnow", ) diff --git a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py deleted file mode 100644 index 3ca989885..000000000 --- a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import pytest - -from diracx.db.sql.pilot_agents.db import PilotAgentsDB - - -@pytest.fixture -async def pilot_agents_db(tmp_path) -> PilotAgentsDB: - agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") - async with agents_db.engine_context(): - async with agents_db.engine.begin() as conn: - await conn.run_sync(agents_db.metadata.create_all) - yield agents_db - - -async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): - async with pilot_agents_db as pilot_agents_db: - # Add a pilot reference - refs = [f"ref_{i}" for i in range(10)] - stamps = [f"stamp_{i}" for i in range(10)] - stamp_dict = dict(zip(refs, stamps)) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict - ) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=None - ) diff --git a/diracx-db/tests/pilot_agents/__init__.py b/diracx-db/tests/pilots/__init__.py similarity index 100% rename from diracx-db/tests/pilot_agents/__init__.py rename to diracx-db/tests/pilots/__init__.py diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py new file mode 100644 index 000000000..1e7397b39 --- /dev/null +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.db.sql.pilots.db import PilotAgentsDB + +from .utils import ( + add_stamps, # noqa: F401 + create_old_pilots_environment, # noqa: F401 + create_timed_pilots, # noqa: F401 + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.mark.asyncio +async def test_insert_and_select(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(10)] + stamps = [f"stamp_{i}" for i in range(10)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Accept duplicates because it is checked by the logic + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=None + ) + + +@pytest.mark.asyncio +async def test_insert_and_delete(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(2)] + stamps = [f"stamp_{i}" for i in range(2)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Works, the pilots exists + res = await get_pilots_by_stamp(pilot_db, [stamps[0]]) + await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + # We delete the first pilot + await pilot_db.delete_pilots([res[0]["PilotID"]]) + + # We get the 2nd pilot that is not delete (no error) + await get_pilots_by_stamp(pilot_db, [stamps[1]]) + # We get the 1st pilot that is delete (error) + + assert not await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + +@pytest.mark.asyncio +async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Assert values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 0.0 + assert pilot["Status"] == PilotStatus.SUBMITTED + assert pilot["StatusReason"] == "Unknown" + assert not pilot["AccountingSent"] + + # + # Modify a pilot, then check if every change is done + # + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ) + ] + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Set values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 1.0 + assert pilot["Status"] == PilotStatus.WAITING + assert pilot["StatusReason"] == "NewReason" + assert pilot["AccountingSent"] + + +@pytest.mark.asyncio +async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): + """We will proceed in few steps. + + 1. Create a pilot + 2. Verify that he is not associated with any job + 3. Associate with jobs + 4. Verify that he is associate with this job + 5. Associate with jobs that he already has and two that he has not + 6. Associate with jobs that he has not, but were involved in a crash + """ + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + # Add pilot + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + pilot_id = pilot["PilotID"] + + # Verify that he has no jobs + assert len(await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id)) == 0 + + now = datetime.now(tz=timezone.utc) + + # Associate pilot with jobs + pilot_jobs = [1, 2, 3] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Verify that he has all jobs + db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) + # We test both length and if every job is included if for any reason we have duplicates + assert all(job in db_jobs for job in pilot_jobs) + assert len(pilot_jobs) == len(db_jobs) + + # Associate pilot with a job that he already has, and one that he has not + pilot_jobs = [10, 1, 5] + with pytest.raises(PilotAlreadyAssociatedWithJobError): + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Associate pilot with jobs that he has not, but was previously in an error + # To test that the rollback worked + pilot_jobs = [5, 10] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py new file mode 100644 index 000000000..be80f0179 --- /dev/null +++ b/diracx-db/tests/pilots/test_query.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import pytest + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_db(pilot_db): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i + 1}" for i in range(N)] + stamps = [f"stamp_{i + 1}" for i in range(N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await pilot_db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ) + for i, pilot_stamp in enumerate(stamps) + ] + ) + + yield pilot_db + + +async def test_search_parameters(populated_pilot_db): + """Test that we can search specific parameters for pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific parameter: PilotID + total, result = await pilot_db.search_pilots(["PilotID"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + + # Search a specific parameter: Status + total, result = await pilot_db.search_pilots(["Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"Status"} + + # Search for multiple parameters: PilotID, Status + total, result = await pilot_db.search_pilots(["PilotID", "Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + + # Search for a specific parameter but use distinct: Status + total, result = await pilot_db.search_pilots(["Status"], [], [], distinct=True) + assert total == len(PILOT_STATUSES) + assert result + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + total, result = await pilot_db.search_pilots(["Dummy"], [], []) + + +async def test_search_conditions(populated_pilot_db): + """Test that we can search for specific pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert not result + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 0 + assert not result + + +async def test_search_sorts(populated_pilot_db): + """Test that we can search for pilots in the database and sort the results.""" + async with populated_pilot_db as pilot_db: + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort1, sort2]) + assert total == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + + +@pytest.mark.parametrize( + "per_page, page, expected_len, expected_first_id, expect_exception", + [ + (10, 1, 10, 1, None), # Page 1 + (10, 2, 10, 11, None), # Page 2 + (10, 10, 10, 91, None), # Page 10 + (50, 2, 50, 51, None), # Page 2 with 50 per page + (10, 11, 0, None, None), # Page beyond range, should return empty + (10, 0, None, None, InvalidQueryError), # Invalid page + (0, 1, None, None, InvalidQueryError), # Invalid per_page + ], +) +async def test_search_pagination( + populated_pilot_db, + per_page, + page, + expected_len, + expected_first_id, + expect_exception, +): + """Test pagination logic in pilot search.""" + async with populated_pilot_db as pilot_db: + if expect_exception: + with pytest.raises(expect_exception): + await pilot_db.search_pilots([], [], [], per_page=per_page, page=page) + else: + total, result = await pilot_db.search_pilots( + [], [], [], per_page=per_page, page=page + ) + assert total == N + if expected_len == 0: + assert not result + else: + assert result + assert len(result) == expected_len + assert result[0]["PilotID"] == expected_first_id diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py new file mode 100644 index 000000000..793310d0d --- /dev/null +++ b/diracx-db/tests/pilots/utils.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import pytest +from sqlalchemy import update + +from diracx.core.models import ( + ScalarSearchOperator, + ScalarSearchSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +MAIN_VO = "lhcb" +N = 100 + +# ------------ Fetching data ------------ + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + return pilots + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=10000, + ) + + return [job["JobID"] for job in jobs] + + +# ------------ Creating data ------------ + + +@pytest.fixture +async def add_stamps(pilot_db): + async def _add_stamps(start_n=0): + async with pilot_db as db: + # Add pilots + refs = [f"ref_{i}" for i in range(start_n, start_n + N)] + stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + return await get_pilots_by_stamp(db, stamps) + + return _add_stamps + + +@pytest.fixture +async def create_timed_pilots(pilot_db, add_stamps): + async def _create_timed_pilots( + old_date: datetime, aborted: bool = False, start_n=0 + ): + # Get pilots + pilots = await add_stamps(start_n) + + async with pilot_db as db: + # Update manually their age + # Collect PilotStamps + pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] + + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + .values(SubmissionTime=old_date) + ) + + if aborted: + stmt = stmt.values(Status="Aborted") + + res = await db.conn.execute(stmt) + assert res.rowcount == len(pilot_stamps) + + pilots = await get_pilots_by_stamp(db, pilot_stamps) + return pilots + + return _create_timed_pilots + + +@pytest.fixture +async def create_old_pilots_environment(pilot_db, create_timed_pilots): + non_aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), False, N + ) + aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N + ) + + aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N + ) + non_aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N + ) + + pilot_number = 4 * N + + assert pilot_number == ( + len(non_aborted_recent) + + len(aborted_recent) + + len(aborted_very_old) + + len(non_aborted_very_old) + ) + + # Phase 0. Verify that we have the right environment + async with pilot_db as pilot_db: + # Ensure that we can get every pilot (only get first of each group) + await get_pilots_by_stamp(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_very_old[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) + + return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old diff --git a/diracx-db/tests/test_dummy_db.py b/diracx-db/tests/test_dummy_db.py index f94eda5b7..8e324a28e 100644 --- a/diracx-db/tests/test_dummy_db.py +++ b/diracx-db/tests/test_dummy_db.py @@ -149,6 +149,7 @@ async def test_failed_transaction(dummy_db): assert result # This will raise an exception and the transaction will be rolled back + result = await dummy_db.summary(["unexistingfieldraisinganerror"], []) assert result[0]["count"] == 10 diff --git a/diracx-logic/src/diracx/logic/pilots/__init__.py b/diracx-logic/src/diracx/logic/pilots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py new file mode 100644 index 000000000..a6256a742 --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError +from diracx.core.models import PilotFieldsMapping +from diracx.db.sql import PilotAgentsDB + +from .query import ( + get_outdated_pilots, + get_pilot_ids_by_stamps, + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + + +async def register_new_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + vo: str, + grid_type: str, + grid_site: str, + destination_site: str, + status: str, + pilot_job_references: dict[str, str] | None, +): + # [IMPORTANT] Check unicity of pilot stamps + # If a pilot already exists, we raise an error (transaction will rollback) + existing_pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, pilot_stamps=pilot_stamps + ) + + # If we found pilots from the list, this means some pilots already exist + if len(existing_pilots) > 0: + found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} + + raise PilotAlreadyExistsError( + f"The following pilots already exist: {found_keys}" + ) + + await pilot_db.add_pilots( + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_references=pilot_job_references, + status=status, + ) + + +async def delete_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str] | None = None, + age_in_days: int | None = None, + delete_only_aborted: bool = True, + vo_constraint: str | None = None, +): + if pilot_stamps: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=pilot_stamps, allow_missing=True + ) + else: + assert age_in_days + assert vo_constraint + + cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + + pilots = await get_outdated_pilots( + pilot_db=pilot_db, + cutoff_date=cutoff_date, + only_aborted=delete_only_aborted, + parameters=["PilotID"], + vo_constraint=vo_constraint, + ) + + pilot_ids = [pilot["PilotID"] for pilot in pilots] + + await pilot_db.remove_jobs_from_pilots(pilot_ids) + await pilot_db.delete_pilot_logs(pilot_ids) + await pilot_db.delete_pilots(pilot_ids) + + +async def update_pilots_fields( + pilot_db: PilotAgentsDB, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] +): + await pilot_db.update_pilot_fields(pilot_stamps_to_fields_mapping) + + +async def add_jobs_to_pilot( + pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int] +): + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in job_ids + ] + + await pilot_db.add_jobs_to_pilot( + job_to_pilot_mapping=job_to_pilot_mapping, + ) + + +async def get_pilot_jobs_ids_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamp: str +) -> list[int]: + """Fetch pilot jobs by stamp.""" + try: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + except PilotNotFoundError: + return [] + + return await get_pilot_jobs_ids_by_pilot_id(pilot_db=pilot_db, pilot_id=pilot_id) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py new file mode 100644 index 000000000..b6cf504d7 --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.models import ( + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SearchParams, + SearchSpec, + SummaryParams, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql import PilotAgentsDB + +MAX_PER_PAGE = 10000 + + +async def search( + pilot_db: PilotAgentsDB, + user_vo: str, + page: int = 1, + per_page: int = 100, + body: SearchParams | None = None, +) -> tuple[int, list[dict[str, Any]]]: + """Retrieve information about jobs.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = SearchParams() + + body.search.append( + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=user_vo + ) + ) + + total, pilots = await pilot_db.search_pilots( + body.parameters, + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + + return total, pilots + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + parameters: list[str] = [], + allow_missing: bool = True, +) -> list[dict[Any, Any]]: + """Get pilots by their stamp. + + If `allow_missing` is set to False, if a pilot is missing, PilotNotFoundError will be raised. + """ + if parameters: + parameters.append("PilotStamp") + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + # allow_missing is set as True by default to mark explicitly when we allow or not + if not allow_missing: + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in pilots} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + detail=str(missing), + ) + + return pilots + + +async def get_pilot_ids_by_stamps( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], allow_missing=False +) -> list[int]: + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["PilotID"], + allow_missing=allow_missing, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [job["JobID"] for job in jobs] + + +async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[int]: + _, pilots = await pilot_db.search_pilot_to_job_mapping( + parameters=["PilotID"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_outdated_pilots( + pilot_db: PilotAgentsDB, + cutoff_date: datetime, + vo_constraint: str, + only_aborted: bool = True, + parameters: list[str] = [], +): + query: list[SearchSpec] = [ + ScalarSearchSpec( + parameter="SubmissionTime", + operator=ScalarSearchOperator.LESS_THAN, + value=cutoff_date, + ), + # Add VO to avoid deleting other VO's pilots + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=vo_constraint + ), + ] + + if only_aborted: + query.append( + ScalarSearchSpec( + parameter="Status", + operator=ScalarSearchOperator.EQUAL, + value=PilotStatus.ABORTED, + ) + ) + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, search=query, sorts=[] + ) + + return pilots + + +async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): + """Show information suitable for plotting.""" + body.search.append( + { + "parameter": "VO", + "operator": ScalarSearchOperator.EQUAL, + "value": vo, + } + ) + return await pilot_db.pilot_summary(body.grouping, body.search) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 6f554c74e..2038223ce 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -46,10 +46,12 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +pilots = "diracx.routers.pilots:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +PilotManagementAccessPolicy = "diracx.routers.pilots.access_policies:PilotManagementAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 000000000..03f9b8422 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import logging + +from ..fastapi_classes import DiracxRouter +from .management import router as management_router +from .query import router as query_router + +logger = logging.getLogger(__name__) + +router = DiracxRouter() +router.include_router(management_router) +router.include_router(query_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 000000000..61a324f79 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from collections.abc import Callable +from enum import StrEnum, auto +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.core.models import VectorSearchOperator, VectorSearchSpec +from diracx.core.properties import GENERIC_PILOT, SERVICE_ADMINISTRATOR +from diracx.db.sql.job.db import JobDB +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.logic.pilots.query import get_pilots_by_stamp +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + # Change some pilot fields + MANAGE_PILOTS = auto() + # Read some pilot info + READ_PILOT_FIELDS = auto() + + +class PilotManagementAccessPolicy(BaseAccessPolicy): + """Rules: + * Every user can access data about his VO + * An administrator can modify a pilot. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + pilot_db: PilotAgentsDB | None = None, + pilot_stamps: list[str] | None = None, + job_db: JobDB | None = None, + job_ids: list[int] | None = None, + allow_legacy_pilots: bool = False, + ): + assert action, "action is a mandatory parameter" + + # Users can query + # NOTE: Add into queries a VO constraint + # To manage pilots, user have to be an admin + # In some special cases (described with allow_legacy_pilots), we can allow pilots + if action == ActionType.MANAGE_PILOTS: + # To make it clear, we separate + is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties + is_a_pilot_if_allowed = ( + allow_legacy_pilots and GENERIC_PILOT in user_info.properties + ) + + if not is_an_admin and not is_a_pilot_if_allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the permission to manage pilots.", + ) + + if action == ActionType.READ_PILOT_FIELDS: + if GENERIC_PILOT in user_info.properties: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Pilots can't read other pilots info.", + ) + + # + # Additional checks if job_ids or pilot_stamps are provided + # + + # First, if job_ids are provided, we check who is the owner + if job_db and job_ids: + job_owners = await job_db.summary( + ["Owner", "VO"], + [ + VectorSearchSpec( + parameter="JobID", + operator=VectorSearchOperator.IN, + values=job_ids, + ) + ], + ) + + expected_owner = { + "Owner": user_info.preferred_username, + "VO": user_info.vo, + "count": len(set(job_ids)), + } + # All the jobs belong to the user doing the query + # and all of them are present + if not job_owners == [expected_owner]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the rights to modify a pilot.", + ) + + # This is for example when we submit pilots, we use the user VO, so no need to verify + if pilot_db and pilot_stamps: + # Else, check its VO + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["VO"], + allow_missing=True, + ) + + if len(pilots) != len(pilot_stamps): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one pilot does not exist.", + ) + + if not all(pilot["VO"] == user_info.vo for pilot in pilots): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to all pilots.", + ) + + +CheckPilotManagementPolicyCallable = Annotated[ + Callable, Depends(PilotManagementAccessPolicy.check) +] diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py new file mode 100644 index 000000000..a383643d1 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends, HTTPException, Query, status + +from diracx.core.exceptions import ( + PilotAlreadyExistsError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.core.properties import GENERIC_PILOT, JOB_ADMINISTRATOR +from diracx.logic.pilots.management import ( + delete_pilots as delete_pilots_bl, +) +from diracx.logic.pilots.management import ( + get_pilot_jobs_ids_by_stamp, + register_new_pilots, + update_pilots_fields, +) +from diracx.logic.pilots.query import get_pilot_ids_by_job_id +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..dependencies import JobDB, PilotAgentsDB +from ..fastapi_classes import DiracxRouter +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + + +@router.post("/") +async def add_pilot_stamps( + pilot_db: PilotAgentsDB, + pilot_stamps: Annotated[ + list[str], + Body(description="List of the pilot stamps we want to add to the db."), + ], + vo: Annotated[str, Body(description="Pilot virtual organization.")], + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", + grid_site: Annotated[str, Body(description="Pilots grid site.")] = "Unknown", + destination_site: Annotated[ + str, Body(description="Pilots destination site.") + ] = "NotAssigned", + pilot_references: Annotated[ + dict[str, str] | None, + Body(description="Association of a pilot reference with a pilot stamp."), + ] = None, + pilot_status: Annotated[ + PilotStatus, Body(description="Status of the pilots.") + ] = PilotStatus.SUBMITTED, +): + """Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + """ + # TODO: Verify that grid types, sites, destination sites, etc. are valids + await check_permissions( + action=ActionType.MANAGE_PILOTS, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to create thousands of pilots at a time + # (It would be still able to create thousands of pilots, but slower) + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="As a pilot, you can only create yourself.", + ) + + if JOB_ADMINISTRATOR not in user_info.properties: + if not vo == user_info.vo: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You can create pilots only for your VO.", + ) + + try: + await register_new_pilots( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_job_references=pilot_references, + status=pilot_status, + ) + except PilotAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/", status_code=HTTPStatus.NO_CONTENT) +async def delete_pilots( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + pilot_stamps: Annotated[ + list[str] | None, Query(description="Stamps of the pilots we want to delete.") + ] = None, + age_in_days: Annotated[ + int | None, + Query( + description=( + "The number of days that define the maximum age of pilots to be deleted." + "Pilots older than this age will be considered for deletion." + ) + ), + ] = None, + delete_only_aborted: Annotated[ + bool, + Query( + description=( + "Flag indicating whether to only delete pilots whose status is 'Aborted'." + "If set to True, only pilots with the 'Aborted' status will be deleted." + "It is set by default as True to avoid any mistake." + "This flag is only used for deletion by time." + ) + ), + ] = False, +): + """Endpoint to delete a pilot. + + Two features: + + 1. Or you provide pilot_stamps, so you can delete pilots by their stamp + 2. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + """ + vo_constraint: str | None = None + + # If we delete by pilot_stamps, we check that we can access them + # Else, we add a constraint to the request, to avoid deleting pilots from another VO + if pilot_stamps: + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + ) + else: + vo_constraint = user_info.vo + + if not pilot_stamps and not age_in_days: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="pilot_stamps or age_in_days have to be provided.", + ) + + await delete_pilots_bl( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + vo_constraint=vo_constraint, + ) + + +EXAMPLE_UPDATE_FIELDS = { + "Update the BenchMark field": { + "summary": "Update BenchMark", + "description": "Update only the BenchMark for one pilot.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_pilot_stamp", "BenchMark": 1.0} + ] + }, + }, + "Update multiple statuses": { + "summary": "Update multiple pilots", + "description": "Update multiple pilots statuses.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_first_pilot_stamp", "Status": "Waiting"}, + {"PilotStamp": "the_second_pilot_stamp", "Status": "Waiting"}, + ] + }, + }, +} + + +@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT) +async def update_pilot_fields( + pilot_stamps_to_fields_mapping: Annotated[ + list[PilotFieldsMapping], + Body( + description="(pilot_stamp, pilot_fields) mapping to change.", + embed=True, + openapi_examples=EXAMPLE_UPDATE_FIELDS, # type: ignore + ), + ], + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +): + """Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + """ + # Ensures stamps validity + pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time + # (It would be still able to modify thousands of pilots, but slower) + # We are not able to affirm that this pilot modifies itself + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only modify yourself.", + ) + + await update_pilots_fields( + pilot_db=pilot_db, + pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, + ) + + +@router.get("/jobs") +async def get_pilot_jobs( + pilot_db: PilotAgentsDB, + job_db: JobDB, + check_permissions: CheckPilotManagementPolicyCallable, + pilot_stamp: Annotated[ + str | None, Query(description="The stamp of the pilot.") + ] = None, + job_id: Annotated[int | None, Query(description="The ID of the job.")] = None, +) -> list[int]: + """Endpoint only for admins, to get jobs of a pilot.""" + if pilot_stamp: + # Check VO + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + ) + + return await get_pilot_jobs_ids_by_stamp( + pilot_db=pilot_db, + pilot_stamp=pilot_stamp, + ) + elif job_id: + # Check job owner + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, job_db=job_db, job_ids=[job_id] + ) + + return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="You must provide either pilot_stamp or job_id", + ) diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py new file mode 100644 index 000000000..56655b46c --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import Body, Depends, Response + +from diracx.core.models import SearchParams, SummaryParams +from diracx.logic.pilots.query import search as search_bl +from diracx.logic.pilots.query import summary as summary_bl + +from ..dependencies import PilotAgentsDB +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + +EXAMPLE_SEARCHES = { + "Show all": { + "summary": "Show all", + "description": "Shows all pilots the current user has access to.", + "value": {}, + }, + "A specific pilot": { + "summary": "A specific pilot", + "description": "Search for a specific pilot by ID", + "value": {"search": [{"parameter": "PilotID", "operator": "eq", "value": "5"}]}, + }, + "Get ordered pilot statuses": { + "summary": "Get ordered pilot statuses", + "description": "Get only pilot statuses for specific pilots, ordered by status", + "value": { + "parameters": ["PilotID", "Status"], + "search": [ + {"parameter": "PilotID", "operator": "in", "values": ["6", "2", "3"]} + ], + "sort": [{"parameter": "PilotID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of pilots returned in this response", + "schema": {"type": "string", "example": "pilots 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, +} + + +@router.post("/search", responses=EXAMPLE_RESPONSES) +async def search( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + page: int = 1, + per_page: int = 100, + body: Annotated[ + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) # type: ignore + ] = None, +) -> list[dict[str, Any]]: + """Retrieve information about pilots.""" + # Inspired by /api/jobs/query + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + total, pilots = await search_bl( + pilot_db=pilot_db, + user_vo=user_info.vo, + page=page, + per_page=per_page, + body=body, + ) + + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No pilots found but there are pilots for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(pilots) == 0 and total > 0: + response.headers["Content-Range"] = f"pilots */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of pilots is greater than the number of pilots returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(pilots) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(pilots), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"pilots {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return pilots + + +@router.post("/summary") +async def summary( + pilot_db: PilotAgentsDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + body: SummaryParams, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Show information suitable for plotting.""" + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + return await summary_bl( + pilot_db=pilot_db, + body=body, + vo=user_info.vo, + ) diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py new file mode 100644 index 000000000..c055727c9 --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from sqlalchemy import update + +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.db.sql import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +pytestmark = pytest.mark.enabled_dependencies( + [ + "PilotCredentialsAccessPolicy", + "DevelopmentSettings", + "AuthDB", + "AuthSettings", + "ConfigSource", + "BaseAccessPolicy", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + "JobDB", + ] +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +async def test_create_pilots(normal_test_client): + # Lots of request, to validate that it returns the credentials in the same order as the input references + pilot_stamps = [f"stamps_{i}" for i in range(N)] + + # -------------- Bulk insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Register a pilot that already exists, and one that does not -------------- + + body = { + "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], + "vo": MAIN_VO, + } + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 409 + assert ( + r.json()["detail"] + == f"The following pilots already exist: {{'{pilot_stamps[0]}'}}" + ) + + # -------------- Register a pilot that does not exists **but** was called before in an error -------------- + # To prove that, if I tried to register a pilot that does not exist with one that already exists, + # i can normally add the one that did not exist before (it should not have added it before) + body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"], "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 200 + + +async def test_create_pilot_and_delete_it(normal_test_client): + pilot_stamp = "stamps_1" + + # -------------- Insert -------------- + body = {"pilot_stamps": [pilot_stamp], "vo": MAIN_VO} + + # Create a pilot + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Duplicate -------------- + # Duplicate because it exists, should have 409 + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 409, r.json() + + # -------------- Delete -------------- + params = {"pilot_stamps": [pilot_stamp]} + + # We delete the pilot + r = normal_test_client.delete( + "/api/pilots/", + params=params, + ) + + assert r.status_code == 204 + + # -------------- Insert -------------- + # Create a the same pilot, but works because it does not exist anymore + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + +async def test_create_pilot_and_modify_it(normal_test_client): + pilot_stamps = ["stamps_1", "stamp_2"] + + # -------------- Insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + # Create pilots + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Modify -------------- + # We modify only the first pilot + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamps[0], + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ).model_dump(exclude_unset=True) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + body = { + "parameters": [], + "search": [], + "sort": [], + "distinct": True, + } + + r = normal_test_client.post("/api/pilots/search", json=body) + assert r.status_code == 200, r.json() + pilot1 = r.json()[0] + pilot2 = r.json()[1] + + assert pilot1["BenchMark"] == 1.0 + assert pilot1["StatusReason"] == "NewReason" + assert pilot1["AccountingSent"] + assert pilot1["Status"] == PilotStatus.WAITING + + assert pilot2["BenchMark"] != pilot1["BenchMark"] + assert pilot2["StatusReason"] != pilot1["StatusReason"] + assert pilot2["AccountingSent"] != pilot1["AccountingSent"] + assert pilot2["Status"] != pilot1["Status"] + + +@pytest.mark.asyncio +async def test_delete_pilots_by_age_and_stamp(normal_test_client): + # Generate 100 pilot stamps + pilot_stamps = [f"stamp_{i}" for i in range(100)] + + # -------------- Insert all pilots -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + r = normal_test_client.post("/api/pilots/", json=body) + assert r.status_code == 200, r.json() + + # -------------- Modify last 50 pilots' fields -------------- + to_modify = pilot_stamps[50:] + mappings = [] + for idx, stamp in enumerate(to_modify): + # First 25 of modified set to ABORTED, others to WAITING + status = PilotStatus.ABORTED if idx < 25 else PilotStatus.WAITING + mapping = PilotFieldsMapping( + PilotStamp=stamp, + BenchMark=idx + 0.1, + StatusReason=f"Reason_{idx}", + AccountingSent=(idx % 2 == 0), + Status=status, + ).model_dump(exclude_unset=True) + mappings.append(mapping) + + r = normal_test_client.patch( + "/api/pilots/metadata", + json={"pilot_stamps_to_fields_mapping": mappings}, + ) + assert r.status_code == 204 + + # -------------- Directly set SubmissionTime to March 14, 2003 for last 50 -------------- + old_date = datetime(2003, 3, 14, tzinfo=timezone.utc) + # Access DB session from normal_test_client fixtures + db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db: + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(to_modify)) + .values(SubmissionTime=old_date) + ) + await db.conn.execute(stmt) + await db.conn.commit() + + # -------------- Verify all 100 pilots exist -------------- + search_body = {"parameters": [], "search": [], "sort": [], "distinct": True} + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200, r.json() + assert len(r.json()) == 100 + + # -------------- 1) Delete only old aborted pilots (25 expected) -------------- + # age_in_days large enough to include 2003-03-14 + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15, "delete_only_aborted": True}, + ) + assert r.status_code == 204 + # Expect 75 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 75 + + # -------------- 2) Delete all old pilots (remaining 25 old) -------------- + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15}, + ) + assert r.status_code == 204 + + # Expect 50 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 50 + + # -------------- 3) Delete one recent pilot by stamp -------------- + one_stamp = pilot_stamps[10] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": [one_stamp]}) + assert r.status_code == 204 + # Expect 49 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 49 + + # -------------- 4) Delete all remaining pilots -------------- + # Collect remaining stamps + remaining = [p["PilotStamp"] for p in r.json()] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": remaining}) + assert r.status_code == 204 + # Expect none remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200 + assert len(r.json()) == 0 + + # -------------- 5) Attempt deleting unknown pilot, expect 400 -------------- + r = normal_test_client.delete( + "/api/pilots/", params={"pilot_stamps": ["unknown_stamp"]} + ) + assert r.status_code == 204 diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py new file mode 100644 index 000000000..c6d5cedb4 --- /dev/null +++ b/diracx-routers/tests/pilots/test_query.py @@ -0,0 +1,414 @@ +"""Inspired by pilots and jobs db search tests.""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "ConfigSource", + "DevelopmentSettings", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + ] +) + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +MAIN_VO = "lhcb" +N = 100 + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_client(normal_test_client): + pilot_stamps = [f"stamp_{i}" for i in range(1, N + 1)] + + # -------------- Bulk insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ).model_dump(exclude_unset=True) + for i, pilot_stamp in enumerate(pilot_stamps) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + yield normal_test_client + + +async def test_pilot_summary(populated_pilot_client: TestClient): + # Group by StatusReason + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["StatusReason"], + }, + ) + + assert r.status_code == 200 + + assert sum([el["count"] for el in r.json()]) == N + assert len(r.json()) == len(PILOT_REASONS) + + # Group by CurrentJobID + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + }, + ) + + assert r.status_code == 200 + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == N + + # Group by CurrentJobID where BenchMark < 10^2 + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + "search": [{"parameter": "BenchMark", "operator": "lt", "value": 10**2}], + }, + ) + + assert r.status_code == 200, r.json() + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == 10 + + +@pytest.fixture +async def search(populated_pilot_client): + async def _search( + parameters, conditions, sorts, distinct=False, page=1, per_page=100 + ): + body = { + "parameters": parameters, + "search": conditions, + "sort": sorts, + "distinct": distinct, + } + + params = {"per_page": per_page, "page": page} + + r = populated_pilot_client.post("/api/pilots/search", json=body, params=params) + + if r.status_code == 400: + # If we have a status_code 400, that means that the query failed + raise InvalidQueryError() + + return r.json(), r.headers + + return _search + + +async def test_search_parameters(search): + """Test that we can search specific parameters for pilots.""" + # Search a specific parameter: PilotID + result, headers = await search(["PilotID"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + assert "Content-Range" not in headers + + # Search a specific parameter: Status + result, headers = await search(["Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"Status"} + assert "Content-Range" not in headers + + # Search for multiple parameters: PilotID, Status + result, headers = await search(["PilotID", "Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + assert "Content-Range" not in headers + + # Search for a specific parameter but use distinct: Status + result, headers = await search(["Status"], [], [], distinct=True) + assert len(result) == len(PILOT_STATUSES) + assert result + assert "Content-Range" not in headers + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + result, headers = await search(["Dummy"], [], []) + + +async def test_search_conditions(search): + """Test that we can search for specific pilots.""" + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + result, headers = await search([], [condition], []) + assert not result + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + result, headers = await search([], [condition], []) + assert len(result) == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + assert "Content-Range" not in headers + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + assert len(result) == 0 + assert not result + assert "Content-Range" not in headers + + +async def test_search_sorts(search): + """Test that we can search for pilots and sort the results.""" + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + assert "Content-Range" not in headers + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort1, sort2]) + assert len(result) == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + assert "Content-Range" not in headers + + +async def test_search_pagination(search): + """Test that we can search for pilots.""" + # Search for the first 10 pilots + result, headers = await search([], [], [], per_page=10, page=1) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 1 + + # Search for the second 10 pilots + result, headers = await search([], [], [], per_page=10, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 11 + + # Search for the last 10 pilots + result, headers = await search([], [], [], per_page=10, page=10) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 91 + + # Search for the second 50 pilots + result, headers = await search([], [], [], per_page=50, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 50 + assert result[0]["PilotID"] == 51 + + # Invalid page number + result, headers = await search([], [], [], per_page=10, page=11) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert not result + + # Invalid page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=10, page=0) + + # Invalid per_page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=0, page=1) diff --git a/docs/dev/explanations/pilots.md b/docs/dev/explanations/pilots.md new file mode 100644 index 000000000..544e12db5 --- /dev/null +++ b/docs/dev/explanations/pilots.md @@ -0,0 +1,20 @@ +## Presentation + +Pilots are a piece of software that is running on *worker nodes*. There are two types of pilots: "DIRAC pilots", and "DiracX pilots". The first type corresponds to pilots with proxies, sent by DIRAC; and the second type corresponds to pilots with secrets. Both kinds will eventually interact with DiracX using tokens (DIRAC pilots by exchanging their proxies for tokens, DiracX by exchanging their secrets for tokens). + +## Management + +Their management is adapted in DiracX, and each feature has its own route in DiracX. We will split the `/pilots` route into two parts: + +1. `/api/pilots/*` to allow administrators and users to access and modify pilots +2. `/api/pilots/internal/*` is allocated for pilots resources: only DiracX pilots will have access to these resources + +Each part has its own security policy: we want to prevent pilots to access users resources and vice-versa. To differentiate DIRAC pilots from users, we can get their token and compare their properties: `GENERIC_PILOT` is the property that defines a pilot. For DiracX pilots, we can differentiate them by looking at the token structure: they don't have properties, but a "stamp" (their identifier). + +## Endpoints + +We ordered our endpoints like so: + +1. Creation: `POST /api/pilots/` +2. Deletion: `DELETE /api/pilots/` +3. Modification: `PATCH /api/pilots/metadata` diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index 65282efb6..fdf17b6a3 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -15,7 +15,14 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index d67986dae..76280797e 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -15,7 +15,14 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -31,6 +38,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.aio.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.aio.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +77,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 572930a93..3408891fc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 8927a2921..19925b650 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -55,6 +55,12 @@ build_lollygag_get_gubbins_secrets_request, build_lollygag_get_owner_object_request, build_lollygag_insert_owner_object_request, + build_pilots_add_pilot_stamps_request, + build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2523,3 +2529,583 @@ async def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index d8e29cfeb..7bdd59b63 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -16,6 +16,8 @@ BodyAuthGetOidcTokenGrantType, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyPilotsAddPilotStamps, + BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, HTTPValidationError, @@ -27,6 +29,7 @@ JobMetaDataAccountedFlag, JobStatusUpdate, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -52,6 +55,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -67,6 +71,8 @@ "BodyAuthGetOidcTokenGrantType", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyPilotsAddPilotStamps", + "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", "HTTPValidationError", @@ -78,6 +84,7 @@ "JobMetaDataAccountedFlag", "JobStatusUpdate", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -100,6 +107,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py index 663d9c951..23edf99d3 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index faaea49b4..2e8717cb6 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -146,6 +146,109 @@ def __init__(self, *, job_ids: List[int], **kwargs: Any) -> None: self.job_ids = job_ids +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class ExtendedMetadata(_serialization.Model): """ExtendedMetadata. @@ -907,6 +1010,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 572930a93..3408891fc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -15,6 +15,7 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +27,7 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index fa5e665ce..4358ecf51 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -635,6 +635,124 @@ def build_lollygag_get_gubbins_secrets_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/jobs" + + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -3088,3 +3206,579 @@ def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore