From 173af55a8f0fbfeb78e88b913867f2ad1fa0ad78 Mon Sep 17 00:00:00 2001 From: Rohit Katlaa <42460632+rohitkatlaa@users.noreply.github.com> Date: Sat, 22 May 2021 23:41:54 +0530 Subject: [PATCH] Migrate generate_signature method (#40) * migrate generate_signature method * removing additional assert * adding utf-8 to encode * adding encoding to encode * adding encoding * change vm_id to str --- core/domain/remote_access_services.py | 17 +++++++++-------- core/domain/remote_access_services_test.py | 5 ++++- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/core/domain/remote_access_services.py b/core/domain/remote_access_services.py index 2b3039b..b4ca00a 100644 --- a/core/domain/remote_access_services.py +++ b/core/domain/remote_access_services.py @@ -68,14 +68,15 @@ def generate_signature(message, vm_id): """Generates digital signature for given message combined with vm_id. Args: - message: str. Message string. + message: bytes. Message string. vm_id: str. ID of the VM that trained the job. Returns: str. The digital signature generated from request data. """ - msg = '%s|%s' % (base64.b64encode(message), vm_id) - key = _get_shared_secret() + encoded_vm_id = vm_id.encode(encoding='utf-8') + msg = b'%s|%s' % (base64.b64encode(message), encoded_vm_id) + key = _get_shared_secret().encode(encoding='utf-8') # Generate signature and return it. return hmac.new(key, msg, digestmod=hashlib.sha256).hexdigest() @@ -93,8 +94,8 @@ def fetch_next_job_request(): _get_url(), _get_port(), vmconf.FETCH_NEXT_JOB_REQUEST_HANDLER) payload = { - 'vm_id': _get_vm_id(), - 'message': _get_vm_id(), + 'vm_id': _get_vm_id().encode(encoding='utf-8'), + 'message': _get_vm_id().encode(encoding='utf-8'), } signature = generate_signature(payload['message'], payload['vm_id']) payload['signature'] = signature @@ -119,9 +120,9 @@ def store_trained_classifier_model(job_result): job_result.validate() payload = training_job_response_payload_pb2.TrainingJobResponsePayload() payload.job_result.CopyFrom(job_result.to_proto()) - payload.vm_id = _get_vm_id() - signature = generate_signature( - payload.job_result.SerializeToString(), payload.vm_id) + payload.vm_id = _get_vm_id().encode(encoding='utf-8') + message = payload.job_result.SerializeToString().encode(encoding='utf-8') + signature = generate_signature(message, payload.vm_id) payload.signature = signature data = payload.SerializeToString() diff --git a/core/domain/remote_access_services_test.py b/core/domain/remote_access_services_test.py index 75b7f8a..6176d41 100644 --- a/core/domain/remote_access_services_test.py +++ b/core/domain/remote_access_services_test.py @@ -29,8 +29,11 @@ class RemoteAccessServicesTests(test_utils.GenericTestBase): def test_that_generate_signature_works_correctly(self): """Test that generate signature function is working as expected.""" with self.swap(vmconf, 'DEFAULT_VM_SHARED_SECRET', '1a2b3c4e'): + message = 'vm_default' + vm_id = 'vm_default' signature = remote_access_services.generate_signature( - 'vm_default', vm_id='vm_default') + message.encode(encoding='utf-8'), + vm_id.encode(encoding='utf-8')) expected_signature = ( '740ed25befc87674a82844db7769436edb7d21c29d1c9cc87d7a1f3fdefe3610')