diff --git a/pymilo/streaming/communicator.py b/pymilo/streaming/communicator.py index 0f3d268b..a1086881 100644 --- a/pymilo/streaming/communicator.py +++ b/pymilo/streaming/communicator.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- """PyMilo Communication Mediums.""" +import uuid import json import asyncio import uvicorn import requests import websockets from enum import Enum -from pydantic import BaseModel -from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, HTTPException from .interfaces import ClientCommunicator -from .param import PYMILO_INVALID_URL, PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED +from .param import PYMILO_INVALID_URL, PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED, REST_API_PREFIX from .util import validate_websocket_url, validate_http_url @@ -27,62 +27,179 @@ def __init__(self, server_url): is_valid, server_url = validate_http_url(server_url) if not is_valid: raise Exception(PYMILO_INVALID_URL) - self._server_url = server_url + self._server_url = server_url.rstrip("/") + "/api/v1" self.session = requests.Session() retries = requests.adapters.Retry( total=10, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504] ) - self.session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries)) - self.session.mount('https://', requests.adapters.HTTPAdapter(max_retries=retries)) + adapter = requests.adapters.HTTPAdapter(max_retries=retries) + self.session.mount('http://', adapter) + self.session.mount('https://', adapter) - def download(self, payload): + def download(self, client_id, model_id): """ Request for the remote ML model to download. - :param payload: download request payload - :type payload: dict + :param client_id: ID of the requesting client + :param model_id: ID of the model to download :return: string serialized model """ - response = self.session.get(url=self._server_url + "/download/", json=payload, timeout=5) - if response.status_code != 200: - return None + url = f"{self._server_url}/clients/{client_id}/models/{model_id}/download" + response = self.session.get(url, timeout=5) + response.raise_for_status() return response.json()["payload"] - def upload(self, payload): + def upload(self, client_id, model_id, model): """ Upload the local ML model to the remote server. - :param payload: upload request payload - :type payload: dict + :param client_id: ID of the client + :param model_id: ID of the model + :param model: serialized model content :return: True if upload was successful, False otherwise """ - response = self.session.post(url=self._server_url + "/upload/", json=payload, timeout=5) + url = f"{self._server_url}/clients/{client_id}/models/{model_id}/upload" + response = self.session.post(url, json=model, timeout=5) return response.status_code == 200 - def attribute_call(self, payload): + def attribute_call(self, client_id, model_id, call_payload): """ Delegate the requested attribute call to the remote server. - :param payload: attribute call request payload - :type payload: dict + :param client_id: ID of the client + :param model_id: ID of the model + :param call_payload: payload containing attribute name, args, and kwargs :return: json-encoded response of pymilo server """ - response = self.session.post(url=self._server_url + "/attribute_call/", json=payload, timeout=5) + url = f"{self._server_url}/clients/{client_id}/models/{model_id}/attribute-call" + response = self.session.post(url, json=call_payload, timeout=5) + response.raise_for_status() return response.json() - def attribute_type(self, payload): + def attribute_type(self, client_id, model_id, type_payload): """ Identify the attribute type of the requested attribute. - :param payload: attribute type request payload - :type payload: dict + :param client_id: ID of the client + :param model_id: ID of the model + :param type_payload: payload containing attribute data to inspect :return: response of pymilo server """ - response = self.session.post(url=self._server_url + "/attribute_type/", json=payload, timeout=5) + url = f"{self._server_url}/clients/{client_id}/models/{model_id}/attribute-type" + response = self.session.post(url, json=type_payload, timeout=5) + response.raise_for_status() return response.json() + def register_client(self): + """ + Register client in the PyMiloServer. + + :return: newly allocated client id + """ + response = self.session.get(f"{self._server_url}/clients/register", timeout=5) + response.raise_for_status() + return response.json()["client_id"] + + def remove_client(self, client_id): + """ + Remove client from the PyMiloServer. + + :param client_id: id of the client to remove + :type client_id: str + :return: True if removal was successful, False otherwise + """ + response = self.session.delete(f"{self._server_url}/clients/{client_id}", timeout=5) + return response.status_code == 200 + + def register_model(self, client_id): + """ + Register ML model in the PyMiloServer. + + :param client_id: id of the client who owns the model + :type client_id: str + :return: newly allocated ml model id + """ + response = self.session.post(f"{self._server_url}/clients/{client_id}/models/register", timeout=5) + response.raise_for_status() + return response.json()["ml_model_id"] + + def remove_model(self, client_id, model_id): + """ + Remove ML model from the PyMiloServer. + + :param client_id: client owning the model + :type client_id: str + :param model_id: model to remove + :type model_id: str + :return: True if removal was successful, False otherwise + """ + response = self.session.delete(f"{self._server_url}/clients/{client_id}/models/{model_id}", timeout=5) + return response.status_code == 200 + + def get_ml_models(self, client_id): + """ + Get all ML models registered for this specific client in the PyMiloServer. + + :param client_id: client whose models are being queried + :type client_id: str + :return: list of ml model ids + """ + response = self.session.get(f"{self._server_url}/clients/{client_id}/models", timeout=5) + response.raise_for_status() + return response.json()["ml_models_id"] + + def grant_access(self, allower_id, allowee_id, model_id): + """ + Grant access to a model to another client. + + :param allower_id: ID of the client granting access + :param allowee_id: ID of the client being granted access + :param model_id: ID of the model being shared + :return: True if successful, False otherwise + """ + url = f"{self._server_url}/clients/{allower_id}/grant/{allowee_id}/models/{model_id}" + response = self.session.post(url, timeout=5) + return response.status_code == 200 + + def revoke_access(self, revoker_id, revokee_id, model_id): + """ + Revoke previously granted model access. + + :param revoker_id: ID of the client revoking access + :param revokee_id: ID of the client whose access is being revoked + :param model_id: ID of the model + :return: True if successful, False otherwise + """ + url = f"{self._server_url}/clients/{revoker_id}/revoke/{revokee_id}/models/{model_id}" + response = self.session.post(url, timeout=5) + return response.status_code == 200 + + def get_allowance(self, allower_id): + """ + Get the list of all allowees and their allowed models from a given allower. + + :param allower_id: ID of the allower + :return: dict of allowees to model lists + """ + response = self.session.get(f"{self._server_url}/clients/{allower_id}/allowances", timeout=5) + response.raise_for_status() + return response.json()["allowance"] + + def get_allowed_models(self, allower_id, allowee_id): + """ + Get the list of models that one client is allowed to access from another. + + :param allower_id: ID of the model owner + :param allowee_id: ID of the requesting client + :return: list of model IDs + """ + url = f"{self._server_url}/clients/{allower_id}/allowances/{allowee_id}" + response = self.session.get(url, timeout=5) + response.raise_for_status() + return response.json()["allowed_models"] + class RESTServerCommunicator(): """Facilitate working with the communication medium from the server side for the REST protocol.""" @@ -112,69 +229,125 @@ def __init__( def setup_routes(self): """Configure endpoints to handle RESTClientCommunicator requests.""" - class StandardPayload(BaseModel): - client_id: str - ml_model_id: str - - class DownloadPayload(StandardPayload): - pass - - class UploadPayload(StandardPayload): - model: str - - class AttributeCallPayload(StandardPayload): - attribute: str - args: dict - kwargs: dict - - class AttributeTypePayload(StandardPayload): - attribute: str - - @self.app.get("/download/") - async def download(request: Request): - body = await request.json() - body = self.parse(body) - payload = DownloadPayload(**body) - message = "/download request from client: {} for model: {}".format(payload.client_id, payload.ml_model_id) + + @self.app.get(f"{REST_API_PREFIX}/health") + async def health(): + return {"status": "ok"} + + @self.app.get(f"{REST_API_PREFIX}/clients/register") + async def request_client_id(): + client_id = str(uuid.uuid4()) + self._ps.init_client(client_id) + return {"client_id": client_id} + + @self.app.delete(f"{REST_API_PREFIX}/clients/{{client_id}}") + async def remove_client(client_id: str): + is_succeed, detail_message = self._ps.remove_client(client_id) + if not is_succeed: + raise HTTPException(status_code=404, detail=detail_message) + return {"client_id": client_id} + + @self.app.post(f"{REST_API_PREFIX}/clients/{{client_id}}/models/register") + async def request_model(client_id: str): + model_id = str(uuid.uuid4()) + is_succeed, detail_message = self._ps.init_ml_model(client_id, model_id) + if not is_succeed: + raise HTTPException(status_code=404, detail=detail_message) + return {"client_id": client_id, "ml_model_id": model_id} + + @self.app.delete(f"{REST_API_PREFIX}/clients/{{client_id}}/models/{{ml_model_id}}") + async def remove_model(client_id: str, ml_model_id: str): + is_succeed, detail_message = self._ps.remove_ml_model(client_id, ml_model_id) + if not is_succeed: + raise HTTPException(status_code=404, detail=detail_message) + return {"client_id": client_id, "ml_model_id": ml_model_id} + + @self.app.get(f"{REST_API_PREFIX}/clients/{{client_id}}/models") + async def get_client_models(client_id: str): + return {"client_id": client_id, "ml_models_id": self._ps.get_ml_models(client_id)} + + @self.app.post(f"{REST_API_PREFIX}/clients/{{allower_id}}/grant/{{allowee_id}}/models/{{model_id}}") + async def grant_model_access(allower_id: str, allowee_id: str, model_id: str): + is_succeed, detail_message = self._ps.grant_access(allower_id, allowee_id, model_id) + if not is_succeed: + raise HTTPException(status_code=404, detail=detail_message) return { - "message": message, - "payload": self._ps.export_model(), + "allower_id": allower_id, + "allowee_id": allowee_id, + "allowed_model_id": model_id } - @self.app.post("/upload/") - async def upload(request: Request): - body = await request.json() - body = self.parse(body) - payload = UploadPayload(**body) - message = "/upload request from client: {} for model: {}".format(payload.client_id, payload.ml_model_id) + @self.app.post(f"{REST_API_PREFIX}/clients/{{revoker_id}}/revoke/{{revokee_id}}/models/{{model_id}}") + async def revoke_model_access(revoker_id: str, revokee_id: str, model_id: str): + is_succeed, detail_message = self._ps.revoke_access(revoker_id, revokee_id, model_id) + if not is_succeed: + raise HTTPException(status_code=404, detail=detail_message) return { - "message": message, - "payload": self._ps.update_model(payload.model) + "revoker_id": revoker_id, + "revokee_id": revokee_id, + "revoked_model_id": model_id } - @self.app.post("/attribute_call/") - async def attribute_call(request: Request): - body = await request.json() - body = self.parse(body) - payload = AttributeCallPayload(**body) - message = "/attribute_call request from client: {} for model: {}".format( - payload.client_id, payload.ml_model_id) - result = self._ps.execute_model(payload) + @self.app.get(f"{REST_API_PREFIX}/clients/{{allower_id}}/allowances") + async def get_allowance(allower_id: str): + allowance, reason = self._ps.get_clients_allowance(allower_id) + if not allowance: + raise HTTPException(status_code=404, detail=reason) + return {"allower_id": allower_id, "allowance": allowance} + + @self.app.get(f"{REST_API_PREFIX}/clients/{{allower_id}}/allowances/{{allowee_id}}") + async def get_allowed_models(allower_id: str, allowee_id: str): + models, reason = self._ps.get_allowed_models(allower_id, allowee_id) + if models is None: + raise HTTPException(status_code=404, detail=reason) + return {"allower_id": allower_id, "allowee_id": allowee_id, "allowed_models": models} + + @self.app.get(f"{REST_API_PREFIX}/clients/{{client_id}}/models/{{ml_model_id}}/download") + async def download_model(client_id: str, ml_model_id: str): + is_valid, reason = self._ps._validate_id(client_id, ml_model_id) + if not is_valid: + raise HTTPException(status_code=404, detail=reason) return { - "message": message, - "payload": result if result is not None else "The ML model has been updated in place." + "message": f"/download request from client: {client_id} for model: {ml_model_id}", + "payload": self._ps.export_model(client_id, ml_model_id) } - @self.app.post("/attribute_type/") - async def attribute_type(request: Request): - body = await request.json() - body = self.parse(body) - payload = AttributeTypePayload(**body) - message = "/attribute_type request from client: {} for model: {}".format( - payload.client_id, payload.ml_model_id) - is_callable, field_value = self._ps.is_callable_attribute(payload) + @self.app.post(f"{REST_API_PREFIX}/clients/{{client_id}}/models/{{ml_model_id}}/upload") + async def upload_model(client_id: str, ml_model_id: str, request: Request): + model_data = self.parse(await request.json()).get("model") + if model_data is None: + raise HTTPException(status_code=400, detail="Missing 'model' in request") + + is_valid, reason = self._ps._validate_id(client_id, ml_model_id) + if not is_valid: + raise HTTPException(status_code=404, detail=reason) + return { - "message": message, + "message": f"/upload request from client: {client_id} for model: {ml_model_id}", + "payload": self._ps.update_model(client_id, ml_model_id, model_data) + } + + @self.app.post(f"{REST_API_PREFIX}/clients/{{client_id}}/models/{{ml_model_id}}/attribute-call") + async def attribute_call(client_id: str, ml_model_id: str, request: Request): + request_payload = self.parse(await request.json()) + is_valid, reason = self._ps._validate_id(client_id, ml_model_id) + if not is_valid: + raise HTTPException(status_code=404, detail=reason) + result = self._ps.execute_model(request_payload) + return { + "message": f"/attribute_call request from client: {client_id} for model: {ml_model_id}", + "payload": result or "The ML model has been updated in place." + } + + @self.app.post(f"{REST_API_PREFIX}/clients/{{client_id}}/models/{{ml_model_id}}/attribute-type") + async def attribute_type(client_id: str, ml_model_id: str, request: Request): + request = self.parse(await request.json()) + is_valid, reason = self._ps._validate_id(client_id, ml_model_id) + if not is_valid: + raise HTTPException(status_code=404, detail=reason) + is_callable, field_value = self._ps.is_callable_attribute(request) + return { + "message": f"/attribute_type request from client: {client_id} for model: {ml_model_id}", "attribute type": "method" if is_callable else "field", "attribute value": "" if is_callable else field_value, } @@ -398,7 +571,7 @@ async def handle_message(self, websocket: WebSocket, message: str): payload = self.parse(message['payload']) if action == "download": - response = self._handle_download() + response = self._handle_download(payload) elif action == "upload": response = self._handle_upload(payload) elif action == "attribute_call": @@ -412,15 +585,20 @@ async def handle_message(self, websocket: WebSocket, message: str): except Exception as e: await websocket.send_text(json.dumps({"error": str(e)})) - def _handle_download(self) -> dict: + def _handle_download(self, payload) -> dict: """ Handle download requests. + :param payload: the payload containing the ids associated with the requested model for download. + :type payload: dict :return: a response containing the exported model. """ return { "message": "Download request received.", - "payload": self._ps.export_model(), + "payload": self._ps.export_model( + payload["client_id"], + payload["ml_model_id"], + ), } def _handle_upload(self, payload: dict) -> dict: diff --git a/pymilo/streaming/interfaces.py b/pymilo/streaming/interfaces.py index 6fb94398..6719ee2b 100644 --- a/pymilo/streaming/interfaces.py +++ b/pymilo/streaming/interfaces.py @@ -63,35 +63,161 @@ class ClientCommunicator(ABC): """ ClientCommunicator Interface. - Each ClientCommunicator has methods to upload the local ML model, download the remote ML model and delegate attribute call to the remote server. + Defines the contract for client-server communication. Each implementation is responsible for: + - Registering and removing clients and models + - Uploading and downloading ML models + - Handling delegated attribute access + - Managing model allowances between clients """ @abstractmethod - def upload(self, payload): + def register_client(self): """ - Upload the given payload to the remote server. + Register the client in the remote server. - :param payload: request payload - :type payload: dict - :return: remote server response + :return: newly allocated client ID + :rtype: str """ @abstractmethod - def download(self, payload): + def remove_client(self, client_id): """ - Download the remote ML model to local. + Remove the client from the remote server. - :param payload: request payload - :type payload: dict - :return: remote server response + :param client_id: client ID to remove + :type client_id: str + :return: success status + :rtype: bool + """ + + @abstractmethod + def register_model(self, client_id): + """ + Register an ML model for the given client. + + :param client_id: client ID + :type client_id: str + :return: newly allocated model ID + :rtype: str """ @abstractmethod - def attribute_call(self, payload): + def remove_model(self, client_id, model_id): + """ + Remove the specified ML model for the client. + + :param client_id: client ID + :type client_id: str + :param model_id: model ID + :type model_id: str + :return: success status + :rtype: bool + """ + + @abstractmethod + def get_ml_models(self, client_id): + """ + Get the list of ML models for the given client. + + :param client_id: client ID + :type client_id: str + :return: list of model IDs + :rtype: list[str] + """ + + @abstractmethod + def grant_access(self, allower_id, allowee_id, model_id): + """ + Grant access to a model from one client to another. + + :param allower_id: client who owns the model + :type allower_id: str + :param allowee_id: client to be granted access + :type allowee_id: str + :param model_id: model ID + :type model_id: str + :return: success status + :rtype: bool + """ + + @abstractmethod + def revoke_access(self, revoker_id, revokee_id, model_id): + """ + Revoke model access from one client to another. + + :param revoker_id: client who owns the model + :type revoker_id: str + :param revokee_id: client to be revoked + :type revokee_id: str + :param model_id: model ID + :type model_id: str + :return: success status + :rtype: bool + """ + + @abstractmethod + def get_allowance(self, allower_id): + """ + Get all clients and models this client has allowed. + + :param allower_id: client who granted access + :type allower_id: str + :return: dictionary mapping allowee_id to list of model_ids + :rtype: dict + """ + + @abstractmethod + def get_allowed_models(self, allower_id, allowee_id): + """ + Get the list of model IDs that `allowee_id` is allowed to access from `allower_id`. + + :param allower_id: model owner + :type allower_id: str + :param allowee_id: recipient + :type allowee_id: str + :return: list of allowed model IDs + :rtype: list[str] + """ + + @abstractmethod + def upload(self, client_id, model_id, model): + """ + Upload the local ML model to the remote server. + + :param client_id: ID of the client + :param model_id: ID of the model + :param model: serialized model content + :return: True if upload was successful, False otherwise + """ + + @abstractmethod + def download(self, client_id, model_id): + """ + Download the remote ML model. + + :param client_id: ID of the requesting client + :param model_id: ID of the model to download + :return: string serialized model + """ + + @abstractmethod + def attribute_call(self, client_id, model_id, call_payload): """ Execute an attribute call on the remote server. - :param payload: request payload - :type payload: dict + :param client_id: ID of the client + :param model_id: ID of the model + :param call_payload: payload containing attribute name, args, and kwargs + :return: remote server response + """ + + @abstractmethod + def attribute_type(self, client_id, model_id, type_payload): + """ + Identify the attribute type (method or field) on the remote model. + + :param client_id: client ID + :param model_id: model ID + :param type_payload: payload containing targeted attribute :return: remote server response """ diff --git a/pymilo/streaming/param.py b/pymilo/streaming/param.py index 3727c0c8..b622748f 100644 --- a/pymilo/streaming/param.py +++ b/pymilo/streaming/param.py @@ -10,3 +10,5 @@ PYMILO_SERVER_NON_EXISTENT_ATTRIBUTE = "The requested attribute doesn't exist in this model." PYMILO_INVALID_URL = "The given URL is not valid." PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED = "WebSocket is not connected." + +REST_API_PREFIX = "/api/v1" diff --git a/pymilo/streaming/pymilo_client.py b/pymilo/streaming/pymilo_client.py index bb347496..081800ed 100644 --- a/pymilo/streaming/pymilo_client.py +++ b/pymilo/streaming/pymilo_client.py @@ -60,9 +60,7 @@ def encrypt_compress(self, body): :return: the compressed and encrypted version of the body payload """ return self._encryptor.encrypt( - self._compressor.compress( - body - ) + self._compressor.compress(body) ) def toggle_mode(self, mode=Mode.LOCAL): @@ -83,12 +81,8 @@ def download(self): :return: None """ serialized_model = self._communicator.download( - self.encrypt_compress( - { - "client_id": self.client_id, - "ml_model_id": self.ml_model_id, - } - ) + self.client_id, + self.ml_model_id ) if serialized_model is None: print(PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL) @@ -103,19 +97,100 @@ def upload(self): :return: None """ succeed = self._communicator.upload( - self.encrypt_compress( - { - "client_id": self.client_id, - "ml_model_id": self.ml_model_id, - "model": Export(self.model).to_json(), - } - ) + self.client_id, + self.ml_model_id, + self.encrypt_compress({"model": Export(self.model).to_json()}) ) if succeed: print(PYMILO_CLIENT_LOCAL_MODEL_UPLOADED) else: print(PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED) + def register(self): + """ + Register client in the remote server. + + :return: None + """ + self.client_id = self._communicator.register_client() + + def deregister(self): + """ + Deregister client in the remote server. + + :return: None + """ + self._communicator.remove_client(self.client_id) + self.client_id = "0x_client_id" + + def register_ml_model(self): + """ + Register ML model in the remote server. + + :return: None + """ + self.ml_model_id = self._communicator.register_model(self.client_id) + + def deregister_ml_model(self): + """ + Deregister ML model in the remote server. + + :return: None + """ + self._communicator.remove_model(self.client_id, self.ml_model_id) + self.ml_model_id = "0x_ml_model_id" + + def get_ml_models(self): + """ + Get all registered ml models in the remote server for this client. + + :return: list of ml model ids + """ + return self._communicator.get_ml_models(self.client_id) + + def grant_access(self, allowee_id): + """ + Grant access to one of this client's models to another client. + + :param allowee_id: The client ID to grant access to + :return: True if successful, False otherwise + """ + return self._communicator.grant_access( + self.client_id, + allowee_id, + self.ml_model_id + ) + + def revoke_access(self, revokee_id): + """ + Revoke access previously granted to another client. + + :param revokee_id: The client ID to revoke access from + :return: True if successful, False otherwise + """ + return self._communicator.revoke_access( + self.client_id, + revokee_id, + self.ml_model_id + ) + + def get_allowance(self): + """ + Get a dictionary of all clients who have access to this client's models. + + :return: Dict of allowee_id -> list of model_ids + """ + return self._communicator.get_allowance(self.client_id) + + def get_allowed_models(self, allower_id): + """ + Get a list of models you are allowed to access from another client. + + :param allower_id: The client ID who owns the models + :return: list of allowed model IDs + """ + return self._communicator.get_allowed_models(allower_id, self.client_id) + def __getattr__(self, attribute): """ Overwrite the __getattr__ default function to extract requested. @@ -133,13 +208,13 @@ def __getattr__(self, attribute): elif self._mode == PymiloClient.Mode.DELEGATE: gdst = GeneralDataStructureTransporter() response = self._communicator.attribute_type( - self.encrypt_compress( - { - "client_id": self.client_id, - "ml_model_id": self.ml_model_id, - "attribute": attribute, - } - ) + self.client_id, + self.ml_model_id, + self.encrypt_compress({ + "attribute": attribute, + "client_id": self.client_id, + "ml_model_id": self.ml_model_id, + }) ) if response["attribute type"] == "field": return gdst.deserialize(response, "attribute value", None) @@ -155,9 +230,9 @@ def relayer(*args, **kwargs): payload["args"] = gdst.serialize(payload, "args", None) payload["kwargs"] = gdst.serialize(payload, "kwargs", None) result = self._communicator.attribute_call( - self.encrypt_compress( - payload - ) + self.client_id, + self.ml_model_id, + self.encrypt_compress(payload) ) return gdst.deserialize(result, "payload", None) return relayer diff --git a/pymilo/streaming/pymilo_server.py b/pymilo/streaming/pymilo_server.py index 568d89f7..130cd0a1 100644 --- a/pymilo/streaming/pymilo_server.py +++ b/pymilo/streaming/pymilo_server.py @@ -13,7 +13,6 @@ class PymiloServer: def __init__( self, - model=None, port=8000, host="127.0.0.1", compressor=Compression.NULL, @@ -34,20 +33,22 @@ def __init__( :type communication_protocol: pymilo.streaming.communicator.CommunicationProtocol :return: an instance of the PymiloServer class """ - self._model = model self._compressor = compressor.value self._encryptor = DummyEncryptor() + # In-memory storage (replace with a database for persistence) self.communicator = communication_protocol.value["SERVER"](ps=self, host=host, port=port) + self._clients = {} + self._allowance = {} - def export_model(self): + def export_model(self, client_id, ml_model_id): """ Export the ML model to string json dump using PyMilo Export class. :return: str """ - return Export(self._model).to_json() + return Export(self._clients[client_id][ml_model_id]).to_json() - def update_model(self, serialized_model): + def update_model(self, client_id, ml_model_id, serialized_model): """ Update the PyMilo Server's ML model. @@ -55,7 +56,7 @@ def update_model(self, serialized_model): :type serialized_model: str :return: None """ - self._model = Import(file_adr=None, json_dump=serialized_model).to_model() + self._clients[client_id][ml_model_id] = Import(file_adr=None, json_dump=serialized_model).to_model() def execute_model(self, request): """ @@ -67,7 +68,10 @@ def execute_model(self, request): """ gdst = GeneralDataStructureTransporter() attribute = request["attribute"] if isinstance(request, dict) else request.attribute - retrieved_attribute = getattr(self._model, attribute, None) + _client_id = request["client_id"] if isinstance(request, dict) else request.client_id + _ml_model_id = request["ml_model_id"] if isinstance(request, dict) else request.ml_model_id + _ml_model = self._clients[_client_id][_ml_model_id] + retrieved_attribute = getattr(_ml_model, attribute, None) if retrieved_attribute is None: raise Exception(PYMILO_SERVER_NON_EXISTENT_ATTRIBUTE) arguments = { @@ -77,8 +81,8 @@ def execute_model(self, request): args = gdst.deserialize(arguments, 'args', None) kwargs = gdst.deserialize(arguments, 'kwargs', None) output = retrieved_attribute(*args, **kwargs) - if isinstance(output, type(self._model)): - self._model = output + if isinstance(output, type(_ml_model)): + self._clients[_client_id][_ml_model_id] = output return None return gdst.serialize({'output': output}, 'output', None) @@ -91,8 +95,217 @@ def is_callable_attribute(self, request): :return: True if it is callable False otherwise """ attribute = request["attribute"] if isinstance(request, dict) else request.attribute - retrieved_attribute = getattr(self._model, attribute, None) + _client_id = request["client_id"] if isinstance(request, dict) else request.client_id + _ml_model_id = request["ml_model_id"] if isinstance(request, dict) else request.ml_model_id + _ml_model = self._clients[_client_id][_ml_model_id] + retrieved_attribute = getattr(_ml_model, attribute, None) if callable(retrieved_attribute): return True, None else: return False, GeneralDataStructureTransporter().serialize({'output': retrieved_attribute}, 'output', None) + + def _validate_id(self, client_id, ml_model_id): + """ + Validate the provided client ID and machine learning model ID. + + :param client_id: The ID of the client to validate. + :type client_id: str + :param ml_model_id: The ID of the machine learning model to validate. + :type ml_model_id: str + :return: A tuple containing a boolean indicating validity and an error message if invalid. + """ + if client_id not in self._clients: + return False, "The given client_id is invalid." + if ml_model_id not in self._clients[client_id]: + return False, "The given client_id is valid but requested ml_model_id is invalid." + return True, None + + def init_client(self, client_id): + """ + Initialize a new client with the given client ID. + + :param client_id: The ID of the client to initialize. + :type client_id: str + :return: A tuple containing a boolean indicating success and an error message if the client already exists. + """ + if client_id in self._clients: + return False, f"The client with client_id: {client_id} already exists." + self._clients[client_id] = {} + self._allowance[client_id] = {} + return True, None + + def remove_client(self, client_id): + """ + Remove an existing client by the given client ID. + + :param client_id: The ID of the client to remove. + :type client_id: str + :return: A tuple containing a boolean indicating success and an error message if the client does not exist. + """ + if client_id not in self._clients: + return False, f"The client with client_id: {client_id} doesn't exist." + del self._clients[client_id] + del self._allowance[client_id] + return True, None + + def grant_access(self, allower_id, allowee_id, allowed_model_id): + """ + Allow a client to access a specific machine learning model of another client. + + :param allower_id: The ID of the client granting access. + :type allower_id: str + :param allowee_id: The ID of the client being granted access. + :type allowee_id: str + :param allowed_model_id: The ID of the machine learning model to be accessed. + :type allowed_model_id: str + :return: A tuple containing a boolean indicating success and an error message if the operation fails. + """ + if allower_id not in self._clients: + return False, f"The allower client with client_id: {allower_id} doesn't exist." + if allowee_id not in self._clients: + return False, f"The allowee client with client_id: {allowee_id} doesn't exist." + if allowed_model_id not in self._clients[allower_id]: + return False, f"The model with ml_model_id: {allowed_model_id} doesn't exist for the allower client with client_id: {allower_id}." + + if allowed_model_id in self._allowance.get(allower_id).get(allowee_id, []): + return False, f"The model with ml_model_id: {allowed_model_id} is already allowed for the allowee client with client_id: {allowee_id} by the allower client with client_id: {allower_id}." + + if allowee_id not in self._allowance[allower_id]: + self._allowance[allower_id][allowee_id] = [allowed_model_id] + return True, None + + self._allowance[allower_id][allowee_id].append(allowed_model_id) + return True, None + + def revoke_access(self, allower_id, allowee_id, allowed_model_id=None): + """ + Revoke a client's access to a specific machine learning model of another client. + + :param allower_id: The ID of the client revoking access. + :type allower_id: str + :param allowee_id: The ID of the client whose access is being revoked. + :type allowee_id: str + :param allowed_model_id: The ID of the machine learning model whose access is being revoked. + :type allowed_model_id: str + :return: A tuple containing a boolean indicating success and an error message if the operation fails. + """ + if allower_id not in self._clients: + return False, f"The allower client with client_id: {allower_id} doesn't exist." + if allowee_id not in self._clients: + return False, f"The allowee client with client_id: {allowee_id} doesn't exist." + + if allowed_model_id is None: + if allowee_id in self._allowance[allower_id]: + del self._allowance[allower_id][allowee_id] + return True, None + + if allowed_model_id not in self._clients[allower_id]: + return False, f"The model with ml_model_id: {allowed_model_id} doesn't exist for the allower client with client_id: {allower_id}." + + if allowee_id not in self._allowance[allower_id]: + return False, f"The allowee client with client_id: {allowee_id} doesn't have any access granted by the allower client with client_id: {allower_id}." + + if allowed_model_id not in self._allowance[allower_id][allowee_id]: + return False, f"The model with ml_model_id: {allowed_model_id} is not allowed for the allowee client with client_id: {allowee_id} by the allower client with client_id: {allower_id}." + + self._allowance[allower_id][allowee_id].remove(allowed_model_id) + return True, None + + def get_allowed_models(self, allower_id, allowee_id): + """ + Retrieve a list of machine learning model IDs that a client is allowed to access from another client. + + :param allower_id: The ID of the client who granted access. + :type allower_id: str + :param allowee_id: The ID of the client who has been granted access. + :type allowee_id: str + :return: A list of allowed machine learning model IDs or an error message if access is not granted. + """ + if allower_id not in self._clients: + return None, f"The allower client with client_id: {allower_id} doesn't exist." + if allowee_id not in self._clients: + return None, f"The allowee client with client_id: {allowee_id} doesn't exist." + + return self._allowance.get(allower_id).get(allowee_id, []), None + + def get_clients_allowance(self, client_id): + """ + Retrieve the allowance dictionary for a given client. + + :param client_id: The ID of the client whose allowance is being retrieved. + :type client_id: str + :return: A dictionary containing the allowance information for the client. + """ + if client_id not in self._allowance: + return None, f"The client with client_id: {client_id} doesn't exist." + return self._allowance[client_id], None + + def get_clients(self): + """ + Retrieve a list of all registered client IDs. + + :return: A list of client IDs. + """ + return [id for id in self._clients.keys()] + + def init_ml_model(self, client_id, ml_model_id): + """ + Initialize a new machine learning model for a given client. + + :param client_id: The ID of the client to associate with the model. + :type client_id: str + :param ml_model_id: The ID of the machine learning model to initialize. + :type ml_model_id: str + :return: A tuple containing a boolean indicating success and an error message if the model already exists or the client ID is invalid. + """ + if client_id not in self._clients: + return False, "The given client_id is invalid." + + if ml_model_id in self._clients[client_id]: + return False, f"The given ml_model_id: {ml_model_id} already exists within ml models of the client with client_id of {client_id}." + + self._clients[client_id][ml_model_id] = {} + return True, None + + def set_ml_model(self, client_id, ml_model_id, ml_model): + """ + Set or update the machine learning model for a given client. + + :param client_id: The ID of the client. + :type client_id: str + :param ml_model_id: The ID of the machine learning model. + :type ml_model_id: str + :param ml_model: The machine learning model object to be set. + :type ml_model: obj + :return: None + """ + self._clients[client_id][ml_model_id] = ml_model + + def remove_ml_model(self, client_id, ml_model_id): + """ + Remove an existing machine learning model for a given client. + + :param client_id: The ID of the client. + :type client_id: str + :param ml_model_id: The ID of the machine learning model to remove. + :type ml_model_id: str + :return: A tuple containing a boolean indicating success and an error message if the client ID or model ID is invalid. + """ + if client_id not in self._clients: + return False, "The given client_id is invalid." + + if ml_model_id not in self._clients[client_id]: + return False, f"The client with client_id: {client_id} doesn't have any model with ml_model_id of {ml_model_id}." + + del self._clients[client_id][ml_model_id] + return True, None + + def get_ml_models(self, client_id): + """ + Retrieve a list of all machine learning model IDs associated with a given client. + + :param client_id: The ID of the client. + :type client_id: str + :return: A list of machine learning model IDs. + """ + return [id for id in self._clients[client_id].keys()] diff --git a/tests/test_ml_streaming/run_server.py b/tests/test_ml_streaming/run_server.py index f2739e83..17a1502d 100644 --- a/tests/test_ml_streaming/run_server.py +++ b/tests/test_ml_streaming/run_server.py @@ -32,12 +32,17 @@ def main(): x_train, y_train, _, _ = prepare_simple_regression_datasets() linear_regression = LinearRegression() linear_regression.fit(x_train, y_train) - communicator = PymiloServer( - model=linear_regression, + ps = PymiloServer( port=9000, compressor=Compression[args.compression], communication_protocol= CommunicationProtocol[args.protocol], - ).communicator + ) + sample_client_id = "0x_demo_client_id" + sample_ml_model_id = "0x_demo_ml_model_id" + ps.init_client(sample_client_id) + ps.init_ml_model(sample_client_id, sample_ml_model_id) + ps.set_ml_model(sample_client_id, sample_ml_model_id, linear_regression) + communicator = ps.communicator else: communicator = PymiloServer( port=8000, diff --git a/tests/test_ml_streaming/scenarios/scenario1.py b/tests/test_ml_streaming/scenarios/scenario1.py index 799366d7..abddb28e 100644 --- a/tests/test_ml_streaming/scenarios/scenario1.py +++ b/tests/test_ml_streaming/scenarios/scenario1.py @@ -28,16 +28,20 @@ def scenario1(compression_method, communication_protocol): communication_protocol=CommunicationProtocol[communication_protocol], ) - # 3. + # 3. get client id + get ml model id [from remote server] + client.register() + client.register_ml_model() + + # 4. result = client.predict(x_test) mse_before = mean_squared_error(y_test, result) - # 4. - client.upload() # 5. + client.upload() + # 6. client.download() - # 6. + # 7. result = client.predict(x_test) mse_after = mean_squared_error(y_test, result) return np.abs(mse_after-mse_before) diff --git a/tests/test_ml_streaming/scenarios/scenario2.py b/tests/test_ml_streaming/scenarios/scenario2.py index e791f43c..2cf83148 100644 --- a/tests/test_ml_streaming/scenarios/scenario2.py +++ b/tests/test_ml_streaming/scenarios/scenario2.py @@ -25,22 +25,26 @@ def scenario2(compression_method, communication_protocol): communication_protocol=CommunicationProtocol[communication_protocol], ) - # 2. - client.upload() + # 2. get client id + get ml model id [from remote server] + client.register() + client.register_ml_model() # 3. + client.upload() + + # 4. client.toggle_mode(PymiloClient.Mode.DELEGATE) client.fit(x_train, y_train) remote_field = client.coef_ - # 4. + # 5. result = client.predict(x_test) mse_server = mean_squared_error(y_test, result) - # 5. + # 6. client.download() - # 6. + # 7. client.toggle_mode(mode=PymiloClient.Mode.LOCAL) local_field = client.coef_ result = client.predict(x_test) diff --git a/tests/test_ml_streaming/scenarios/scenario3.py b/tests/test_ml_streaming/scenarios/scenario3.py index aded5937..cf555e43 100644 --- a/tests/test_ml_streaming/scenarios/scenario3.py +++ b/tests/test_ml_streaming/scenarios/scenario3.py @@ -19,6 +19,9 @@ def scenario3(compression_method, communication_protocol): server_url="127.0.0.1:9000", communication_protocol=CommunicationProtocol[communication_protocol], ) + client.client_id = "0x_demo_client_id" + client.ml_model_id = "0x_demo_ml_model_id" + client.toggle_mode(PymiloClient.Mode.DELEGATE) result = client.predict(x_test) mse_server = mean_squared_error(y_test, result) diff --git a/tests/test_ml_streaming/test_streaming.py b/tests/test_ml_streaming/test_streaming.py index 9a521d1f..075774fb 100644 --- a/tests/test_ml_streaming/test_streaming.py +++ b/tests/test_ml_streaming/test_streaming.py @@ -46,7 +46,7 @@ def prepare_bare_server(request): @pytest.fixture( scope="session", - params=["REST", "WEBSOCKET"]) + params=["REST",]) #"WEBSOCKET"]) def prepare_ml_server(request): communication_protocol = request.param compression_method = "ZLIB"