Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do 12 #18

Merged
merged 6 commits into from
Oct 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions data_owner/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def get_col_spec(self, **kw):

def bind_processor(self, dialect):
def process(value):
x = value.X.tolist() if value.X is not None else None
y = value.y.tolist() if value.y is not None else None
weights = value.weights if type(value.weights) == list else value.weights.tolist() if value and value.weights.any() else None
model_type = value.type
x = value.X.tolist() if value and value.X is not None else None
y = value.y.tolist() if value and value.y is not None else None
weights = value.weights if value and type(value.weights) == list else value.weights.tolist() if value and value.weights.any() else None
model_type = value.type if value else None
return json.dumps({
'x': x, 'y': y, 'weights': weights, 'type': model_type
})
Expand All @@ -56,7 +56,7 @@ class BaseModel(DbEntity):
request_data = Column(JSON)
mse = Column(Float)
initial_mse = Column(Float)
status = Column(String(50), default=TrainingStatus.INITIATED.name)
status = Column(String(50), default=TrainingStatus.WAITING.name)
improvement = Column(Float)
name = Column(String(100))
iterations = Column(Integer)
Expand All @@ -68,14 +68,15 @@ class BaseModel(DbEntity):
user = relationship("User", back_populates="models")
User.models = relationship("Model", back_populates="user")

def __init__(self, model_id, model_type, data, name="default"):
def __init__(self, model_id, model_type, reqs, name="default"):
self.id = model_id
self.model_type = model_type
_model = ModelFactory.get_model(model_type)(X=data[0], y=data[1])
_model = ModelFactory.get_model(self.model_type)(X=None, y=None, requirements=reqs)
self.model = _model
self.model.set_weights(_model.weights.tolist())
self.model.type = model_type
self.status = TrainingStatus.INITIATED.name
self.model.type = self.model_type
self.requirements = reqs
self.status = TrainingStatus.WAITING.name
self.iterations = 0
self.improvement = 0.0
self.name = name
Expand All @@ -84,6 +85,12 @@ def __init__(self, model_id, model_type, data, name="default"):
self.initial_mse = 0.0
self.mse_history = []

def link_to_dataset(self, data):
_model = ModelFactory.get_model(self.model_type)(X=data[0], y=data[1])
self.model = _model
self.status = TrainingStatus.INITIATED.name
self.update()

def set_weights(self, weights):
if type(weights) == list:
weights = np.asarray(weights)
Expand Down Expand Up @@ -125,6 +132,15 @@ def add_mse(self, mse):
def find_all(cls):
return DbEntity.find(BaseModel)

@classmethod
def find_all_by(self, filters):
return DbEntity.find(BaseModel, filters)

@classmethod
def find_by_status(self, status):
filters = {'status': status}
return DbEntity.find(BaseModel, filters)


class Model(BaseModel):

Expand Down
2 changes: 1 addition & 1 deletion data_owner/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class User(DbEntity):
name = Column(String(100))
token = Column(String(100))

def __init__(self, external_id, email, name, token):
def __init__(self, email, name, token, external_id=None):
self.external_id = external_id
self.email = email
self.name = name
Expand Down
9 changes: 6 additions & 3 deletions data_owner/resources/models_resource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flask_restplus import Resource, Namespace, fields

from flask_restplus import Resource, Namespace, fields, reqparse
from data_owner.models.model import TrainingStatus
from data_owner.services.model_service import ModelService

api = Namespace('models', description='Model related operations')
Expand Down Expand Up @@ -62,7 +62,10 @@ class ModelsResources(Resource):

@api.marshal_list_with(model_reduced_response)
def get(self):
return ModelService.get_all()
parser = reqparse.RequestParser()
parser.add_argument('status', type=str, required=False, help='Status cannot be converted', location='args')
args = parser.parse_args()
return ModelService.get_all(args)


@api.route('/<model_id>', endpoint='model_resources_ep')
Expand Down
26 changes: 19 additions & 7 deletions data_owner/resources/trainings_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
link = api.model(name='Link', model={
'model_id': fields.String(required=True, description='The model identifier'),
'data_owner_id': fields.String(required=True, description='The model identifier'),
'has_dataset': fields.Boolean(required=True, description='The model weights')
'linked': fields.Boolean(required=True, description='The model weights')
})

