Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading checkpoint of trained model fails #183

Open
utkarshp1161 opened this issue Jun 7, 2023 · 5 comments
Open

Loading checkpoint of trained model fails #183

utkarshp1161 opened this issue Jun 7, 2023 · 5 comments

Comments

@utkarshp1161
Copy link

utkarshp1161 commented Jun 7, 2023

I don't face this error on another checkpoint that is trained on a different dataset.

Please let me know if any more information is required. Thanks

Error:
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 2
      1 #model = load_model("[/home/benchmarking_datasets/fone_output/torchmdnet/lips2/logs/epoch](https://vscode-remote+ssh-002dremote-002bhpc-002enode.vscode-resource.vscode-cdn.net/home/datasets/fone_output/torchmdnet/lips2/logs/epoch)=209-val_loss=0.0165-test_loss=0.0897.ckpt", derivative=True)
----> 2 model = load_model("[/home/torchmd-net/bench_data_sl/fone_output/torchmdnet/ala/logs/epoch](https://vscode-remote+ssh-002dremote-002bhpc-002enode.vscode-resource.vscode-cdn.net/home/torchmd-net/bench_data_sl/fone_output/torchmdnet/ala/logs/epoch)=49-val_loss=3.3522-test_loss=1.1794.ckpt", derivative = True)

