This repository has been archived by the owner on Dec 2, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Milestone 1: Updating Oppia-ML pipeline (#32)
* Milestone 1: Updating Oppia-ML pipeline * Disable lint (no-name-in-module) for importing generated proto files * Refactor protobuf implementation and address review comments * Fix lint tests * Addressed review comments * Fix __init__.py inclusion * Address review comments * Address Review Comments * Correct doc string * Nit changes
- Loading branch information
1 parent
425bd4a
commit 27fd0cf
Showing
20 changed files
with
387 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,5 @@ | |
*.swp | ||
*.swo | ||
third_party/* | ||
core/domain/proto/*.py | ||
!core/domain/proto/__init__.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// coding: utf-8 | ||
// | ||
// Copyright 2020 The Oppia Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS-IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
syntax = "proto3"; | ||
|
||
message TextClassifierFrozenModel { | ||
// The parameters of a trained text classifier model which are necessary | ||
// for inference. | ||
string model_json = 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// coding: utf-8 | ||
// | ||
// Copyright 2020 The Oppia Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS-IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
syntax = "proto3"; | ||
|
||
import "core/domain/proto/text_classifier.proto"; | ||
|
||
// Training job response payload contains job result of the training job | ||
// along with other metadata items such as vm_id (to identify which VM executed | ||
// this job) and signature of the payload for security purpose. | ||
message TrainingJobResponsePayload { | ||
// Job result of the training job. Job result contains the ID of the Job and | ||
// trained model (frozen model) of the job. | ||
message JobResult { | ||
// Id of the training job whose data is being stored. | ||
string job_id = 1; | ||
|
||
// Each of the classifier algorithms' proto message must be present in | ||
// the oneof classifier_data field. | ||
oneof classifier_frozen_model { | ||
TextClassifierFrozenModel text_classifier = 2; | ||
} | ||
} | ||
JobResult job_result = 1; | ||
|
||
// Id of the VM instance that trained the job. | ||
string vm_id = 2; | ||
|
||
// Signature of the job data for authenticated communication. | ||
string signature = 3; | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# coding: utf-8 | ||
# | ||
# Copyright 2020 The Oppia Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Functions and classdefs related to protobuf files used in Oppia-ml""" | ||
|
||
from core.classifiers import algorithm_registry | ||
from core.domain.proto import training_job_response_payload_pb2 | ||
|
||
class TrainingJobResult(object): | ||
"""TrainingJobResult domain object. | ||
This domain object stores the of training job result along with job_id and | ||
algorithm_id. The training job result is the trained classifier data. | ||
""" | ||
|
||
def __init__(self, job_id, algorithm_id, classifier_data): | ||
"""Initializes TrainingJobResult object. | ||
Args: | ||
job_id: str. The id of the training job whose results are stored | ||
in classifier_data. | ||
algorithm_id: str. The id of the algorithm of the training job. | ||
classifier_data: object. Frozen model of the corresponding | ||
training job. | ||
""" | ||
self.job_id = job_id | ||
self.algorithm_id = algorithm_id | ||
self.classifier_data = classifier_data | ||
|
||
def validate(self): | ||
"""Validate that TrainigJobResult object stores correct data. | ||
Raises: | ||
Exception: str. The classifier data is stored in a field | ||
that does not correspond to algorithm_id. | ||
""" | ||
|
||
# Ensure that the classifier_data is corresponds to the classifier | ||
# having given algorithm_id. | ||
classifier = algorithm_registry.Registry.get_classifier_by_algorithm_id( | ||
self.algorithm_id) | ||
if ( | ||
type(self.classifier_data).__name__ != | ||
classifier.type_in_job_result_proto): | ||
raise Exception( | ||
"Expected classifier data of type %s but found %s type" % ( | ||
classifier.type_in_job_result_proto, | ||
type(self.classifier_data).__name__)) | ||
|
||
def to_proto(self): | ||
"""Generate TrainingJobResult protobuf object from the TrainingJobResult | ||
domain object. | ||
Returns: | ||
TrainingJobResult protobuf object. Protobuf object corresponding to | ||
TrainingJobResult protobuf message definition. | ||
""" | ||
self.validate() | ||
proto_message = ( | ||
training_job_response_payload_pb2. | ||
TrainingJobResponsePayload.JobResult()) | ||
proto_message.job_id = self.job_id | ||
job_result_attribute = ( | ||
algorithm_registry.Registry.get_classifier_by_algorithm_id( | ||
self.algorithm_id).name_in_job_result_proto) | ||
getattr(proto_message, job_result_attribute).CopyFrom( | ||
self.classifier_data) | ||
return proto_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# coding: utf-8 | ||
# | ||
# Copyright 2020 The Oppia Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for training job result domain""" | ||
|
||
from core.classifiers import algorithm_registry | ||
from core.domain import training_job_result_domain | ||
from core.domain.proto import training_job_response_payload_pb2 | ||
from core.domain.proto import text_classifier_pb2 | ||
from core.tests import test_utils | ||
|
||
class TrainingJobResultTests(test_utils.GenericTestBase): | ||
"""Tests for TrainingJobResult domain object.""" | ||
|
||
def test_validate_job_data_with_valid_model_does_not_raise_exception(self): # pylint: disable=no-self-use | ||
"""Ensure that validation checks do not raise exceptions when | ||
a valid classifier model is supplied. | ||
""" | ||
job_id = 'job_id' | ||
algorithm_id = 'TextClassifier' | ||
classifier_data = text_classifier_pb2.TextClassifierFrozenModel() | ||
classifier_data.model_json = 'dummy model' | ||
job_result = training_job_result_domain.TrainingJobResult( | ||
job_id, algorithm_id, classifier_data) | ||
job_result.validate() | ||
|
||
def test_validate_job_data_with_invalid_model_raises_exception(self): | ||
"""Ensure that validation checks raise exception when | ||
an invalid classifier model is supplied. | ||
""" | ||
job_id = 'job_id' | ||
algorithm_id = 'TextClassifier' | ||
classifier_data = 'simple classifier' | ||
job_result = training_job_result_domain.TrainingJobResult( | ||
job_id, algorithm_id, classifier_data) | ||
with self.assertRaisesRegexp( | ||
Exception, | ||
'Expected classifier data of type TextClassifier'): | ||
job_result.validate() | ||
|
||
def test_that_all_algorithms_have_job_result_information(self): | ||
"""Test that all algorithms have properties to identify name and type | ||
of attribute in job result proto which stores classifier data for that | ||
algorithm. | ||
""" | ||
job_result_proto = ( | ||
training_job_response_payload_pb2. | ||
TrainingJobResponsePayload.JobResult()) | ||
for classifier in algorithm_registry.Registry.get_all_classifiers(): | ||
self.assertIsNotNone(classifier.name_in_job_result_proto) | ||
attribute_type_name = type(getattr( | ||
job_result_proto, classifier.name_in_job_result_proto)).__name__ | ||
self.assertEqual( | ||
attribute_type_name, classifier.type_in_job_result_proto) | ||
|
||
def test_that_training_job_result_proto_is_generated_with_correct_details( | ||
self): | ||
"""Ensure that the JobResult proto is correctly generated from | ||
TrainingJobResult domain object. | ||
""" | ||
classifier_data = text_classifier_pb2.TextClassifierFrozenModel() | ||
classifier_data.model_json = 'dummy model' | ||
job_id = 'job_id' | ||
algorithm_id = 'TextClassifier' | ||
classifier = algorithm_registry.Registry.get_classifier_by_algorithm_id( | ||
algorithm_id) | ||
job_result = training_job_result_domain.TrainingJobResult( | ||
job_id, algorithm_id, classifier_data) | ||
job_result_proto = job_result.to_proto() | ||
|
||
# Lint test for no-member needs to be disabled as protobuf generated | ||
# classes are metaclasses and hence their attributes are defined at | ||
# runtime. | ||
self.assertEqual(job_result_proto.job_id, job_id) # pylint: disable=no-member | ||
self.assertEqual( | ||
job_result_proto.WhichOneof('classifier_frozen_model'), | ||
classifier.name_in_job_result_proto) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,3 +24,4 @@ scikit-learn 0.18.1 | |
pylint 1.7.1 | ||
requests 2.17.1 | ||
responses 0.5.1 | ||
protobuf 3.12.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
protoc: | ||
version: 3.8.0 | ||
lint: | ||
group: google | ||
generate: | ||
plugins: | ||
- name: python | ||
output: ./ |
Oops, something went wrong.