Skip to content

Commit 539b7e8

Browse files
committed
Makedirs helper function
1 parent 87d659c commit 539b7e8

File tree

9 files changed

+44
-32
lines changed

9 files changed

+44
-32
lines changed

chemprop/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .model import build_model
1+
from .model import build_model, MoleculeModel

chemprop/parsing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import torch
77

8+
from chemprop.utils import makedirs
9+
810

911
def add_predict_args(parser: ArgumentParser):
1012
"""
@@ -202,9 +204,7 @@ def modify_predict_args(args: Namespace):
202204
del args.no_cuda
203205

204206
# Create directory for preds path
205-
preds_dir = os.path.dirname(args.preds_path)
206-
if preds_dir != '':
207-
os.makedirs(preds_dir, exist_ok=True)
207+
makedirs(args.preds_path, isfile=True)
208208

209209

210210
def parse_predict_args() -> Namespace:
@@ -235,7 +235,7 @@ def modify_train_args(args: Namespace):
235235
assert args.dataset_type is not None
236236

237237
if args.save_dir is not None:
238-
os.makedirs(args.save_dir, exist_ok=True)
238+
makedirs(args.save_dir)
239239
else:
240240
temp_dir = TemporaryDirectory()
241241
args.save_dir = temp_dir.name

chemprop/train/cross_validate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import numpy as np
77

8-
from chemprop.data.utils import get_task_names
98
from .run_training import run_training
9+
from chemprop.data.utils import get_task_names
10+
from chemprop.utils import makedirs
1011

1112

1213
def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float]:
@@ -24,7 +25,7 @@ def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float
2425
info(f'Fold {fold_num}')
2526
args.seed = init_seed + fold_num
2627
args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
27-
os.makedirs(args.save_dir, exist_ok=True)
28+
makedirs(args.save_dir)
2829
model_scores = run_training(args, logger)
2930
all_scores.append(model_scores)
3031
all_scores = np.array(all_scores)

chemprop/train/run_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from chemprop.models import build_model
2121
from chemprop.nn_utils import param_count
2222
from chemprop.utils import build_optimizer, build_lr_scheduler, get_loss_func, get_metric_func, load_checkpoint,\
23-
save_checkpoint
23+
makedirs, save_checkpoint
2424

2525

2626
def run_training(args: Namespace, logger: Logger = None) -> List[float]:
@@ -140,7 +140,7 @@ def run_training(args: Namespace, logger: Logger = None) -> List[float]:
140140
for model_idx in range(args.ensemble_size):
141141
# Tensorboard writer
142142
save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
143-
os.makedirs(save_dir, exist_ok=True)
143+
makedirs(save_dir)
144144
writer = SummaryWriter(log_dir=save_dir)
145145

146146
# Load/build model

chemprop/utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,35 @@
1212
from torch.optim.lr_scheduler import _LRScheduler
1313

1414
from chemprop.data import StandardScaler
15-
from chemprop.models import build_model
15+
from chemprop.models import build_model, MoleculeModel
1616
from chemprop.nn_utils import NoamLR
1717

1818

19+
def makedirs(path: str, isfile: bool = False):
20+
"""
21+
Creates a directory given a path to either a directory or file.
22+
23+
If a directory is provided, creates that directory. If a file is provided (i.e. isfiled == True),
24+
creates the parent directory for that file.
25+
26+
:param path: Path to a directory or file.
27+
:param isfile: Whether the provided path is a directory or file.
28+
"""
29+
if isfile:
30+
path = os.path.dirname(path)
31+
if path != '':
32+
os.makedirs(path, exist_ok=True)
33+
34+
1935
def save_checkpoint(path: str,
20-
model: nn.Module,
36+
model: MoleculeModel,
2137
scaler: StandardScaler = None,
2238
features_scaler: StandardScaler = None,
2339
args: Namespace = None):
2440
"""
2541
Saves a model checkpoint.
2642
27-
:param model: A PyTorch model.
43+
:param model: A MoleculeModel.
2844
:param scaler: A StandardScaler fitted on the data.
2945
:param features_scaler: A StandardScaler fitted on the features.
3046
:param args: Arguments namespace.
@@ -48,15 +64,15 @@ def save_checkpoint(path: str,
4864
def load_checkpoint(path: str,
4965
current_args: Namespace = None,
5066
cuda: bool = False,
51-
logger: logging.Logger = None) -> nn.Module:
67+
logger: logging.Logger = None) -> MoleculeModel:
5268
"""
5369
Loads a model checkpoint.
5470
5571
:param path: Path where checkpoint is saved.
5672
:param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided.
5773
:param cuda: Whether to move model to cuda.
5874
:param logger: A logger.
59-
:return: The loaded model.
75+
:return: The loaded MoleculeModel.
6076
"""
6177
debug = logger.debug if logger is not None else print
6278

