Skip to content

Commit

Permalink
Mod: Update setup_logging function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 18, 2024
1 parent e6c2a9e commit a8c8497
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 28 deletions.
7 changes: 2 additions & 5 deletions src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import os.path as osp
import shutil

from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Union
Expand All @@ -14,7 +13,6 @@
from torch.hub import download_url_to_file

import aac_metrics

from aac_metrics.classes.bert_score_mrefs import BERTScoreMRefs
from aac_metrics.classes.fense import FENSE
from aac_metrics.functional.meteor import DNAME_METEOR_CACHE
Expand All @@ -24,7 +22,7 @@
FNAME_SPICE_JAR,
check_spice_install,
)
from aac_metrics.utils.cmdline import _str_to_bool, _setup_logging
from aac_metrics.utils.cmdline import _str_to_bool, setup_logging
from aac_metrics.utils.globals import (
_get_cache_path,
_get_tmp_path,
Expand All @@ -33,7 +31,6 @@
)
from aac_metrics.utils.tokenization import FNAME_STANFORD_CORENLP_3_4_1_JAR


pylog = logging.getLogger(__name__)


Expand Down Expand Up @@ -387,7 +384,7 @@ def _get_main_download_args() -> Namespace:

def _main_download() -> None:
args = _get_main_download_args()
_setup_logging(aac_metrics.__package__, args.verbose)
setup_logging(aac_metrics.__package__, args.verbose)

download_metrics(
cache_path=args.cache_path,
Expand Down
11 changes: 4 additions & 7 deletions src/aac_metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,26 @@

import csv
import logging

from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Iterable, Union

import yaml

import aac_metrics

from aac_metrics.functional.evaluate import (
evaluate,
DEFAULT_METRICS_SET_NAME,
METRICS_SETS,
evaluate,
)
from aac_metrics.utils.checks import check_metric_inputs, check_java_path
from aac_metrics.utils.cmdline import _str_to_bool, _str_to_opt_str, _setup_logging
from aac_metrics.utils.checks import check_java_path, check_metric_inputs
from aac_metrics.utils.cmdline import _str_to_bool, _str_to_opt_str, setup_logging
from aac_metrics.utils.globals import (
get_default_cache_path,
get_default_java_path,
get_default_tmp_path,
)


pylog = logging.getLogger(__name__)


Expand Down Expand Up @@ -230,7 +227,7 @@ def _get_main_evaluate_args() -> Namespace:

def _main_eval() -> None:
args = _get_main_evaluate_args()
_setup_logging(aac_metrics.__package__, args.verbose)
setup_logging(aac_metrics.__package__, args.verbose)

if not check_java_path(args.java_path):
raise RuntimeError(f"Invalid Java executable. ({args.java_path})")
Expand Down
64 changes: 48 additions & 16 deletions src/aac_metrics/utils/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import logging
import sys

from typing import Optional

from logging import Logger
from types import ModuleType
from typing import Optional, Sequence, Union

_TRUE_VALUES = ("true", "1", "t", "yes", "y")
_FALSE_VALUES = ("false", "0", "f", "no", "n")
Expand All @@ -31,22 +31,55 @@ def _str_to_opt_str(s: str) -> Optional[str]:
return s


def _setup_logging(pkg_name: str, verbose: int, set_format: bool = True) -> None:
def setup_logging(
package_or_logger: Union[
str,
ModuleType,
None,
Logger,
Sequence[Union[str, ModuleType, None]],
Sequence[Logger],
],
verbose: int,
format_: Optional[str] = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s",
) -> None:
if package_or_logger is None or isinstance(
package_or_logger, (str, Logger, ModuleType)
):
package_or_logger_lst = [package_or_logger]
else:
package_or_logger_lst = list(package_or_logger)

name_or_logger_lst = [
pkg.__name__ if isinstance(pkg, ModuleType) else pkg
for pkg in package_or_logger_lst
]
logger_lst = [
logging.getLogger(pkg_i) if not isinstance(pkg_i, Logger) else pkg_i
for pkg_i in name_or_logger_lst
]

handler = logging.StreamHandler(sys.stdout)
if set_format:
format_ = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s"
if format_ is not None:
handler.setFormatter(logging.Formatter(format_))

pkg_logger = logging.getLogger(pkg_name)
for logger in logger_lst:
found = False
for handler in logger.handlers:
if (
isinstance(handler, logging.StreamHandler)
and handler.stream is sys.stdout
):
found = True
break
if not found:
logger.addHandler(handler)

level = _verbose_to_logging_level(verbose)
logger.setLevel(level)

found = False
for handler in pkg_logger.handlers:
if isinstance(handler, logging.StreamHandler) and handler.stream is sys.stdout:
found = True
break
if not found:
pkg_logger.addHandler(handler)

def _verbose_to_logging_level(verbose: int) -> int:
if verbose < 0:
level = logging.ERROR
elif verbose == 0:
Expand All @@ -55,5 +88,4 @@ def _setup_logging(pkg_name: str, verbose: int, set_format: bool = True) -> None
level = logging.INFO
else:
level = logging.DEBUG

pkg_logger.setLevel(level)
return level

0 comments on commit a8c8497

Please sign in to comment.