Skip to content

Commit

Permalink
添加在metric统计all_gather_object是否调用的逻辑;添加了一些可能引起bug处的注释
Browse files Browse the repository at this point in the history
  • Loading branch information
x54-729 committed Oct 31, 2022
1 parent 0d1a580 commit 6f21084
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
7 changes: 4 additions & 3 deletions fastNLP/core/callbacks/has_monitor_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,10 @@ def on_after_trainer_initialized(self, trainer, driver):
if self.must_have_monitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.log_name}. "
f"You can set it in the initialization or through Trainer.")
if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None:
raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.log_name}"
f" need to watch the monitor:`{self.monitor_name}`.")
# 用户可能会在自定义 Callback 中自行 evaluate 结果并且不使用 Evaluator,此时该限制会变得不合理,暂时注释掉
# if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None:
# raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.log_name}"
# f" need to watch the monitor:`{self.monitor_name}`.")

def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
Expand Down
3 changes: 3 additions & 0 deletions fastNLP/core/callbacks/topk_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def save(self, trainer, folder_name):
model_save_fn=self.model_save_fn,
**self.kwargs
)
# TODO 如果 Metric 没有进行聚集操作,此时会创建出多个文件夹且只在 rank 0 的文件夹中进行保存
# 可能的解决方法:检测出空文件夹并且删除

return str(os.path.abspath(folder))

@rank_zero_call
Expand Down
16 changes: 13 additions & 3 deletions fastNLP/core/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from fastNLP.core.metrics.backend import Backend, AutoBackend
from fastNLP.core.metrics.element import Element
from fastNLP.envs import is_cur_env_distributed
from fastNLP.core.log import logger


Expand Down Expand Up @@ -42,11 +43,9 @@ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_g
self.get_metric = self._sync_get_metric(self.get_metric)
self.update = self._wrap_update(self.update)
self.reset = self._wrap_auto_reset_elements(self.reset)
self.get_metric = self._wrap_check_get_metric(self.get_metric)
self.aggregate_when_get_metric = aggregate_when_get_metric
self._cannot_change_element = False
self._call_gather_object = False
self._check_get_metric = False
self._call_gather_object = False # 用于检查用户是否在 get_metric 中调用了 all_gather_object
self._elements = {}

@property
Expand Down Expand Up @@ -108,7 +107,18 @@ def _wrap_get_metric(*args, **kwargs):
assert self._updated, f"You have to call `{self.__class__.__name__}'s update() function before calling " \
f"get_metric()."
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
self._call_gather_object = False
results = get_metric(*args, **kwargs)

# elements 为空、没有 call 则准备报错
if len(self._elements) == 0 and not self._call_gather_object:
# 需要 aggregate 并且在多卡环境下
if self.aggregate_when_get_metric and is_cur_env_distributed():
logger.rank_zero_warning("There is no `<class 'Element'>` registered in metric `{}` and you didn't call "
"`Metric.all_gather_object()` in method `get_metric()` either. Therefore your "
"results may not be aggregated in distributed training."
.format(self.__class__), once=True)

return results

return _wrap_get_metric
Expand Down

0 comments on commit 6f21084

Please sign in to comment.