Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove no use out_dir #735

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytorch_pfn_extras/training/extensions/_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,8 @@ def _make_snapshot(self, manager: ExtensionsManagerProtocol) -> None:
filename = filename(manager)
else:
filename = filename.format(manager)
outdir = manager.out
writer( # type: ignore
filename, outdir, serialized_target, savefun=self._savefun
filename, serialized_target, savefun=self._savefun
)

def finalize(self, manager: ExtensionsManagerProtocol) -> None:
Expand Down
2 changes: 0 additions & 2 deletions pytorch_pfn_extras/training/extensions/log_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:

# write to the log file
log_name = self._filename.format(**stats_cpu)
out = manager.out
savefun = LogWriterSaveFunc(self._format, self._append)
writer(
log_name,
out,
self._log_looker.get(),
savefun=savefun,
append=self._append,
Expand Down
1 change: 0 additions & 1 deletion pytorch_pfn_extras/training/extensions/profile_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
)
writer(
log_name,
out,
self._log, # type: ignore
savefun=savefun,
append=self._append,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ def save_plot_using_module(

writer(
self._filename,
manager.out,
(fig, plt), # type: ignore
savefun=matplotlib_savefun,
)
Expand Down
11 changes: 3 additions & 8 deletions pytorch_pfn_extras/writing/_parallel_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,13 @@ def __init__(
def _save_with_exitcode(
self,
filename: str,
out_dir: str,
target: _TargetType,
savefun: _SaveFun,
append: bool,
**savefun_kwargs: Any,
) -> None:
try:
self.save(
filename, out_dir, target, savefun, append, **savefun_kwargs
)
self.save(filename, target, savefun, append, **savefun_kwargs)
except Exception as e:
thread = threading.current_thread()
thread.exitcode = -1 # type: ignore[attr-defined]
Expand All @@ -56,7 +53,6 @@ def _save_with_exitcode(
def create_worker(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
Expand All @@ -65,7 +61,7 @@ def create_worker(
) -> threading.Thread:
return threading.Thread(
target=self._save_with_exitcode,
args=(filename, out_dir, target, savefun, append),
args=(filename, target, savefun, append),
kwargs=savefun_kwargs,
)

Expand Down Expand Up @@ -97,7 +93,6 @@ def __init__(
def create_worker(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
Expand All @@ -106,6 +101,6 @@ def create_worker(
) -> multiprocessing.Process:
return multiprocessing.Process(
target=self.save,
args=(filename, out_dir, target, savefun, append),
args=(filename, target, savefun, append),
kwargs=savefun_kwargs,
)
13 changes: 3 additions & 10 deletions pytorch_pfn_extras/writing/_queue_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
_Worker,
)

_QueUnit = Optional[
Tuple[_TaskFun, str, str, _TargetType, Optional[_SaveFun], bool]
]
_QueUnit = Optional[Tuple[_TaskFun, str, _TargetType, Optional[_SaveFun], bool]]


class QueueWriter(Writer, Generic[_Worker]):
Expand Down Expand Up @@ -66,16 +64,13 @@ def __init__(
def __call__(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
append: bool = False,
) -> None:
assert not self._finalized
self._queue.put(
(self._task, filename, out_dir, target, savefun, append)
)
self._queue.put((self._task, filename, target, savefun, append))

def create_task(self, savefun: _SaveFun) -> _TaskFun:
return SimpleWriter(savefun=savefun)
Expand All @@ -93,9 +88,7 @@ def consume(self, q: "queue.Queue[_QueUnit]") -> None:
q.task_done()
return
else:
task[0](
task[1], task[2], task[3], savefun=task[4], append=task[5]
)
task[0](task[1], task[2], savefun=task[3], append=task[4])
q.task_done()

def finalize(self) -> None:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_pfn_extras/writing/_simple_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ def __init__(
def __call__(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
append: bool = False,
) -> None:
if savefun is None:
savefun = self._savefun
self.save(filename, out_dir, target, savefun, append, **self._kwds)
self.save(filename, target, savefun, append, **self._kwds)
1 change: 0 additions & 1 deletion pytorch_pfn_extras/writing/_tensorboard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __del__(self) -> None:
def __call__(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
Expand Down
5 changes: 0 additions & 5 deletions pytorch_pfn_extras/writing/_writer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def __init__(
def __call__(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
Expand Down Expand Up @@ -280,7 +279,6 @@ def finalize(self) -> None:
def save(
self,
filename: str,
out_dir: str,
target: _TargetType,
savefun: _SaveFun,
append: bool,
Expand Down Expand Up @@ -375,7 +373,6 @@ def __init__(
def __call__(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
Expand All @@ -389,7 +386,6 @@ def __call__(
self._filename = filename
self._worker = self.create_worker(
filename,
out_dir,
target,
savefun=savefun,
append=append,
Expand All @@ -402,7 +398,6 @@ def __call__(
def create_worker(
self,
filename: str,
out_dir: str,
target: _TargetType,
*,
savefun: Optional[_SaveFun] = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch_pfn_extras_tests/test_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_tensorboard_writing():
writer = ppe.writing.TensorBoardWriter(
out_dir=tempd, filename_suffix="_test"
)
writer(None, None, data)
writer(None, data)
# Check that the file was generated
for snap in os.listdir(tempd):
assert "_test" in snap
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
import os
import tempfile
import threading
from unittest import mock
Expand All @@ -11,25 +12,27 @@

def test_simple_writer():
target = mock.MagicMock()
w = writing.SimpleWriter(foo=True)
savefun = mock.MagicMock()
with tempfile.TemporaryDirectory() as tempd:
w("myfile.dat", tempd, target, savefun=savefun)
w = writing.SimpleWriter(foo=True, out_dir=tempd)
filename = "myfile.dat"
w(filename, target, savefun=savefun)
assert os.path.exists(os.path.join(tempd, filename))
assert savefun.call_count == 1
assert savefun.call_args[0][0] == target
assert savefun.call_args[1]["foo"] is True


def test_standard_writer():
target = mock.MagicMock()
w = writing.StandardWriter()
worker = mock.MagicMock()
worker.exitcode = 0
name = spshot_writers_path + ".StandardWriter.create_worker"
with mock.patch(name, return_value=worker):
with tempfile.TemporaryDirectory() as tempd:
w("myfile.dat", tempd, target)
w("myfile.dat", tempd, target)
w = writing.StandardWriter(out_dir=tempd)
w("myfile.dat", target)
w("myfile.dat", target)
w.finalize()

assert worker.start.call_count == 2
Expand All @@ -38,36 +41,36 @@ def test_standard_writer():

def test_thread_writer_create_worker():
target = mock.MagicMock()
w = writing.ThreadWriter()
with tempfile.TemporaryDirectory() as tempd:
worker = w.create_worker("myfile.dat", tempd, target, append=False)
w = writing.ThreadWriter(out_dir=tempd)
worker = w.create_worker("myfile.dat", target, append=False)
assert isinstance(worker, threading.Thread)
w("myfile2.dat", tempd, "test")
w("myfile2.dat", "test")
w.finalize()


def test_thread_writer_fail():
w = writing.ThreadWriter(savefun=None)
with tempfile.TemporaryDirectory() as tempd:
w("myfile2.dat", tempd, "test")
w = writing.ThreadWriter(savefun=None, out_dir=tempd)
w("myfile2.dat", "test")
with pytest.raises(RuntimeError):
w.finalize()


def test_process_writer_create_worker():
target = mock.MagicMock()
w = writing.ProcessWriter()
with tempfile.TemporaryDirectory() as tempd:
worker = w.create_worker("myfile.dat", tempd, target, append=False)
w = writing.ProcessWriter(out_dir=tempd)
worker = w.create_worker("myfile.dat", target, append=False)
assert isinstance(worker, multiprocessing.Process)
w("myfile2.dat", tempd, "test")
w("myfile2.dat", "test")
w.finalize()


def test_process_writer_fail():
w = writing.ProcessWriter(savefun=None)
with tempfile.TemporaryDirectory() as tempd:
w("myfile2.dat", tempd, "test")
w = writing.ProcessWriter(savefun=None, out_dir=tempd)
w("myfile2.dat", "test")
with pytest.raises(RuntimeError):
w.finalize()

Expand All @@ -82,11 +85,10 @@ def test_queue_writer():
]
with mock.patch(names[0], return_value=q):
with mock.patch(names[1], return_value=consumer):
w = writing.QueueWriter()

with tempfile.TemporaryDirectory() as tempd:
w("myfile.dat", tempd, target)
w("myfile.dat", tempd, target)
w = writing.QueueWriter(out_dir=tempd)
w("myfile.dat", target)
w("myfile.dat", target)
w.finalize()

assert consumer.start.call_count == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,11 @@ def test_extensions_accessing_models_without_flag(priority):
if priority is not None:
extension.priority = priority
manager = training.ExtensionsManager(
m, optimizer, 1, iters_per_epoch=5, extensions=[extension]
m,
optimizer,
1,
iters_per_epoch=5,
extensions=[extension],
)
while not manager.stop_trigger:
with pytest.raises(RuntimeError):
Expand Down
Loading