diff --git a/proto/training.proto b/proto/training.proto index 1c85b7e3..e0151152 100644 --- a/proto/training.proto +++ b/proto/training.proto @@ -21,7 +21,9 @@ service Training { rpc GetLogs(ModelSession) returns (GetLogsResponse) {} - rpc Export(ModelSession) returns (Empty) {} + rpc Save(SaveRequest) returns (Empty) {} + + rpc Export(ExportRequest) returns (Empty) {} rpc Predict(PredictRequest) returns (PredictResponse) {} @@ -59,6 +61,17 @@ message GetLogsResponse { } +message SaveRequest { + ModelSession modelSessionId = 1; + string filePath = 2; +} + + +message ExportRequest { + ModelSession modelSessionId = 1; + string filePath = 2; +} + message ValidationResponse { double validation_score_average = 1; } diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index 067b4fe5..a5df763b 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -2,7 +2,7 @@ import threading import time from pathlib import Path -from typing import Callable +from typing import Callable, Optional import grpc import h5py @@ -41,8 +41,11 @@ def grpc_stub_cls(): return training_pb2_grpc.TrainingStub -def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: str = "cpu"): - return f""" +def unet2d_config_path( + checkpoint_dir: Path, train_data_dir: str, val_data_path: str, resume: Optional[str] = None, device: str = "cpu" +): + # todo: upsampling makes model torchscript incompatible + config = f""" device: {device} # Use CPU for faster test execution, change to 'cuda' if GPU is available and necessary model: name: UNet2D @@ -53,13 +56,14 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st num_groups: 4 final_sigmoid: false is_segmentation: true + upsample: default trainer: checkpoint_dir: {checkpoint_dir} - resume: null - validate_after_iters: 2 + resume: {resume if resume else "null"} + validate_after_iters: 250 log_after_iters: 2 - max_num_epochs: 1000 - max_num_iterations: 10000 + max_num_epochs: 10000 + max_num_iterations: 100000 eval_score_higher_is_better: True optimizer: learning_rate: 0.0002 @@ -149,6 +153,7 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st - name: ToTensor expand_dims: false """ + return config def create_random_dataset(shape, channel_per_class): @@ -171,15 +176,22 @@ def create_random_dataset(shape, channel_per_class): return tmp.name -def prepare_unet2d_test_environment(device: str = "cpu") -> str: +def prepare_unet2d_test_environment(resume: Optional[str] = None, device: str = "cpu") -> str: checkpoint_dir = Path(tempfile.mkdtemp()) - shape = (3, 1, 128, 128) + in_channel = 3 + z = 1 # 2d + y = 128 + x = 128 + shape = (in_channel, z, y, x) binary_loss = False train = create_random_dataset(shape, binary_loss) val = create_random_dataset(shape, binary_loss) - return unet2d_config_path(checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device) + config = unet2d_config_path( + resume=resume, checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device + ) + return config class TestTrainingServicer: @@ -561,6 +573,98 @@ def test_forward_invalid_dims(self, grpc_stub): grpc_stub.Predict(predict_request) assert "Tensor dims should be" in excinfo.value.details() + def test_save_while_running(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + grpc_stub.Start(training_session_id) + + with tempfile.TemporaryDirectory() as model_checkpoint_dir: + model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth" + save_request = training_pb2.SaveRequest( + modelSessionId=training_session_id, filePath=str(model_checkpoint_file) + ) + grpc_stub.Save(save_request) + assert model_checkpoint_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING) + + # assume stopping training to release devices + grpc_stub.CloseTrainerSession(training_session_id) + + # attempt to init a new model with the new checkpoint and start training + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file)) + ) + grpc_stub.Start(training_session_id) + + def test_save_while_paused(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + grpc_stub.Start(training_session_id) + time.sleep(1) + grpc_stub.Pause(training_session_id) + + with tempfile.TemporaryDirectory() as model_checkpoint_dir: + model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth" + save_request = training_pb2.SaveRequest( + modelSessionId=training_session_id, filePath=str(model_checkpoint_file) + ) + grpc_stub.Save(save_request) + assert model_checkpoint_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) + + # assume stopping training to release devices + grpc_stub.CloseTrainerSession(training_session_id) + + # attempt to init a new model with the new checkpoint and start training + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file)) + ) + grpc_stub.Start(training_session_id) + + def test_export_while_running(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + grpc_stub.Start(training_session_id) + + with tempfile.TemporaryDirectory() as model_checkpoint_dir: + model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip" + export_request = training_pb2.ExportRequest( + modelSessionId=training_session_id, filePath=str(model_export_file) + ) + grpc_stub.Export(export_request) + assert model_export_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) + + # assume stopping training since model is exported + grpc_stub.CloseTrainerSession(training_session_id) + + def test_export_while_paused(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + grpc_stub.Start(training_session_id) + time.sleep(1) + grpc_stub.Pause(training_session_id) + + with tempfile.TemporaryDirectory() as model_checkpoint_dir: + model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip" + export_request = training_pb2.ExportRequest( + modelSessionId=training_session_id, filePath=str(model_export_file) + ) + grpc_stub.Export(export_request) + assert model_export_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) + + # assume stopping training since model is exported + grpc_stub.CloseTrainerSession(training_session_id) + def test_close_session(self, grpc_stub): """ Test closing a training session. diff --git a/tiktorch/proto/training_pb2.py b/tiktorch/proto/training_pb2.py index ee0a9d27..18f1a52e 100644 --- a/tiktorch/proto/training_pb2.py +++ b/tiktorch/proto/training_pb2.py @@ -14,7 +14,7 @@ from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"%\n\x17GetBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\x80\x04\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x31\n\x04Init\x12\x18.training.TrainingConfig\x1a\r.ModelSession\"\x00\x12 \n\x05Start\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12!\n\x06Resume\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12 \n\x05Pause\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12\x42\n\rStreamUpdates\x12\r.ModelSession\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x35\n\x07GetLogs\x12\r.ModelSession\x1a\x19.training.GetLogsResponse\"\x00\x12!\n\x06\x45xport\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x12\x39\n\tGetStatus\x12\r.ModelSession\x1a\x1b.training.GetStatusResponse\"\x00\x12.\n\x13\x43loseTrainerSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"%\n\x17GetBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"F\n\x0bSaveRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"H\n\rExportRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xb3\x04\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x31\n\x04Init\x12\x18.training.TrainingConfig\x1a\r.ModelSession\"\x00\x12 \n\x05Start\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12!\n\x06Resume\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12 \n\x05Pause\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12\x42\n\rStreamUpdates\x12\r.ModelSession\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x35\n\x07GetLogs\x12\r.ModelSession\x1a\x19.training.GetLogsResponse\"\x00\x12\'\n\x04Save\x12\x15.training.SaveRequest\x1a\x06.Empty\"\x00\x12+\n\x06\x45xport\x12\x17.training.ExportRequest\x1a\x06.Empty\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x12\x39\n\tGetStatus\x12\r.ModelSession\x1a\x1b.training.GetStatusResponse\"\x00\x12.\n\x13\x43loseTrainerSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_pb2', globals()) @@ -31,16 +31,20 @@ _STREAMUPDATERESPONSE._serialized_end=294 _GETLOGSRESPONSE._serialized_start=296 _GETLOGSRESPONSE._serialized_end=343 - _VALIDATIONRESPONSE._serialized_start=345 - _VALIDATIONRESPONSE._serialized_end=399 - _GETSTATUSRESPONSE._serialized_start=402 - _GETSTATUSRESPONSE._serialized_end=541 - _GETSTATUSRESPONSE_STATE._serialized_start=473 - _GETSTATUSRESPONSE_STATE._serialized_end=541 - _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=543 - _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=587 - _TRAININGCONFIG._serialized_start=589 - _TRAININGCONFIG._serialized_end=627 - _TRAINING._serialized_start=630 - _TRAINING._serialized_end=1142 + _SAVEREQUEST._serialized_start=345 + _SAVEREQUEST._serialized_end=415 + _EXPORTREQUEST._serialized_start=417 + _EXPORTREQUEST._serialized_end=489 + _VALIDATIONRESPONSE._serialized_start=491 + _VALIDATIONRESPONSE._serialized_end=545 + _GETSTATUSRESPONSE._serialized_start=548 + _GETSTATUSRESPONSE._serialized_end=687 + _GETSTATUSRESPONSE_STATE._serialized_start=619 + _GETSTATUSRESPONSE_STATE._serialized_end=687 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=689 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=733 + _TRAININGCONFIG._serialized_start=735 + _TRAININGCONFIG._serialized_end=773 + _TRAINING._serialized_start=776 + _TRAINING._serialized_end=1339 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/training_pb2_grpc.py b/tiktorch/proto/training_pb2_grpc.py index 5286471d..fffecd37 100644 --- a/tiktorch/proto/training_pb2_grpc.py +++ b/tiktorch/proto/training_pb2_grpc.py @@ -50,9 +50,14 @@ def __init__(self, channel): request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=training__pb2.GetLogsResponse.FromString, ) + self.Save = channel.unary_unary( + '/training.Training/Save', + request_serializer=training__pb2.SaveRequest.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, + ) self.Export = channel.unary_unary( '/training.Training/Export', - request_serializer=utils__pb2.ModelSession.SerializeToString, + request_serializer=training__pb2.ExportRequest.SerializeToString, response_deserializer=utils__pb2.Empty.FromString, ) self.Predict = channel.unary_unary( @@ -117,6 +122,12 @@ def GetLogs(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def Save(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Export(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -179,9 +190,14 @@ def add_TrainingServicer_to_server(servicer, server): request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=training__pb2.GetLogsResponse.SerializeToString, ), + 'Save': grpc.unary_unary_rpc_method_handler( + servicer.Save, + request_deserializer=training__pb2.SaveRequest.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, + ), 'Export': grpc.unary_unary_rpc_method_handler( servicer.Export, - request_deserializer=utils__pb2.ModelSession.FromString, + request_deserializer=training__pb2.ExportRequest.FromString, response_serializer=utils__pb2.Empty.SerializeToString, ), 'Predict': grpc.unary_unary_rpc_method_handler( @@ -328,6 +344,23 @@ def GetLogs(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod + def Save(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Save', + training__pb2.SaveRequest.SerializeToString, + utils__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def Export(request, target, @@ -340,7 +373,7 @@ def Export(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Export', - utils__pb2.ModelSession.SerializeToString, + training__pb2.ExportRequest.SerializeToString, utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/server/session/backend/base.py b/tiktorch/server/session/backend/base.py index c4bc65cd..7a0053a4 100644 --- a/tiktorch/server/session/backend/base.py +++ b/tiktorch/server/session/backend/base.py @@ -3,6 +3,7 @@ import logging from abc import ABC from concurrent.futures import Future +from pathlib import Path from bioimageio.core import PredictionPipeline, Sample @@ -81,11 +82,15 @@ def start_training(self) -> None: self._queue_tasks.send_command(start_cmd.awaitable) start_cmd.awaitable.wait() - def save(self) -> None: - raise NotImplementedError + def save(self, file_path: Path) -> None: + save_cmd = commands.SaveTrainingCmd(file_path) + self._queue_tasks.send_command(save_cmd.awaitable) + save_cmd.awaitable.wait() - def export(self) -> None: - raise NotImplementedError + def export(self, file_path: Path) -> None: + export_cmd = commands.ExportTrainingCmd(file_path) + self._queue_tasks.send_command(export_cmd.awaitable) + export_cmd.awaitable.wait() def get_state(self) -> TrainerState: return self._supervisor.get_state() diff --git a/tiktorch/server/session/backend/commands.py b/tiktorch/server/session/backend/commands.py index f19a45d2..3523c3bb 100644 --- a/tiktorch/server/session/backend/commands.py +++ b/tiktorch/server/session/backend/commands.py @@ -6,6 +6,7 @@ import threading import typing from dataclasses import dataclass, field +from pathlib import Path from typing import Generic, Type, TypeVar from tiktorch.trainer import TrainerAction, TrainerState @@ -131,6 +132,24 @@ def execute(self, ctx: Context) -> None: pass +class ExportTrainingCmd(ICommand): + def __init__(self, file_path: Path): + super().__init__() + self._file_path = file_path + + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.export(self._file_path) + + +class SaveTrainingCmd(ICommand): + def __init__(self, file_path: Path): + super().__init__() + self._file_path = file_path + + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.save(self._file_path) + + class ShutdownWithTeardownCmd(ShutdownCmd): def execute(self, ctx: Context[Supervisors]) -> None: ctx.session.shutdown() diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index 5bb07e2a..d4d4e45b 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -1,5 +1,6 @@ import logging import threading +from pathlib import Path from typing import Generic, Set, TypeVar, Union from bioimageio.core import PredictionPipeline, Sample @@ -134,11 +135,19 @@ def forward(self, input_tensors): self.resume() return res - def save(self): - raise NotImplementedError + def save(self, file_path: Path): + init_state = self.get_state() # retain the state after save + if init_state == TrainerState.RUNNING: + self.pause() + self._trainer.save_state_dict(file_path) + if init_state == TrainerState.RUNNING: + self.resume() - def export(self): - raise NotImplementedError + def export(self, file_path: Path): + init_state = self.get_state() + if init_state == TrainerState.RUNNING: + self.pause() + self._trainer.export(file_path) def _should_stop(self): return self._pause_triggered diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 81aa6b82..dfb2c21c 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -1,10 +1,10 @@ import logging import multiprocessing as _mp -import pathlib import tempfile import uuid from concurrent.futures import Future from multiprocessing.connection import Connection +from pathlib import Path from typing import List, Optional, Tuple, Type, TypeVar, Union import torch @@ -139,11 +139,11 @@ def start_training(self): def pause_training(self): self.worker.pause_training() - def save(self): - self.worker.save() + def save(self, file_path: Path): + self.worker.save(file_path) - def export(self): - self.worker.export() + def export(self, file_path: Path): + self.worker.export(file_path) def get_state(self): return self.worker.get_state() @@ -210,7 +210,7 @@ def _get_prediction_pipeline_from_model_bytes(model_bytes: bytes, devices: List[ def _get_model_descr_from_model_bytes(model_bytes: bytes) -> v0_5.ModelDescr: with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file: _tmp_file.write(model_bytes) - temp_file_path = pathlib.Path(_tmp_file.name) + temp_file_path = Path(_tmp_file.name) model_descr = load_description(temp_file_path, format_version="latest") if isinstance(model_descr, InvalidDescr): raise ValueError(f"Failed to load valid model descriptor {model_descr.validation_summary}") diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index 36146b1e..b6dbb132 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import List import torch @@ -78,11 +79,11 @@ def shutdown(self) -> Shutdown: raise NotImplementedError @exposed - def save(self): + def save(self, file_path: Path): raise NotImplementedError @exposed - def export(self): + def export(self, file_path: Path): raise NotImplementedError @exposed diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py index a3b0512b..ef54e20f 100644 --- a/tiktorch/trainer.py +++ b/tiktorch/trainer.py @@ -1,23 +1,49 @@ from __future__ import annotations import logging +import tempfile from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum +from pathlib import Path from typing import Any, Callable, Generic, List, TypeVar import bioimageio +import numpy as np import torch import xarray as xr import yaml from bioimageio.core import Sample +from bioimageio.spec import save_bioimageio_package +from bioimageio.spec.model.v0_5 import ( + ArchitectureFromLibraryDescr, + Author, + AxisId, + BatchAxis, + ChannelAxis, + CiteEntry, + Doi, + FileDescr, + HttpUrl, + Identifier, + InputTensorDescr, + LicenseId, + ModelDescr, + OutputTensorDescr, + PytorchStateDictWeightsDescr, + SpaceInputAxis, + SpaceOutputAxis, + TensorId, + Version, + WeightsDescr, +) from pytorch3dunet.augment.transforms import Compose, Normalize, Standardize, ToTensor from pytorch3dunet.datasets.utils import get_train_loaders from pytorch3dunet.unet3d.losses import get_loss_criterion from pytorch3dunet.unet3d.metrics import get_evaluation_metric from pytorch3dunet.unet3d.model import ResidualUNet2D, ResidualUNet3D, ResidualUNetSE3D, UNet2D, UNet3D, get_model from pytorch3dunet.unet3d.trainer import UNetTrainer -from pytorch3dunet.unet3d.utils import create_lr_scheduler, create_optimizer, get_tensorboard_formatter +from pytorch3dunet.unet3d.utils import create_lr_scheduler, create_optimizer, get_class, get_tensorboard_formatter from torch import nn T = TypeVar("T", bound=Callable) @@ -160,7 +186,176 @@ def train(self): def validate(self): return super().validate() - def forward(self, input_tensors: Sample) -> Sample: + def save_state_dict(self, file_path: Path): + """ + On demand save of the training progress including the optimizer state. + + Note: pytorch-3dunet automatically saves two checkpoints latest.pytorch and best.pytorch. + The best.pytorch is updated in intervals defined by the `validation_after_iters`. + """ + if not file_path.suffix: + file_path = file_path.with_suffix(".pth") + + state = { + "num_epochs": self.num_epochs + 1, + "num_iterations": self.num_iterations, + "model_state_dict": self.model.state_dict(), + "best_eval_score": self.best_eval_score, + "optimizer_state_dict": self.optimizer.state_dict(), + } + torch.save(state, file_path) + + def save_torchscript(self, file_path: Path): + """ + Requires the model to be torchscript compatible! + """ + if not file_path.suffix: + file_path = file_path.with_suffix(".pt") + + scripted_model = torch.jit.script(self.model) + torch.jit.save(scripted_model, file_path) + + def get_model_import_file_path(self) -> str: + """ + Utility to be used for bioimageio pytorch state weight descriptor. + + Assuming that pytorch-3dunet will be installed in the environment. + """ + model_class = get_class(self.model.__class__.__name__, modules=["pytorch3dunet.unet3d.model"]) # sanity-check + return f"{model_class.__module__}" + + def export(self, file_to_save: Path): + """ + Export to bioimageio model zip. + + Use default values for the fields. Expected the configuration to be delegated. + """ + if not file_to_save.suffix: + file_to_save = file_to_save.with_suffix(".zip") + + sample_inputs = self._get_test_input_from_loaders() + sample_outputs = self._forward([sample_inputs]) + assert len(sample_outputs) == 1, "Only single output supported" + + # define the axes + if self.is_3d_model(): + b_size, c_size, z_size, y_size, x_size = sample_inputs.shape + input_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(self._in_channels)]), + SpaceInputAxis(id=AxisId("z"), size=z_size), + SpaceInputAxis(id=AxisId("y"), size=y_size), + SpaceInputAxis(id=AxisId("x"), size=x_size), + ] + b_size, c_size, z_size, y_size, x_size = sample_outputs.shape + output_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(self._out_channels)]), + SpaceOutputAxis(id=AxisId("z"), size=z_size), + SpaceOutputAxis(id=AxisId("y"), size=y_size), + SpaceOutputAxis(id=AxisId("x"), size=x_size), + ] + elif self.is_2d_model(): + b_size, c_size, z_size, y_size, x_size = sample_inputs.shape + input_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(self._in_channels)]), + SpaceInputAxis(id=AxisId("y"), size=y_size), + SpaceInputAxis(id=AxisId("x"), size=x_size), + ] + b_size, c_size, z_size, y_size, x_size = sample_outputs.shape + output_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(self._out_channels)]), + SpaceOutputAxis(id=AxisId("y"), size=y_size), + SpaceOutputAxis(id=AxisId("x"), size=x_size), + ] + else: + raise ValueError("Only 2D and 3D models are supported") + + # todo: we want to create a bioimageio zip file, + # currently it needs hard coded paths that will then be converted to zip + # so we create temp files to fulfill this dependency, + # we should be able to go directly to a zip. + + # squeeze z-dimension for 2d models + sample_inputs = sample_inputs.squeeze(dim=-3) + sample_outputs = sample_outputs.squeeze(dim=-3) + + with tempfile.TemporaryDirectory() as temp_dir: + sample_input_file = Path(temp_dir) / "input.npy" + np.save(sample_input_file, sample_inputs) + sample_output_file = Path(temp_dir) / "output.npy" + np.save(sample_output_file, sample_outputs) + + input_tensor = InputTensorDescr( + id=TensorId("input"), + axes=input_axes, + description="", + test_tensor=FileDescr(source=sample_input_file), + ) + + output_tensor = OutputTensorDescr( + id=TensorId("output"), + axes=output_axes, + description="", + test_tensor=FileDescr(source=sample_output_file), + ) + + weights_file = Path(temp_dir) / "weights.pt" + torch.save(self.model.state_dict(), weights_file) + weights = WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=weights_file, + architecture=ArchitectureFromLibraryDescr( + import_from=f"{self.get_model_import_file_path()}", + callable=Identifier(f"{self.model.__class__.__name__}"), + kwargs={"in_channels": self._in_channels, "out_channels": self._out_channels}, + ), + pytorch_version=Version("1.1.1"), + ) + ) # todo: pytorch version + + mocked_descr = ModelDescr( + name="tiktorch v5 model", + description="Add description", + authors=[Author(name="me", affiliation="my institute", github_user="bioimageiobot")], + cite=[CiteEntry(text="", doi=Doi("10.1234something"))], + license=LicenseId("MIT"), + documentation=HttpUrl("https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/README.md"), + git_repo=HttpUrl("https://github.com/bioimage-io/spec-bioimage-io"), + inputs=[input_tensor], + outputs=[output_tensor], + weights=weights, + ) + save_bioimageio_package(mocked_descr, output_path=file_to_save) + + def is_3d_model(self): + return isinstance(self.model, (ResidualUNetSE3D, ResidualUNet3D, UNet3D)) + + def is_2d_model(self): + return isinstance(self.model, (ResidualUNet2D, UNet2D)) + + def _get_test_input_from_loaders(self) -> torch.Tensor: + """ + Get the test input from the loaders to be used for bioimageio ModelDescr. + """ + # Get one batch of data + sample_batch = next(iter(self.loaders["train"])) # Get one batch + sample_inputs = sample_batch[0] # Assume input is the first element + return sample_inputs[0:1] # batch size 1 + + def forward(self, input_tensors: Sample): + assert len(input_tensors.members) == 1, "We support models with 1 input" + tensor_id, input_tensor = input_tensors.members.popitem() + input_tensor = self._get_pytorch_tensor_from_bioimageio_tensor(input_tensor) + predictions = self._forward([input_tensor]) + output_sample = Sample( + members={"output": self._get_bioimageio_tensor_from_pytorch_tensor(predictions)}, stat={}, id=None + ) + return output_sample + + def _forward(self, input_tensors: List[torch.Tensor]) -> torch.Tensor: """ Note: "The 2D U-Net itself uses the standard 2D convolutional @@ -170,10 +365,8 @@ def forward(self, input_tensors: Sample) -> Sample: Thus, we drop the z dimension if we have 2d model. But the input h5 data needs to respect CxDxHxW or DxHxW. """ - - assert len(input_tensors.members) == 1, "We support models with 1 input" - tensor_id, input_tensor = input_tensors.members.popitem() - input_tensor = self._get_pytorch_tensor_from_bioimageio_tensor(input_tensor) + assert len(input_tensors) == 1, "We support models with 1 input" + input_tensor = input_tensors[0] self.model.eval() b, c, z, y, x = input_tensor.shape if self.is_2d_model() and z != 1: @@ -204,11 +397,7 @@ def apply_final_activation(input_tensors) -> torch.Tensor: # currently we scale the features from 0 - 1 (consistent scale for rendering across channels) postprocessor = Compose([Normalize(norm01=True), ToTensor(expand_dims=True)]) predictions = self._apply_transformation(compose=postprocessor, tensor=predictions) - - output_sample = Sample( - members={"output": self._get_bioimageio_tensor_from_pytorch_tensor(predictions)}, stat={}, id=None - ) - return output_sample + return predictions def _apply_transformation(self, compose: Compose, tensor: torch.Tensor) -> torch.Tensor: """ @@ -247,12 +436,6 @@ def _get_pytorch_tensor_from_bioimageio_tensor(self, bioimageio_tensor: bioimage def _get_bioimageio_tensor_from_pytorch_tensor(self, pytorch_tensor: torch.Tensor) -> bioimageio.core.Tensor: return bioimageio.core.Tensor.from_xarray(xr.DataArray(pytorch_tensor.numpy(), dims=["b", "c", "z", "y", "x"])) - def is_3d_model(self): - return isinstance(self.model, (ResidualUNetSE3D, ResidualUNet3D, UNet3D)) - - def is_2d_model(self): - return isinstance(self.model, (ResidualUNet2D, UNet2D)) - def should_stop(self) -> bool: """ Intervene on how to stop the training. @@ -277,7 +460,7 @@ def _log_stats(self, phase, loss_avg, eval_score_avg): max_iterations=self.max_num_iterations, ) self.logs_callbacks(logs) - # todo: why the internal training logging isn't printed on the stdout, although it is set + # todo: why the internal training logging isn't printed on the stdout logger.info(str(logs)) return super()._log_stats(phase, loss_avg, eval_score_avg) @@ -329,9 +512,9 @@ def parse(self) -> Trainer: pre_trained = trainer_config.pop("pre_trained", None) return Trainer( - device=config["device"], in_channels=in_channels, out_channels=out_channels, + device=config["device"], model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,