Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add STAFNet Model for Air Quality Prediction #1070

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
84 changes: 84 additions & 0 deletions examples/demo/conf/stafnet.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件建议使用vscode的yaml插件格式化一下,或者提交前用pre-commit格式化:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#1

pre-commit run --files xxx.yaml

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_
hydra:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

配置文件开头请加上以下字段:

defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

run:
# dynamic output directory according to running time and override name
dir: outputs_chip_heat/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
log_freq: 20
# dataset setting
DATASET:
label_keys: [label]
data_dir: ./dataset/train_data.pkl


MODEL:
input_keys: [aq_train_data, mete_train_data]
output_keys: [label]
output_attention: True
seq_len: 72
pred_len: 48
aq_gat_node_features: 7
aq_gat_node_num: 35
mete_gat_node_features: 7
mete_gat_node_num: 18
gat_hidden_dim: 32
gat_edge_dim: 3
e_layers: 1
enc_in: 7
dec_in: 7
c_out: 7
d_model: 16
embed: "fixed"
freq: "t"
dropout: 0.05
factor: 3
n_heads: 4
d_ff: 32
num_kernels: 6
top_k: 4

# training settings
TRAIN:
epochs: 100
iters_per_epoch: 400
save_freq: 10
eval_during_train: true
eval_freq: 10
batch_size: 1
lr_scheduler:
epochs: ${TRAIN.epochs}
iters_per_epoch: ${TRAIN.iters_per_epoch}
learning_rate: 0.001
step_size: 10
gamma: 0.9
pretrained_model_path: null
checkpoint_path: null

EVAL:
eval_data_path: ./dataset/val_data.pkl
pretrained_model_path: null
compute_metric_by_batch: false
eval_with_no_grad: true
batch_size: 1
112 changes: 112 additions & 0 deletions examples/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import ppsci
from ppsci.utils import logger
from omegaconf import DictConfig
import hydra
import paddle
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn
import multiprocessing

def train(cfg: DictConfig):
# set model
model = ppsci.arch.STAFNet(**cfg.MODEL)
train_dataloader_cfg = {
"dataset": {
"name": "STAFNetDataset",
"file_path": cfg.DATASET.data_dir,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"seq_len": cfg.MODEL.seq_len,
"pred_len": cfg.MODEL.pred_len,
},
"batch_size": cfg.TRAIN.batch_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是EVAL?

"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
Comment on lines +22 to +26
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个"sampler"字段是否可以删掉?eval应该不需要shuffle

"collate_fn": gat_lstmcollate_fn,
}

sup_constraint = ppsci.constraint.SupervisedConstraint(
train_dataloader_cfg,
loss=ppsci.loss.MSELoss("mean"),
name="STAFNet_Sup",
)
constraint = {sup_constraint.name: sup_constraint}

# set optimizer
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.lr_scheduler)()
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
output_dir = cfg.output_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_dir = cfg.output_dir

ITERS_PER_EPOCH = len(sup_constraint.data_loader)

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
output_dir,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_dir,
cfg.output_dir,

optimizer,
lr_scheduler,
cfg.TRAIN.epochs,
ITERS_PER_EPOCH,
eval_during_train=cfg.TRAIN.eval_during_train,
seed=cfg.seed,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seed=cfg.seed,

validator=validator,
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# train model
solver.train()

def evaluate(cfg: DictConfig):
model = ppsci.arch.STAFNet(**cfg.MODEL)
eval_dataloader_cfg= {
"dataset": {
"name": "STAFNetDataset",
"file_path": cfg.EVAL.eval_data_path,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"seq_len": cfg.MODEL.seq_len,
"pred_len": cfg.MODEL.pred_len,
},
"batch_size": cfg.TRAIN.batch_size,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
Comment on lines +74 to +78
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},

"collate_fn": gat_lstmcollate_fn,
}
sup_validator = ppsci.validate.SupervisedValidator(
eval_dataloader_cfg,
loss=ppsci.loss.MSELoss("mean"),
metric={"MSE": ppsci.metric.MSE()},
name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
validator=validator,
cfg=cfg,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# evaluate model
solver.eval()

@hydra.main(version_base=None, config_path="./conf", config_name="stafnet.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")

if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions ppsci/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ppsci.utils import logger # isort:skip
from ppsci.arch.regdgcnn import RegDGCNN # isort:skip
from ppsci.arch.ifm_mlp import IFMMLP # isort:skip
from ppsci.arch.stafnet import STAFNet # isort:skip

__all__ = [
"MoFlowNet",
Expand Down Expand Up @@ -111,6 +112,7 @@
"VelocityGenerator",
"RegDGCNN",
"IFMMLP",
"STAFNet",
]


Expand Down
Loading