@@ -275,8 +291,7 @@ def create_logger(name: str, save_dir: str = None, quiet: bool = False) -> loggi
275291
logger.addHandler(ch)
276292

277293
if save_dir is not None:
278-
if save_dir != '':
279-
os.makedirs(save_dir, exist_ok=True)
294+
makedirs(save_dir)
280295

281296
fh_v = logging.FileHandler(os.path.join(save_dir, 'verbose.log'))
282297
fh_v.setLevel(logging.DEBUG)

hyperparameter_optimization.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from chemprop.nn_utils import param_count
1414
from chemprop.parsing import add_train_args, modify_train_args
1515
from chemprop.train import cross_validate
16-
from chemprop.utils import create_logger
16+
from chemprop.utils import create_logger, makedirs
1717

1818

1919
SPACE = {
@@ -86,9 +86,7 @@ def objective(hyperparams: Dict[str, Union[int, float]]) -> float:
8686
logger.info(f'{best_result["mean_score"]} +/- {best_result["std_score"]} {args.metric}')
8787

8888
# Save best hyperparameter settings as JSON config file
89-
config_save_dir = os.path.dirname(args.config_save_path)
90-
if config_save_dir != '':
91-
os.makedirs(config_save_dir, exist_ok=True)
89+
makedirs(args.config_save_path, isfile=True)
9290

9391
with open(args.config_save_path, 'w') as f:
9492
json.dump(best_result['hyperparams'], f, indent=4, sort_keys=True)

scripts/save_features.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from chemprop.data.utils import get_data
1818
from chemprop.features import get_features_func
19+
from chemprop.utils import makedirs
1920

2021

2122
def load_temp(temp_dir: str) -> Tuple[List[List[float]], int]:
@@ -62,6 +63,9 @@ def save_features(args: Namespace):
6263
6364
:param args: Arguments.
6465
"""
66+
# Create directory for save_path
67+
makedirs(args.save_path, isfile=True)
68+
6569
# Get data and features function
6670
data = get_data(path=args.data_path, max_data_size=None)
6771
features_func = get_features_func(args.features_generator)
@@ -81,7 +85,7 @@ def save_features(args: Namespace):
8185
features, temp_num = load_temp(temp_save_dir)
8286

8387
if not os.path.exists(temp_save_dir):
84-
os.makedirs(temp_save_dir)
88+
makedirs(temp_save_dir)
8589
features, temp_num = [], 0
8690

8791
# Build features map function
@@ -134,8 +138,4 @@ def save_features(args: Namespace):
134138
help='Whether to run in parallel rather than sequentially (warning: doesn\'t always work')
135139
args = parser.parse_args()
136140

137-
dirname = os.path.dirname(args.save_path)
138-
if dirname != '':
139-
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
140-
141141
save_features(args)

test/test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from chemprop.features import clear_cache
1414
from chemprop.parsing import add_train_args, modify_train_args, add_predict_args, modify_predict_args
1515
from chemprop.train import cross_validate, make_predictions
16-
from chemprop.utils import create_logger
16+
from chemprop.utils import create_logger, makedirs
1717

1818
from hyperparameter_optimization import grid_search
1919

@@ -91,9 +91,7 @@ def test_save_features(self):
9191
args.save_path = NamedTemporaryFile().name
9292
args.features_generator = 'morgan_count'
9393

94-
dirname = os.path.dirname(args.save_path)
95-
if dirname != '':
96-
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
94+
makedirs(args.save_path, isfile=True)
9795

9896
save_features(args)
9997
os.unlink(args.save_path)

web/web.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
from chemprop.parsing import add_predict_args, add_train_args, modify_predict_args, modify_train_args
2222
from chemprop.train.make_predictions import make_predictions
2323
from chemprop.train.run_training import run_training
24-
from chemprop.utils import create_logger, load_task_names
24+
from chemprop.utils import create_logger, load_task_names, makedirs
2525

2626
TEMP_FOLDER = TemporaryDirectory()
2727

2828
app = Flask(__name__)
2929
app.config['DATA_FOLDER'] = 'web_data'
30-
os.makedirs(app.config['DATA_FOLDER'], exist_ok=True)
30+
makedirs(app.config['DATA_FOLDER'])
3131
app.config['CHECKPOINT_FOLDER'] = 'web_checkpoints'
32-
os.makedirs(app.config['CHECKPOINT_FOLDER'], exist_ok=True)
32+
makedirs(app.config['CHECKPOINT_FOLDER'])
3333
app.config['TEMP_FOLDER'] = TEMP_FOLDER.name
3434
app.config['SMILES_FILENAME'] = 'smiles.csv'
3535
app.config['PREDICTIONS_FILENAME'] = 'predictions.csv'

0 commit comments

Comments
 (0)