Skip to content

Commit e2425ec

Browse files
authored
Rework get optimizer function (#97)
1 parent 4376be1 commit e2425ec

File tree

5 files changed

+20
-53
lines changed

5 files changed

+20
-53
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ flake8==3.7.9
77
ipdb==0.13.2
88
ipython==7.13.0
99
mypy==0.770
10+
numpy==1.18.2
1011
pyroma==2.6
1112
pytest-cov==2.8.1
1213
pytest==5.4.1

tests/test_get.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

tests/test_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def build_lookahead(*a, **kw):
8080
build_lookahead,
8181
optim.Ranger,
8282
optim.RangerQH,
83-
optim.RangerVA
83+
optim.RangerVA,
8484
]
8585

8686

tests/test_optimizer_with_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def build_lookahead(*a, **kw):
6464
(build_lookahead, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
6565
(optim.QHM, {'lr': 0.1, 'weight_decay': 1e-5, 'momentum': 0.2}, 200),
6666
(optim.QHAdam, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
67-
(optim.Ranger, {'lr': .1, 'weight_decay': 1e-3}, 200),
68-
(optim.RangerQH, {'lr': .01, 'weight_decay': 1e-3}, 200),
69-
(optim.RangerVA, {'lr': .01, 'weight_decay': 1e-3}, 200),
67+
(optim.Ranger, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
68+
(optim.RangerQH, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
69+
(optim.RangerVA, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
7070
]
7171

7272

torch_optimizer/__init__.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from typing import Optional, Type
1+
from typing import Optional, Type, List, Dict
2+
3+
from pytorch_ranger import Ranger, RangerQH, RangerVA
4+
from torch.optim import Optimizer
5+
26
from .accsgd import AccSGD
37
from .adabound import AdaBound
48
from .adamod import AdaMod
@@ -12,8 +16,6 @@
1216
from .radam import RAdam
1317
from .sgdw import SGDW
1418
from .yogi import Yogi
15-
from pytorch_ranger import Ranger, RangerQH, RangerVA
16-
from torch import optim
1719

1820

1921
__all__ = (
@@ -32,12 +34,14 @@
3234
'Yogi',
3335
'Ranger',
3436
'RangerQH',
35-
'RangerVA'
37+
'RangerVA',
38+
# utils
39+
'get',
3640
)
3741
__version__ = '0.0.1a11'
3842

3943

40-
package_opts = [
44+
_package_opts = [
4145
AccSGD,
4246
AdaBound,
4347
AdaMod,
@@ -54,38 +58,18 @@
5458
Ranger,
5559
RangerQH,
5660
RangerVA,
57-
]
61+
] # type: List[Optimizer]
5862

59-
builtin_opts = [
60-
optim.Adadelta,
61-
optim.Adagrad,
62-
optim.Adam,
63-
optim.AdamW,
64-
optim.SparseAdam,
65-
optim.Adamax,
66-
optim.ASGD,
67-
optim.SGD,
68-
optim.Rprop,
69-
optim.RMSprop,
70-
optim.LBFGS
71-
]
7263

73-
NAME_OPTIM_MAP = {
74-
opt.__name__.lower(): opt for opt in package_opts + builtin_opts
75-
}
64+
_NAME_OPTIM_MAP = {
65+
opt.__name__.lower(): opt for opt in _package_opts
66+
} # type: Dict[str, Optimizer]
7667

7768

78-
def get(name: str,) -> Optional[Type[optim.Optimizer]]:
69+
def get(name: str) -> Optional[Type[Optimizer]]:
7970
r"""Returns an optimizer class from its name. Case insensitive.
8071
8172
Args:
8273
name: the optimizer name.
8374
"""
84-
if isinstance(name, str):
85-
cls = NAME_OPTIM_MAP.get(name.lower())
86-
if cls is None:
87-
raise ValueError('Could not interpret optimizer name: ' +
88-
str(name))
89-
return cls
90-
raise ValueError('Could not interpret optimizer name: ' +
91-
str(name))
75+
return _NAME_OPTIM_MAP.get(name.lower())

0 commit comments

Comments
 (0)