Skip to content

Commit

Permalink
Merge pull request #225 from thodkatz/add-training-servicer
Browse files Browse the repository at this point in the history
Add training service
  • Loading branch information
thodkatz authored Jan 16, 2025
2 parents 5ea5d3a + c699344 commit 57df547
Show file tree
Hide file tree
Showing 33 changed files with 2,169 additions and 656 deletions.
4 changes: 2 additions & 2 deletions examples/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import grpc

from tiktorch.proto import inference_pb2, inference_pb2_grpc
from tiktorch.proto import inference_pb2_grpc, utils_pb2


def run():
with grpc.insecure_channel("127.0.0.1:5567") as channel:
stub = inference_pb2_grpc.InferenceStub(channel)
response = stub.ListDevices(inference_pb2.Empty())
response = stub.ListDevices(utils_pb2.Empty())
print(response)


Expand Down
35 changes: 5 additions & 30 deletions proto/inference.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
syntax = "proto3";

package inference;

import "utils.proto";


service Inference {
rpc CreateModelSession(CreateModelSessionRequest) returns (ModelSession) {}

Expand All @@ -14,15 +19,6 @@ service Inference {
rpc Predict(PredictRequest) returns (PredictResponse) {}
}

message Device {
enum Status {
AVAILABLE = 0;
IN_USE = 1;
}

string id = 1;
Status status = 2;
}

message CreateDatasetDescriptionRequest {
string modelSessionId = 1;
Expand Down Expand Up @@ -76,26 +72,6 @@ message LogEntry {
string content = 3;
}

message Devices {
repeated Device devices = 1;
}

message NamedInt {
uint32 size = 1;
string name = 2;
}

message NamedFloat {
float size = 1;
string name = 2;
}

message Tensor {
bytes buffer = 1;
string dtype = 2;
string tensorId = 3;
repeated NamedInt shape = 4;
}

message PredictRequest {
string modelSessionId = 1;
Expand All @@ -107,7 +83,6 @@ message PredictResponse {
repeated Tensor tensors = 1;
}

message Empty {}

service FlightControl {
rpc Ping(Empty) returns (Empty) {}
Expand Down
95 changes: 95 additions & 0 deletions proto/training.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
syntax = "proto3";

package training;

import "utils.proto";



service Training {
rpc ListDevices(Empty) returns (Devices) {}

rpc Init(TrainingConfig) returns (TrainingSessionId) {}

rpc Start(TrainingSessionId) returns (Empty) {}

rpc Resume(TrainingSessionId) returns (Empty) {}

rpc Pause(TrainingSessionId) returns (Empty) {}

rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {}

rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {}

rpc Save(TrainingSessionId) returns (Empty) {}

rpc Export(TrainingSessionId) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {}

rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {}
}

message TrainingSessionId {
string id = 1;
}

message Logs {
enum ModelPhase {
Train = 0;
Eval = 1;
}
ModelPhase mode = 1;
double eval_score = 2;
double loss = 3;
uint32 iteration = 4;
}


message StreamUpdateResponse {
uint32 best_model_idx = 1;
Logs logs = 2;
}


message GetLogsResponse {
repeated Logs logs = 1;
}



message PredictRequest {
repeated Tensor tensors = 1;
TrainingSessionId sessionId = 2;
}


message PredictResponse {
repeated Tensor tensors = 1;
}

message ValidationResponse {
double validation_score_average = 1;
}

message GetStatusResponse {
enum State {
Idle = 0;
Running = 1;
Paused = 2;
Failed = 3;
Finished = 4;
}
State state = 1;
}


message GetCurrentBestModelIdxResponse {
uint32 id = 1;
}

message TrainingConfig {
string yaml_content = 1;
}
34 changes: 34 additions & 0 deletions proto/utils.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
syntax = "proto3";

message Empty {}

message NamedInt {
uint32 size = 1;
string name = 2;
}

message NamedFloat {
float size = 1;
string name = 2;
}

message Tensor {
bytes buffer = 1;
string dtype = 2;
string tensorId = 3;
repeated NamedInt shape = 4;
}

message Device {
enum Status {
AVAILABLE = 0;
IN_USE = 1;
}

string id = 1;
Status status = 2;
}

message Devices {
repeated Device devices = 1;
}
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
python_files = test_*.py
addopts =
--timeout 10
--timeout 60
-v
-s
--color=yes
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ max-line-length = 120

[flake8]
max-line-length = 120
ignore=E203
ignore=E203,W503
exclude = tiktorch/proto/*,vendor
34 changes: 17 additions & 17 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
xarray_to_pb_tensor,
xr_tensors_to_sample,
)
from tiktorch.proto import inference_pb2
from tiktorch.proto import utils_pb2


def _numpy_to_pb_tensor(arr, tensor_id: str = "dummy_tensor_name"):
"""
Makes sure that tensor was serialized/deserialized
"""
tensor = numpy_to_pb_tensor(tensor_id, arr)
parsed = inference_pb2.Tensor()
parsed = utils_pb2.Tensor()
parsed.ParseFromString(tensor.SerializeToString())
return parsed

Expand All @@ -31,7 +31,7 @@ def to_pb_tensor(tensor_id: str, arr: xr.DataArray):
Makes sure that tensor was serialized/deserialized
"""
tensor = xarray_to_pb_tensor(tensor_id, arr)
parsed = inference_pb2.Tensor()
parsed = utils_pb2.Tensor()
parsed.ParseFromString(tensor.SerializeToString())
return parsed

Expand All @@ -40,7 +40,7 @@ class TestNumpyToPBTensor:
def test_should_serialize_to_tensor_type(self):
arr = np.arange(9)
tensor = _numpy_to_pb_tensor(arr)
assert isinstance(tensor, inference_pb2.Tensor)
assert isinstance(tensor, utils_pb2.Tensor)

@pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")])
def test_should_have_dtype_as_str(self, np_dtype, dtype_str):
Expand All @@ -65,12 +65,12 @@ def test_should_have_serialized_bytes(self):

class TestPBTensorToNumpy:
def test_should_raise_on_empty_dtype(self):
tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)])
tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)])
with pytest.raises(ValueError):
pb_tensor_to_numpy(tensor)

def test_should_raise_on_empty_shape(self):
tensor = inference_pb2.Tensor(dtype="int64", shape=[])
tensor = utils_pb2.Tensor(dtype="int64", shape=[])
with pytest.raises(ValueError):
pb_tensor_to_numpy(tensor)

Expand Down Expand Up @@ -109,7 +109,7 @@ class TestXarrayToPBTensor:
def test_should_serialize_to_tensor_type(self):
xarr = xr.DataArray(np.arange(8).reshape((2, 4)), dims=("x", "y"))
pb_tensor = to_pb_tensor("input0", xarr)
assert isinstance(pb_tensor, inference_pb2.Tensor)
assert isinstance(pb_tensor, utils_pb2.Tensor)
assert len(pb_tensor.shape) == 2
dim1 = pb_tensor.shape[0]
dim2 = pb_tensor.shape[1]
Expand Down Expand Up @@ -137,12 +137,12 @@ def test_should_have_serialized_bytes(self):

class TestPBTensorToXarray:
def test_should_raise_on_empty_dtype(self):
tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)])
tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)])
with pytest.raises(ValueError):
pb_tensor_to_xarray(tensor)

def test_should_raise_on_empty_shape(self):
tensor = inference_pb2.Tensor(dtype="int64", shape=[])
tensor = utils_pb2.Tensor(dtype="int64", shape=[])
with pytest.raises(ValueError):
pb_tensor_to_xarray(tensor)

Expand Down Expand Up @@ -178,19 +178,19 @@ def test_should_same_data(self, shape):
class TestSample:
def test_pb_tensors_to_sample(self):
arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32)
tensor_1 = inference_pb2.Tensor(
tensor_1 = utils_pb2.Tensor(
dtype="int64",
tensorId="input1",
buffer=bytes(arr_1),
shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)],
)

arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
tensor_2 = inference_pb2.Tensor(
tensor_2 = utils_pb2.Tensor(
dtype="int64",
tensorId="input2",
buffer=bytes(arr_2),
shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)],
)

sample = pb_tensors_to_sample([tensor_1, tensor_2])
Expand Down Expand Up @@ -218,17 +218,17 @@ def test_sample_to_pb_tensors(self):
tensors_ids = ["input1", "input2"]
sample = xr_tensors_to_sample(tensors_ids, [tensor_1, tensor_2])

pb_tensor_1 = inference_pb2.Tensor(
pb_tensor_1 = utils_pb2.Tensor(
dtype="int64",
tensorId="input1",
buffer=bytes(arr_1),
shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)],
)
pb_tensor_2 = inference_pb2.Tensor(
pb_tensor_2 = utils_pb2.Tensor(
dtype="int64",
tensorId="input2",
buffer=bytes(arr_2),
shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)],
)
expected_tensors = [pb_tensor_1, pb_tensor_2]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_server/test_grpc/test_fligh_control_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def _pinger():
pinger_thread.start()

assert not evt.is_set()
assert not evt.wait(timeout=0.2)
assert not evt.wait(timeout=1)

stop_pinger.set()
assert evt.wait(timeout=0.2)
assert evt.wait(timeout=1)


def test_shutdown_timeout_0_means_no_watchdog():
Expand Down
Loading

0 comments on commit 57df547

Please sign in to comment.