-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add STAFNet Model for Air Quality Prediction #1070
base: develop
Are you sure you want to change the base?
Changes from all commits
f79a3f9
68b23d1
d9d2b54
bfa3e69
2d9dc85
57dc7c2
fa1cdee
ab1ae03
d257a49
b43c7f5
757477a
2b46497
711cd36
a79ad1d
af96434
86a9c0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,84 @@ | ||||||||||||||||||||
defaults: | ||||||||||||||||||||
- ppsci_default | ||||||||||||||||||||
- TRAIN: train_default | ||||||||||||||||||||
- TRAIN/ema: ema_default | ||||||||||||||||||||
- TRAIN/swa: swa_default | ||||||||||||||||||||
- EVAL: eval_default | ||||||||||||||||||||
- INFER: infer_default | ||||||||||||||||||||
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | ||||||||||||||||||||
- _self_ | ||||||||||||||||||||
hydra: | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 配置文件开头请加上以下字段: PaddleScience/examples/ldc/conf/ldc_2d_Re3200_piratenet.yaml Lines 1 to 9 in fad6927
|
||||||||||||||||||||
run: | ||||||||||||||||||||
# dynamic output directory according to running time and override name | ||||||||||||||||||||
dir: outputs_chip_heat/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||||||||||||||||||||
job: | ||||||||||||||||||||
name: ${mode} # name of logfile | ||||||||||||||||||||
chdir: false # keep current working directory unchanged | ||||||||||||||||||||
callbacks: | ||||||||||||||||||||
init_callback: | ||||||||||||||||||||
_target_: ppsci.utils.callbacks.InitCallback | ||||||||||||||||||||
sweep: | ||||||||||||||||||||
# output directory for multirun | ||||||||||||||||||||
dir: ${hydra.run.dir} | ||||||||||||||||||||
subdir: ./ | ||||||||||||||||||||
|
||||||||||||||||||||
# general settings | ||||||||||||||||||||
mode: train # running mode: train/eval | ||||||||||||||||||||
seed: 42 | ||||||||||||||||||||
output_dir: ${hydra:run.dir} | ||||||||||||||||||||
log_freq: 20 | ||||||||||||||||||||
# dataset setting | ||||||||||||||||||||
DATASET: | ||||||||||||||||||||
label_keys: [label] | ||||||||||||||||||||
data_dir: ./dataset/train_data.pkl | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
MODEL: | ||||||||||||||||||||
input_keys: [aq_train_data, mete_train_data] | ||||||||||||||||||||
output_keys: [label] | ||||||||||||||||||||
output_attention: True | ||||||||||||||||||||
seq_len: 72 | ||||||||||||||||||||
pred_len: 48 | ||||||||||||||||||||
aq_gat_node_features: 7 | ||||||||||||||||||||
aq_gat_node_num: 35 | ||||||||||||||||||||
mete_gat_node_features: 7 | ||||||||||||||||||||
mete_gat_node_num: 18 | ||||||||||||||||||||
gat_hidden_dim: 32 | ||||||||||||||||||||
gat_edge_dim: 3 | ||||||||||||||||||||
e_layers: 1 | ||||||||||||||||||||
enc_in: 7 | ||||||||||||||||||||
dec_in: 7 | ||||||||||||||||||||
c_out: 7 | ||||||||||||||||||||
d_model: 16 | ||||||||||||||||||||
embed: "fixed" | ||||||||||||||||||||
freq: "t" | ||||||||||||||||||||
dropout: 0.05 | ||||||||||||||||||||
factor: 3 | ||||||||||||||||||||
n_heads: 4 | ||||||||||||||||||||
d_ff: 32 | ||||||||||||||||||||
num_kernels: 6 | ||||||||||||||||||||
top_k: 4 | ||||||||||||||||||||
|
||||||||||||||||||||
# training settings | ||||||||||||||||||||
TRAIN: | ||||||||||||||||||||
epochs: 100 | ||||||||||||||||||||
iters_per_epoch: 400 | ||||||||||||||||||||
save_freq: 10 | ||||||||||||||||||||
eval_during_train: true | ||||||||||||||||||||
eval_freq: 10 | ||||||||||||||||||||
batch_size: 1 | ||||||||||||||||||||
lr_scheduler: | ||||||||||||||||||||
epochs: ${TRAIN.epochs} | ||||||||||||||||||||
iters_per_epoch: ${TRAIN.iters_per_epoch} | ||||||||||||||||||||
learning_rate: 0.001 | ||||||||||||||||||||
step_size: 10 | ||||||||||||||||||||
gamma: 0.9 | ||||||||||||||||||||
pretrained_model_path: null | ||||||||||||||||||||
checkpoint_path: null | ||||||||||||||||||||
|
||||||||||||||||||||
EVAL: | ||||||||||||||||||||
eval_data_path: ./dataset/val_data.pkl | ||||||||||||||||||||
pretrained_model_path: null | ||||||||||||||||||||
compute_metric_by_batch: false | ||||||||||||||||||||
eval_with_no_grad: true | ||||||||||||||||||||
batch_size: 1 |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,112 @@ | ||||||||||||
import ppsci | ||||||||||||
from ppsci.utils import logger | ||||||||||||
from omegaconf import DictConfig | ||||||||||||
import hydra | ||||||||||||
import paddle | ||||||||||||
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn | ||||||||||||
import multiprocessing | ||||||||||||
|
||||||||||||
def train(cfg: DictConfig): | ||||||||||||
# set model | ||||||||||||
model = ppsci.arch.STAFNet(**cfg.MODEL) | ||||||||||||
train_dataloader_cfg = { | ||||||||||||
"dataset": { | ||||||||||||
"name": "STAFNetDataset", | ||||||||||||
"file_path": cfg.DATASET.data_dir, | ||||||||||||
"input_keys": cfg.MODEL.input_keys, | ||||||||||||
"label_keys": cfg.MODEL.output_keys, | ||||||||||||
"seq_len": cfg.MODEL.seq_len, | ||||||||||||
"pred_len": cfg.MODEL.pred_len, | ||||||||||||
}, | ||||||||||||
"batch_size": cfg.TRAIN.batch_size, | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该是EVAL? |
||||||||||||
"sampler": { | ||||||||||||
"name": "BatchSampler", | ||||||||||||
"drop_last": False, | ||||||||||||
"shuffle": True, | ||||||||||||
}, | ||||||||||||
Comment on lines
+22
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个"sampler"字段是否可以删掉?eval应该不需要shuffle |
||||||||||||
"collate_fn": gat_lstmcollate_fn, | ||||||||||||
} | ||||||||||||
|
||||||||||||
sup_constraint = ppsci.constraint.SupervisedConstraint( | ||||||||||||
train_dataloader_cfg, | ||||||||||||
loss=ppsci.loss.MSELoss("mean"), | ||||||||||||
name="STAFNet_Sup", | ||||||||||||
) | ||||||||||||
constraint = {sup_constraint.name: sup_constraint} | ||||||||||||
|
||||||||||||
# set optimizer | ||||||||||||
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.lr_scheduler)() | ||||||||||||
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model) | ||||||||||||
output_dir = cfg.output_dir | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
ITERS_PER_EPOCH = len(sup_constraint.data_loader) | ||||||||||||
|
||||||||||||
# initialize solver | ||||||||||||
solver = ppsci.solver.Solver( | ||||||||||||
model, | ||||||||||||
constraint, | ||||||||||||
output_dir, | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
optimizer, | ||||||||||||
lr_scheduler, | ||||||||||||
cfg.TRAIN.epochs, | ||||||||||||
ITERS_PER_EPOCH, | ||||||||||||
eval_during_train=cfg.TRAIN.eval_during_train, | ||||||||||||
seed=cfg.seed, | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
validator=validator, | ||||||||||||
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, | ||||||||||||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||||||||||||
) | ||||||||||||
|
||||||||||||
# train model | ||||||||||||
solver.train() | ||||||||||||
|
||||||||||||
def evaluate(cfg: DictConfig): | ||||||||||||
model = ppsci.arch.STAFNet(**cfg.MODEL) | ||||||||||||
eval_dataloader_cfg= { | ||||||||||||
"dataset": { | ||||||||||||
"name": "STAFNetDataset", | ||||||||||||
"file_path": cfg.EVAL.eval_data_path, | ||||||||||||
"input_keys": cfg.MODEL.input_keys, | ||||||||||||
"label_keys": cfg.MODEL.output_keys, | ||||||||||||
"seq_len": cfg.MODEL.seq_len, | ||||||||||||
"pred_len": cfg.MODEL.pred_len, | ||||||||||||
}, | ||||||||||||
"batch_size": cfg.TRAIN.batch_size, | ||||||||||||
"sampler": { | ||||||||||||
"name": "BatchSampler", | ||||||||||||
"drop_last": False, | ||||||||||||
"shuffle": True, | ||||||||||||
}, | ||||||||||||
Comment on lines
+74
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
"collate_fn": gat_lstmcollate_fn, | ||||||||||||
} | ||||||||||||
sup_validator = ppsci.validate.SupervisedValidator( | ||||||||||||
eval_dataloader_cfg, | ||||||||||||
loss=ppsci.loss.MSELoss("mean"), | ||||||||||||
metric={"MSE": ppsci.metric.MSE()}, | ||||||||||||
name="Sup_Validator", | ||||||||||||
) | ||||||||||||
validator = {sup_validator.name: sup_validator} | ||||||||||||
|
||||||||||||
# initialize solver | ||||||||||||
solver = ppsci.solver.Solver( | ||||||||||||
model, | ||||||||||||
validator=validator, | ||||||||||||
cfg=cfg, | ||||||||||||
pretrained_model_path=cfg.EVAL.pretrained_model_path, | ||||||||||||
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, | ||||||||||||
eval_with_no_grad=cfg.EVAL.eval_with_no_grad, | ||||||||||||
) | ||||||||||||
|
||||||||||||
# evaluate model | ||||||||||||
solver.eval() | ||||||||||||
|
||||||||||||
@hydra.main(version_base=None, config_path="./conf", config_name="stafnet.yaml") | ||||||||||||
def main(cfg: DictConfig): | ||||||||||||
if cfg.mode == "train": | ||||||||||||
train(cfg) | ||||||||||||
elif cfg.mode == "eval": | ||||||||||||
evaluate(cfg) | ||||||||||||
else: | ||||||||||||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||||||||||||
|
||||||||||||
if __name__ == "__main__": | ||||||||||||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件建议使用vscode的yaml插件格式化一下,或者提交前用pre-commit格式化:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#1