Skip to content
This repository has been archived by the owner on Dec 2, 2024. It is now read-only.

Commit

Permalink
Migrate generate_signature method (#40)
Browse files Browse the repository at this point in the history
* migrate generate_signature method

* removing additional assert

* adding utf-8 to encode

* adding encoding to encode

* adding encoding

* change vm_id to str
  • Loading branch information
rohitkatlaa authored May 22, 2021
1 parent 26ea71e commit 173af55
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
17 changes: 9 additions & 8 deletions core/domain/remote_access_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ def generate_signature(message, vm_id):
"""Generates digital signature for given message combined with vm_id.
Args:
message: str. Message string.
message: bytes. Message string.
vm_id: str. ID of the VM that trained the job.
Returns:
str. The digital signature generated from request data.
"""
msg = '%s|%s' % (base64.b64encode(message), vm_id)
key = _get_shared_secret()
encoded_vm_id = vm_id.encode(encoding='utf-8')
msg = b'%s|%s' % (base64.b64encode(message), encoded_vm_id)
key = _get_shared_secret().encode(encoding='utf-8')

# Generate signature and return it.
return hmac.new(key, msg, digestmod=hashlib.sha256).hexdigest()
Expand All @@ -93,8 +94,8 @@ def fetch_next_job_request():
_get_url(), _get_port(), vmconf.FETCH_NEXT_JOB_REQUEST_HANDLER)

payload = {
'vm_id': _get_vm_id(),
'message': _get_vm_id(),
'vm_id': _get_vm_id().encode(encoding='utf-8'),
'message': _get_vm_id().encode(encoding='utf-8'),
}
signature = generate_signature(payload['message'], payload['vm_id'])
payload['signature'] = signature
Expand All @@ -119,9 +120,9 @@ def store_trained_classifier_model(job_result):
job_result.validate()
payload = training_job_response_payload_pb2.TrainingJobResponsePayload()
payload.job_result.CopyFrom(job_result.to_proto())
payload.vm_id = _get_vm_id()
signature = generate_signature(
payload.job_result.SerializeToString(), payload.vm_id)
payload.vm_id = _get_vm_id().encode(encoding='utf-8')
message = payload.job_result.SerializeToString().encode(encoding='utf-8')
signature = generate_signature(message, payload.vm_id)
payload.signature = signature

data = payload.SerializeToString()
Expand Down
5 changes: 4 additions & 1 deletion core/domain/remote_access_services_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ class RemoteAccessServicesTests(test_utils.GenericTestBase):
def test_that_generate_signature_works_correctly(self):
"""Test that generate signature function is working as expected."""
with self.swap(vmconf, 'DEFAULT_VM_SHARED_SECRET', '1a2b3c4e'):
message = 'vm_default'
vm_id = 'vm_default'
signature = remote_access_services.generate_signature(
'vm_default', vm_id='vm_default')
message.encode(encoding='utf-8'),
vm_id.encode(encoding='utf-8'))

expected_signature = (
'740ed25befc87674a82844db7769436edb7d21c29d1c9cc87d7a1f3fdefe3610')
Expand Down

0 comments on commit 173af55

Please sign in to comment.