File [~/torchmd-net/torchmdnet/models/model.py:116](https://vscode-remote+ssh-002dremote-002bhpc-002enode.vscode-resourc/torchmd-net/notebooks/~/torchmd-net/torchmdnet/models/model.py:116), in load_model(filepath, args, device, **kwargs)
--> 116 model.load_state_dict(state_dict)


File [~/anaconda3/envs/bebam_tmd/lib/python3.10/site-packages/torch/nn/modules/module.py:2041](https://vscode-remote+ssh-002dremote-002bhpc-002enode.vscode-resource.vscode-cdn.net/home/torchmd-net/notebooks/~/anaconda3/envs/bebam_tmd/lib/python3.10/site-packages/torch/nn/modules/module.py:2041), in Module.load_state_dict(self, state_dict, strict)
   2036         error_msgs.insert(
   2037             0, 'Missing key(s) in state_dict: {}. '.format(
   2038                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   2040 if len(error_msgs) > 0:
-> 2041     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for TorchMD_Net:
	size mismatch for mean: copying a param with shape torch.Size([1, 1]) from checkpoint, the shape in current model is torch.Size([]).
	size mismatch for std: copying a param with shape torch.Size([1, 1]) from checkpoint, the shape in current model is torch.Size([]).
@RaulPPelaez
Copy link
Collaborator

Could you provide some more information?
Your current conda env (Run conda list) and the yaml configuration you are using for this run would help.

@utkarshp1161
Copy link
Author

utkarshp1161 commented Jun 8, 2023

Sure

# packages in environment 
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
absl-py                   1.4.0                    pypi_0    pypi
aiohttp                   3.8.4                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
anyio                     3.6.2                    pypi_0    pypi
appdirs                   1.4.4                    pypi_0    pypi
argon2-cffi               21.3.0                   pypi_0    pypi
argon2-cffi-bindings      21.2.0                   pypi_0    pypi
arrow                     1.2.3                    pypi_0    pypi
ase                       3.22.1                   pypi_0    pypi
asttokens                 2.2.1                    pypi_0    pypi
async-timeout             4.0.2                    pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
backcall                  0.2.0                    pypi_0    pypi
beautifulsoup4            4.12.2                   pypi_0    pypi
bleach                    6.0.0                    pypi_0    pypi
bzip2                     1.0.8                h7b6447c_0  
ca-certificates           2023.01.10           h06a4308_0  
cachetools                5.3.0                    pypi_0    pypi
certifi                   2022.12.7                pypi_0    pypi
cffi                      1.15.1                   pypi_0    pypi
charset-normalizer        2.1.1                    pypi_0    pypi
click                     8.1.3                    pypi_0    pypi
cmake                     3.25.0                   pypi_0    pypi
comm                      0.1.3                    pypi_0    pypi
contourpy                 1.0.7                    pypi_0    pypi
cycler                    0.11.0                   pypi_0    pypi
debugpy                   1.6.7                    pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
defusedxml                0.7.1                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
e3nn                      0.5.1                    pypi_0    pypi
executing                 1.2.0                    pypi_0    pypi
fastjsonschema            2.16.3                   pypi_0    pypi
filelock                  3.9.0                    pypi_0    pypi
fonttools                 4.39.3                   pypi_0    pypi
fqdn                      1.5.1                    pypi_0    pypi
frozenlist                1.3.3                    pypi_0    pypi
fsspec                    2023.4.0                 pypi_0    pypi
gitdb                     4.0.10                   pypi_0    pypi
gitpython                 3.1.31                   pypi_0    pypi
google-auth               2.17.3                   pypi_0    pypi
google-auth-oauthlib      1.0.0                    pypi_0    pypi
grpcio                    1.54.0                   pypi_0    pypi
h5py                      3.8.0                    pypi_0    pypi
idna                      3.4                      pypi_0    pypi
ipykernel                 6.22.0                   pypi_0    pypi
ipython                   8.12.0                   pypi_0    pypi
ipython-genutils          0.2.0                    pypi_0    pypi
ipywidgets                8.0.6                    pypi_0    pypi
isoduration               20.11.0                  pypi_0    pypi
jedi                      0.18.2                   pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
joblib                    1.2.0                    pypi_0    pypi
jsonpointer               2.3                      pypi_0    pypi
jsonschema                4.17.3                   pypi_0    pypi
jupyter                   1.0.0                    pypi_0    pypi
jupyter-client            8.2.0                    pypi_0    pypi
jupyter-console           6.6.3                    pypi_0    pypi
jupyter-core              5.3.0                    pypi_0    pypi
jupyter-events            0.6.3                    pypi_0    pypi
jupyter-server            2.5.0                    pypi_0    pypi
jupyter-server-terminals  0.4.4                    pypi_0    pypi
jupyterlab-pygments       0.2.2                    pypi_0    pypi
jupyterlab-widgets        3.0.7                    pypi_0    pypi
kiwisolver                1.4.4                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.3                  he6710b0_2  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
lit                       15.0.7                   pypi_0    pypi
markdown                  3.4.3                    pypi_0    pypi
markupsafe                2.1.2                    pypi_0    pypi
matplotlib                3.7.1                    pypi_0    pypi
matplotlib-inline         0.1.6                    pypi_0    pypi
mistune                   2.0.5                    pypi_0    pypi
mpmath                    1.2.1                    pypi_0    pypi
multidict                 6.0.4                    pypi_0    pypi
nbclassic                 0.5.5                    pypi_0    pypi
nbclient                  0.7.3                    pypi_0    pypi
nbconvert                 7.3.1                    pypi_0    pypi
nbformat                  5.8.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nequip                    0.5.6                    pypi_0    pypi
nest-asyncio              1.5.6                    pypi_0    pypi
networkx                  3.0                      pypi_0    pypi
notebook                  6.5.4                    pypi_0    pypi
notebook-shim             0.2.3                    pypi_0    pypi
numpy                     1.24.3                   pypi_0    pypi
oauthlib                  3.2.2                    pypi_0    pypi
openssl                   1.1.1t               h7f8727e_0  
opt-einsum                3.3.0                    pypi_0    pypi
opt-einsum-fx             0.1.4                    pypi_0    pypi
packaging                 23.1                     pypi_0    pypi
pandas                    2.0.1                    pypi_0    pypi
pandocfilters             1.5.0                    pypi_0    pypi
parso                     0.8.3                    pypi_0    pypi
pathtools                 0.1.2                    pypi_0    pypi
pexpect                   4.8.0                    pypi_0    pypi
pickleshare               0.7.5                    pypi_0    pypi
pillow                    9.3.0                    pypi_0    pypi
pip                       23.0.1          py310h06a4308_0  
platformdirs              3.2.0                    pypi_0    pypi
prometheus-client         0.16.0                   pypi_0    pypi
prompt-toolkit            3.0.38                   pypi_0    pypi
protobuf                  4.22.3                   pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
pure-eval                 0.2.2                    pypi_0    pypi
pyasn1                    0.5.0                    pypi_0    pypi
pyasn1-modules            0.3.0                    pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
pydeprecate               0.3.2                    pypi_0    pypi
pyg-lib                   0.2.0+pt20cu118          pypi_0    pypi
pygments                  2.15.1                   pypi_0    pypi
pyparsing                 3.0.9                    pypi_0    pypi
pyrsistent                0.19.3                   pypi_0    pypi
python                    3.10.0               h12debd9_5  
python-dateutil           2.8.2                    pypi_0    pypi
python-json-logger        2.0.7                    pypi_0    pypi
pytorch-lightning         1.6.3                    pypi_0    pypi
pytz                      2023.3                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
pyzmq                     25.0.2                   pypi_0    pypi
qtconsole                 5.4.2                    pypi_0    pypi
qtpy                      2.3.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
requests                  2.28.1                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
rfc3339-validator         0.1.4                    pypi_0    pypi
rfc3986-validator         0.1.1                    pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
scikit-learn              1.2.2                    pypi_0    pypi
scipy                     1.10.1                   pypi_0    pypi
send2trash                1.8.0                    pypi_0    pypi
sentry-sdk                1.21.0                   pypi_0    pypi
setproctitle              1.3.2                    pypi_0    pypi
setuptools                66.0.0          py310h06a4308_0  
six                       1.16.0                   pypi_0    pypi
smmap                     5.0.0                    pypi_0    pypi
sniffio                   1.3.0                    pypi_0    pypi
soupsieve                 2.4.1                    pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0  
stack-data                0.6.2                    pypi_0    pypi
sympy                     1.11.1                   pypi_0    pypi
tabulate                  0.9.0                    pypi_0    pypi
tensorboard               2.12.3                   pypi_0    pypi
tensorboard-data-server   0.7.0                    pypi_0    pypi
terminado                 0.17.1                   pypi_0    pypi
threadpoolctl             3.1.0                    pypi_0    pypi
tinycss2                  1.2.1                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
torch                     2.0.0+cu118              pypi_0    pypi
torch-cluster             1.6.1+pt20cu118          pypi_0    pypi
torch-ema                 0.3                      pypi_0    pypi
torch-geometric           2.3.0                    pypi_0    pypi
torch-runstats            0.2.0                    pypi_0    pypi
torch-scatter             2.1.1+pt20cu118          pypi_0    pypi
torch-sparse              0.6.17+pt20cu118          pypi_0    pypi
torch-spline-conv         1.2.2+pt20cu118          pypi_0    pypi
torchaudio                2.0.1+cu118              pypi_0    pypi
torchmd-net               0                        pypi_0    pypi
torchmetrics              0.11.4                   pypi_0    pypi
torchvision               0.15.1+cu118             pypi_0    pypi
tornado                   6.3.1                    pypi_0    pypi
tqdm                      4.65.0                   pypi_0    pypi
traitlets                 5.9.0                    pypi_0    pypi
triton                    2.0.0                    pypi_0    pypi
typing-extensions         4.4.0                    pypi_0    pypi
tzdata                    2023.3                   pypi_0    pypi
uri-template              1.2.0                    pypi_0    pypi
urllib3                   1.26.13                  pypi_0    pypi
wandb                     0.15.0                   pypi_0    pypi
wcwidth                   0.2.6                    pypi_0    pypi
webcolors                 1.13                     pypi_0    pypi
webencodings              0.5.1                    pypi_0    pypi
websocket-client          1.5.1                    pypi_0    pypi
werkzeug                  2.3.3                    pypi_0    pypi
wheel                     0.38.4          py310h06a4308_0  
widgetsnbextension        4.0.7                    pypi_0    pypi
xz                        5.2.10               h5eee18b_1  
yarl                      1.9.2                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_0  

@utkarshp1161
Copy link
Author

utkarshp1161 commented Jun 8, 2023

Yaml file:

activation: silu
aggr: add
atom_filter: -1
attn_activation: silu
batch_size: 8
charge: false
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 5.0
dataset: MD17
dataset_arg:
  molecules: aspirin
dataset_root: mdsim_data/ala/40k/
derivative: true
distance_influence: both
early_stopping_patience: 300
ema_alpha_neg_dy: 1.0
ema_alpha_y: 0.05
embed_files: null
embedding_dimension: 128
energy_files: null
force_files: null
inference_batch_size: 64
load_model: null
log_dir: torchmdnet/ala/logs/
lr: 0.001
lr_factor: 0.8
lr_metric: val_loss
lr_min: 1.0e-07
lr_patience: 30
lr_warmup_steps: 1000
max_num_neighbors: 32
max_z: 100
model: equivariant-transformer
neg_dy_weight: 0.8
neighbor_embedding: true
ngpus: -1
num_epochs: 3000
num_heads: 8
num_layers: 6
num_nodes: 1
num_rbf: 32
num_workers: 6
output_model: Scalar
precision: 32
prior_model: null
rbf_type: expnorm
redirect: false
reduce_op: add
reset_trainer: false
save_interval: 10
seed: 1
spin: false
splits: null
standardize: true
tensorboard_use: false
test_interval: 10
test_size: null
train_size: 950
trainable_rbf: false
val_size: 50
wandb_name: training
wandb_project: training_
wandb_use: false
weight_decay: 0.0
y_weight: 0.2

@RaulPPelaez
Copy link
Collaborator

RaulPPelaez commented Jun 8, 2023

Could be an issue with standardize. Do you see this also when it is set to False?
Would you tell me a bit about your use case regarding the use of this flag?
EDIT: Well, I see you are using MD17, and the example sets this to true.

@utkarshp1161
Copy link
Author

Actually, I am training on a custom data alanine dipeptide. I have edited the example yaml file md17 file. The dataset and dataset arg flags have been commented on and augmented appropriately so I can train from custom npz files. I have successfully trained on liPS data using a similar strategy but failing in this case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants