Skip to content

Commit 312e5ed

Browse files
authored
fix bug in test_model and hyper (#277)
* fix bug in test_model and hyper * fix pipeline * add hyper example * unify example
1 parent 939ef5d commit 312e5ed

File tree

7 files changed

+63
-26
lines changed

7 files changed

+63
-26
lines changed

hyper_example.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"learning_rate": {
3+
"type": "choice",
4+
"list": [0.01, 0.005, 0.001]
5+
}
6+
}

hyper_example.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
learning_rate choice [0.01, 0.005, 0.001]

hyper_tune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
训练并评估单一模型的脚本
2+
模型调参脚本 (based on the ray[tune])
33
"""
44

55
import argparse
@@ -20,7 +20,7 @@
2020
parser.add_argument('--config_file', type=str,
2121
default=None, help='the file name of config file')
2222
parser.add_argument('--space_file', type=str,
23-
default=None, help='the file which specifies the parameter search space')
23+
default='hyper_example', help='the file which specifies the parameter search space')
2424
parser.add_argument('--scheduler', type=str,
2525
default='FIFO', help='the trial sheduler which will be used in ray.tune.run')
2626
parser.add_argument('--search_alg', type=str,

libcity/pipeline/pipeline.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,15 @@ def hyper_parameter(task=None, model_name=None, dataset_name=None, config_file=N
136136
# load config
137137
experiment_config = ConfigParser(task, model_name, dataset_name, config_file=config_file,
138138
other_args=other_args)
139+
# exp_id
140+
exp_id = experiment_config.get('exp_id', None)
141+
if exp_id is None:
142+
exp_id = int(random.SystemRandom().random() * 100000)
143+
experiment_config['exp_id'] = exp_id
139144
# logger
140145
logger = get_logger(experiment_config)
146+
logger.info('Begin ray-tune, task={}, model_name={}, dataset_name={}, exp_id={}'.
147+
format(str(task), str(model_name), str(dataset_name), str(exp_id)))
141148
logger.info(experiment_config.config)
142149
# check space_file
143150
if space_file is None:
@@ -167,8 +174,11 @@ def train(config, checkpoint_dir=None, experiment_config=None,
167174
experiment_config[key] = config[key]
168175
experiment_config['hyper_tune'] = True
169176
logger = get_logger(experiment_config)
170-
logger.info('Begin pipeline, task={}, model_name={}, dataset_name={}'
171-
.format(str(task), str(model_name), str(dataset_name)))
177+
# exp_id
178+
exp_id = int(random.SystemRandom().random() * 100000)
179+
experiment_config['exp_id'] = exp_id
180+
logger.info('Begin pipeline, task={}, model_name={}, dataset_name={}, exp_id={}'.
181+
format(str(task), str(model_name), str(dataset_name), str(exp_id)))
172182
logger.info('running parameters: ' + str(config))
173183
# load model
174184
model = get_model(experiment_config, data_feature)
@@ -215,9 +225,9 @@ def train(config, checkpoint_dir=None, experiment_config=None,
215225
# save best
216226
best_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
217227
model_state, optimizer_state = torch.load(best_path)
218-
model_cache_file = './libcity/cache/model_cache/{}_{}.m'.format(
219-
model_name, dataset_name)
220-
ensure_dir('./libcity/cache/model_cache')
228+
model_cache_file = './libcity/cache/{}/model_cache/{}_{}.m'.format(
229+
exp_id, model_name, dataset_name)
230+
ensure_dir('./libcity/cache/{}/model_cache'.format(exp_id))
221231
torch.save((model_state, optimizer_state), model_cache_file)
222232

223233

run_hyper.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
2-
单一模型调参脚本
2+
模型调参脚本 (based on the hyperopt)
33
"""
44