metric = api.model(name='Metric', model={
Expand Down Expand Up @@ -82,22 +82,23 @@
@api.route('', endpoint='training_resources_ep')
class TrainingResources(Resource):

@api.doc('Initialize new model with existing dataset')
@api.doc('Save new training request')
@api.expect(data_metadata)
@api.marshal_with(link, code=201)
def post(self):
data = request.get_json()
training_id = data['model_id']
reqs = data['requirements']
# TODO: For now i'm creating the training and linking the dataset to the training all at once
# TODO: and doing it in the back, but a future change will be to do those in separate API calls.
model_id, do_id, has_dataset = data_owner.link_model_to_dataset(training_id, data['model_type'], reqs)
return {'model_id': model_id, 'data_owner_id': do_id, 'has_dataset': has_dataset}
data_owner.init_model(training_id, data['model_type'], reqs)


@api.route('/<model_id>', endpoint='training_resource_ep')
class TrainingResource(Resource):

@api.doc('Get if data owner is training the model')
@api.marshal_with(link, code=201)
def get(self, model_id):
return {'model_id': model_id, 'data_owner_id': data_owner.get_id(), 'linked': data_owner.model_is_linked(model_id)}

@api.doc('Get gradient updated')
@api.marshal_with(update, code=200)
def post(self, model_id):
Expand Down Expand Up @@ -136,3 +137,14 @@ def put(self, model_id):
data = request.get_json()
data_owner.update_mse(model_id, data['mse'])
return 200


@api.route('/<model_id>/accept', endpoint='accept_training_resource_ep')
class MetricsResource(Resource):

@api.doc('Initialize new model with existing dataset')
@api.marshal_with(link, code=200)
def put(self, model_id):
data = request.get_json()
model_id, do_id, has_dataset = data_owner.link_model_to_dataset(model_id)
return {'linked': has_dataset}
4 changes: 2 additions & 2 deletions data_owner/resources/users_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
'external_id': fields.String(required=True, description='The user identifier'),
'name': fields.String(required=True, description='The user name'),
'email': fields.String(required=True, description='The user email'),
'token': fields.String(required=True, description='The user token'),
'models': fields.Nested(model, required=True, description='The user models')
'token': fields.String(required=True, description='The user token')#,
#'models': fields.Nested(model, required=True, description='The user models')
})


Expand Down
20 changes: 15 additions & 5 deletions data_owner/services/data_owner_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from commons.utils.singleton import Singleton
from commons.operations_utils.functions import deserialize, serialize
from data_owner.domain.data_owner import DataOwner
from data_owner.models.model import Model
from data_owner.models.model import Model, TrainingStatus
from data_owner.services.datasets_service import DatasetsService
from data_owner.services.federated_aggregator_connector import FederatedAggregatorConnector

Expand Down Expand Up @@ -110,11 +110,21 @@ def update_mse(self, model_id, mse):
model_orm.update()
logging.info("Calculated mse: {}".format(mse))

def link_model_to_dataset(self, model_id, model_type, reqs):
def link_model_to_dataset(self, model_id):
has_dataset = False
dataset = DatasetsService().get_dataset_for_training(reqs)
model = Model.get(model_id)
dataset = DatasetsService().get_dataset_for_training(model.requirements)
if not dataset:
return model_id, self.get_id(), has_dataset
model = Model(model_id, model_type, dataset)
model.save()
model.link_to_dataset(dataset)
model.update()
self.federated_aggregator_connector.accept_model_training(self.get_id(), model_id)
return model_id, self.get_id(), not has_dataset

def model_is_linked(self, model_id):
return Model.get(model_id).status != TrainingStatus.WAITING

def init_model(self, model_id, model_type, reqs):
model = Model(model_id, model_type, reqs)
model.save()
return model_id, self.get_id()
13 changes: 13 additions & 0 deletions data_owner/services/federated_aggregator_connector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import requests
import logging
from commons.utils.async_thread_pool_executor import AsyncThreadPoolExecutor


class FederatedAggregatorConnector:

def __init__(self, config):
self.federated_aggregator_host = config['FEDERATED_AGGREGATOR_HOST']
self.async_thread_pool = AsyncThreadPoolExecutor()

def register(self, client_id):
"""
Expand All @@ -19,6 +21,17 @@ def register(self, client_id):
response.raise_for_status()
return response.status_code == requests.codes.ok

def accept_model_training(self, client_id, model_id):
server_register_url = self.federated_aggregator_host + "/model/" + model_id + "/accept"
logging.info("Register client {} to server {}".format(client_id, server_register_url))
args = [{'url': server_register_url, 'payload': {'data_owner_id': client_id, 'model_id': model_id}}]
self.async_thread_pool.run(executable=self.send_accept_request, args=args)

def send_accept_request(self, args):
response = requests.post(args['url'], json=args['payload'])
response.raise_for_status()
return response.status_code == requests.codes.ok

def send_prediction(self, prediction):
server_register_url = self.federated_aggregator_host + "/prediction"
logging.info("Send prediction")
Expand Down
6 changes: 3 additions & 3 deletions data_owner/services/model_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from data_owner.models.model import Model


class ModelService:

@classmethod
def get_all(cls):
return Model.find_all()
def get_all(cls, args):
filters = {k: v for k, v in args.items() if v is not None}
return Model.find_all_by(filters)

@classmethod
def get(cls, model_id):
Expand Down
4 changes: 2 additions & 2 deletions data_owner/services/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get(user_id):

@staticmethod
def update(user_id, user_data):
return User().partial_update(user_id, user_data)
return UserService.get(user_id).partial_update(user_id, user_data)

def delete(self, user_id):
user = self.get(user_id)
Expand All @@ -51,4 +51,4 @@ def login(data):
email=user_info["email"],
token=token)
user.save()
return user
return User.find_one_by_external_id(user_external_id)