Skip to content

Commit

Permalink
Merge pull request #228 from thodkatz/add-save-export-to-training-ser…
Browse files Browse the repository at this point in the history
…vicer

Add save and export to training servicer
  • Loading branch information
thodkatz authored Jan 20, 2025
2 parents 90b4ce0 + 6e4ec16 commit 5f5fbd1
Show file tree
Hide file tree
Showing 10 changed files with 433 additions and 62 deletions.
15 changes: 14 additions & 1 deletion proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ service Training {

rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc Export(ModelSession) returns (Empty) {}
rpc Save(SaveRequest) returns (Empty) {}

rpc Export(ExportRequest) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

Expand Down Expand Up @@ -59,6 +61,17 @@ message GetLogsResponse {
}


message SaveRequest {
ModelSession modelSessionId = 1;
string filePath = 2;
}


message ExportRequest {
ModelSession modelSessionId = 1;
string filePath = 2;
}

message ValidationResponse {
double validation_score_average = 1;
}
Expand Down
124 changes: 114 additions & 10 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import time
from pathlib import Path
from typing import Callable
from typing import Callable, Optional

import grpc
import h5py
Expand Down Expand Up @@ -41,8 +41,11 @@ def grpc_stub_cls():
return training_pb2_grpc.TrainingStub


def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: str = "cpu"):
return f"""
def unet2d_config_path(
checkpoint_dir: Path, train_data_dir: str, val_data_path: str, resume: Optional[str] = None, device: str = "cpu"
):
# todo: upsampling makes model torchscript incompatible
config = f"""
device: {device} # Use CPU for faster test execution, change to 'cuda' if GPU is available and necessary
model:
name: UNet2D
Expand All @@ -53,13 +56,14 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st
num_groups: 4
final_sigmoid: false
is_segmentation: true
upsample: default
trainer:
checkpoint_dir: {checkpoint_dir}
resume: null
validate_after_iters: 2
resume: {resume if resume else "null"}
validate_after_iters: 250
log_after_iters: 2
max_num_epochs: 1000
max_num_iterations: 10000
max_num_epochs: 10000
max_num_iterations: 100000
eval_score_higher_is_better: True
optimizer:
learning_rate: 0.0002
Expand Down Expand Up @@ -149,6 +153,7 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st
- name: ToTensor
expand_dims: false
"""
return config


def create_random_dataset(shape, channel_per_class):
Expand All @@ -171,15 +176,22 @@ def create_random_dataset(shape, channel_per_class):
return tmp.name


def prepare_unet2d_test_environment(device: str = "cpu") -> str:
def prepare_unet2d_test_environment(resume: Optional[str] = None, device: str = "cpu") -> str:
checkpoint_dir = Path(tempfile.mkdtemp())

shape = (3, 1, 128, 128)
in_channel = 3
z = 1 # 2d
y = 128
x = 128
shape = (in_channel, z, y, x)
binary_loss = False
train = create_random_dataset(shape, binary_loss)
val = create_random_dataset(shape, binary_loss)

return unet2d_config_path(checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device)
config = unet2d_config_path(
resume=resume, checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device
)
return config


class TestTrainingServicer:
Expand Down Expand Up @@ -561,6 +573,98 @@ def test_forward_invalid_dims(self, grpc_stub):
grpc_stub.Predict(predict_request)
assert "Tensor dims should be" in excinfo.value.details()

def test_save_while_running(self, grpc_stub):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth"
save_request = training_pb2.SaveRequest(
modelSessionId=training_session_id, filePath=str(model_checkpoint_file)
)
grpc_stub.Save(save_request)
assert model_checkpoint_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING)

# assume stopping training to release devices
grpc_stub.CloseTrainerSession(training_session_id)

# attempt to init a new model with the new checkpoint and start training
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file))
)
grpc_stub.Start(training_session_id)

def test_save_while_paused(self, grpc_stub):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)
time.sleep(1)
grpc_stub.Pause(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth"
save_request = training_pb2.SaveRequest(
modelSessionId=training_session_id, filePath=str(model_checkpoint_file)
)
grpc_stub.Save(save_request)
assert model_checkpoint_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

# assume stopping training to release devices
grpc_stub.CloseTrainerSession(training_session_id)

# attempt to init a new model with the new checkpoint and start training
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file))
)
grpc_stub.Start(training_session_id)

def test_export_while_running(self, grpc_stub):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip"
export_request = training_pb2.ExportRequest(
modelSessionId=training_session_id, filePath=str(model_export_file)
)
grpc_stub.Export(export_request)
assert model_export_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

# assume stopping training since model is exported
grpc_stub.CloseTrainerSession(training_session_id)

def test_export_while_paused(self, grpc_stub):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)
time.sleep(1)
grpc_stub.Pause(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip"
export_request = training_pb2.ExportRequest(
modelSessionId=training_session_id, filePath=str(model_export_file)
)
grpc_stub.Export(export_request)
assert model_export_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

# assume stopping training since model is exported
grpc_stub.CloseTrainerSession(training_session_id)

def test_close_session(self, grpc_stub):
"""
Test closing a training session.
Expand Down
30 changes: 17 additions & 13 deletions tiktorch/proto/training_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 36 additions & 3 deletions tiktorch/proto/training_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def __init__(self, channel):
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=training__pb2.GetLogsResponse.FromString,
)
self.Save = channel.unary_unary(
'/training.Training/Save',
request_serializer=training__pb2.SaveRequest.SerializeToString,
response_deserializer=utils__pb2.Empty.FromString,
)
self.Export = channel.unary_unary(
'/training.Training/Export',
request_serializer=utils__pb2.ModelSession.SerializeToString,
request_serializer=training__pb2.ExportRequest.SerializeToString,
response_deserializer=utils__pb2.Empty.FromString,
)
self.Predict = channel.unary_unary(
Expand Down Expand Up @@ -117,6 +122,12 @@ def GetLogs(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Save(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Export(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand Down Expand Up @@ -179,9 +190,14 @@ def add_TrainingServicer_to_server(servicer, server):
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=training__pb2.GetLogsResponse.SerializeToString,
),
'Save': grpc.unary_unary_rpc_method_handler(
servicer.Save,
request_deserializer=training__pb2.SaveRequest.FromString,
response_serializer=utils__pb2.Empty.SerializeToString,
),
'Export': grpc.unary_unary_rpc_method_handler(
servicer.Export,
request_deserializer=utils__pb2.ModelSession.FromString,
request_deserializer=training__pb2.ExportRequest.FromString,
response_serializer=utils__pb2.Empty.SerializeToString,
),
'Predict': grpc.unary_unary_rpc_method_handler(
Expand Down Expand Up @@ -328,6 +344,23 @@ def GetLogs(request,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Save(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/training.Training/Save',
training__pb2.SaveRequest.SerializeToString,
utils__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Export(request,
target,
Expand All @@ -340,7 +373,7 @@ def Export(request,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/training.Training/Export',
utils__pb2.ModelSession.SerializeToString,
training__pb2.ExportRequest.SerializeToString,
utils__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
Expand Down
13 changes: 9 additions & 4 deletions tiktorch/server/session/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from abc import ABC
from concurrent.futures import Future
from pathlib import Path

from bioimageio.core import PredictionPipeline, Sample

Expand Down Expand Up @@ -81,11 +82,15 @@ def start_training(self) -> None:
self._queue_tasks.send_command(start_cmd.awaitable)
start_cmd.awaitable.wait()

def save(self) -> None:
raise NotImplementedError
def save(self, file_path: Path) -> None:
save_cmd = commands.SaveTrainingCmd(file_path)
self._queue_tasks.send_command(save_cmd.awaitable)
save_cmd.awaitable.wait()

def export(self) -> None:
raise NotImplementedError
def export(self, file_path: Path) -> None:
export_cmd = commands.ExportTrainingCmd(file_path)
self._queue_tasks.send_command(export_cmd.awaitable)
export_cmd.awaitable.wait()

def get_state(self) -> TrainerState:
return self._supervisor.get_state()
19 changes: 19 additions & 0 deletions tiktorch/server/session/backend/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import typing
from dataclasses import dataclass, field
from pathlib import Path
from typing import Generic, Type, TypeVar

from tiktorch.trainer import TrainerAction, TrainerState
Expand Down Expand Up @@ -131,6 +132,24 @@ def execute(self, ctx: Context) -> None:
pass


class ExportTrainingCmd(ICommand):
def __init__(self, file_path: Path):
super().__init__()
self._file_path = file_path

def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.export(self._file_path)


class SaveTrainingCmd(ICommand):
def __init__(self, file_path: Path):
super().__init__()
self._file_path = file_path

def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.save(self._file_path)


class ShutdownWithTeardownCmd(ShutdownCmd):
def execute(self, ctx: Context[Supervisors]) -> None:
ctx.session.shutdown()
Expand Down
Loading

0 comments on commit 5f5fbd1

Please sign in to comment.