55
import argparse
6-
6+
import random
77
from libcity.pipeline import objective_function
88
from libcity.executor import HyperTuning
99
from libcity.utils import str2bool, get_logger, set_random_seed, add_general_args
@@ -26,7 +26,7 @@
2626
help='whether re-train model if the model is \
2727
trained before')
2828
parser.add_argument('--params_file', type=str,
29-
default=None, help='the file which specify the \
29+
default='hyper_example.txt', help='the file which specify the \
3030
hyper-parameters and ranges to be adjusted')
3131
parser.add_argument('--hyper_algo', type=str,
3232
default='grid_search', help='hyper-parameters search algorithm')
@@ -43,11 +43,18 @@
4343
other_args = {key: val for key, val in dict_args.items() if key not in [
4444
'task', 'model', 'dataset', 'config_file', 'saved_model', 'train',
4545
'params_file', 'hyper_algo'] and val is not None}
46-
47-
logger = get_logger({'model': args.model, 'dataset': args.dataset})
46+
# exp_id
47+
exp_id = dict_args.get('exp_id', None)
48+
if exp_id is None:
49+
# Make a new experiment ID
50+
exp_id = int(random.SystemRandom().random() * 100000)
51+
other_args['exp_id'] = exp_id
52+
# logger
53+
logger = get_logger({'model': args.model, 'dataset': args.dataset, 'exp_id': exp_id})
4854
# seed
4955
seed = dict_args.get('seed', 0)
5056
set_random_seed(seed)
57+
other_args['seed'] = seed
5158
hp = HyperTuning(objective_function, params_file=args.params_file, algo=args.hyper_algo,
5259
max_evals=args.max_evals, task=args.task, model_name=args.model,
5360
dataset_name=args.dataset, config_file=args.config_file,

test_model.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
from libcity.config import ConfigParser
22
from libcity.data import get_dataset
3-
from libcity.utils import get_model, get_executor
3+
from libcity.utils import get_model, get_executor, get_logger, set_random_seed
4+
import random
5+
6+
"""
7+
取一个batch的数据进行初步测试
8+
Take the data of a batch for preliminary testing
9+
"""
410

511
# 加载配置文件
6-
config = ConfigParser(task='traj_loc_pred', model='TemplateTLP',
7-
dataset='foursquare_tky', config_file=None,
8-
other_args={'batch_size': 2})
9-
# 如果是交通流量\速度预测任务,请使用下面的加载配置文件语句
10-
# config = ConfigParser(task='traffic_state_pred', model='TemplateTSP',
11-
# dataset='METR_LA', config_file=None, other_args={'batch_size': 2})
12+
config = ConfigParser(task='traffic_state_pred', model='RNN',
13+
dataset='METR_LA', other_args={'batch_size': 2})
14+
exp_id = config.get('exp_id', None)
15+
if exp_id is None:
16+
exp_id = int(random.SystemRandom().random() * 100000)
17+
config['exp_id'] = exp_id
18+
# logger
19+
logger = get_logger(config)
20+
logger.info(config.config)
21+
# seed
22+
seed = config.get('seed', 0)
23+
set_random_seed(seed)
1224
# 加载数据模块
1325
dataset = get_dataset(config)
1426
# 数据预处理,划分数据集
@@ -18,10 +30,11 @@
1830
batch = train_data.__iter__().__next__()
1931
# 加载模型
2032
model = get_model(config, data_feature)
21-
self = model.to(config['device'])
33+
model = model.to(config['device'])
34+
# 加载执行器
35+
executor = get_executor(config, model, data_feature)
2236
# 模型预测
2337
batch.to_tensor(config['device'])
2438
res = model.predict(batch)
25-
# 请自行确认 res 的 shape 是否符合赛道的约束
26-
# 如果要加载执行器的话
27-
executor = get_executor(config, model)
39+
logger.info('Result shape is {}'.format(res.shape))
40+
logger.info('Success test the model!')

unit_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
#############################################
77
# The parameter to control the unit testing #
8-
tested_trajectory_model = 'TemplateTLP'
9-
tested_trajectory_dataset = 'foursquare_tky'
8+
tested_trajectory_model = 'RNN'
9+
tested_trajectory_dataset = 'foursquare_nyc'
1010
tested_trajectory_encoder = 'StandardTrajectoryEncoder'
11-
tested_traffic_state_model = 'DCRNN'
11+
tested_traffic_state_model = 'RNN'
1212
tested_traffic_state_dataset = 'METR_LA'
1313
#############################################
1414

0 commit comments

Comments
 (0)