-
Notifications
You must be signed in to change notification settings - Fork 599
feat(pt_expt): add dp freeze support for pt_expt backend (.pte/.pt2)
#5299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
3606340
50155fb
ac78e07
6145fe8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |||||
| from deepmd.dpmodel.array_api import ( | ||||||
| Array, | ||||||
| xp_take_along_axis, | ||||||
| xp_take_first_n, | ||||||
| ) | ||||||
| from deepmd.dpmodel.common import ( | ||||||
| cast_precision, | ||||||
|
|
@@ -534,7 +535,7 @@ | |||||
| (nf, nall, self.tebd_dim), | ||||||
| ) | ||||||
| # nfnl x tebd_dim | ||||||
| atype_embd = atype_embd_ext[:, :nloc, :] | ||||||
| atype_embd = xp_take_first_n(atype_embd_ext, 1, nloc) | ||||||
| grrg, g2, h2, rot_mat, sw = self.se_atten( | ||||||
| nlist, | ||||||
| coord_ext, | ||||||
|
|
@@ -1056,7 +1057,8 @@ | |||||
| self.stddev[...], | ||||||
| ) | ||||||
| nf, nloc, nnei, _ = dmatrix.shape | ||||||
| atype = atype_ext[:, :nloc] | ||||||
| nall = atype_ext.shape[1] | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drop the dead Ruff already flags this as 💡 Suggested fix- nall = atype_ext.shape[1]
atype = xp_take_first_n(atype_ext, 1, nloc)📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.4)[error] 1060-1060: Local variable Remove assignment to unused variable (F841) 🤖 Prompt for AI Agents |
||||||
| atype = xp_take_first_n(atype_ext, 1, nloc) | ||||||
| exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) | ||||||
| # nfnl x nnei | ||||||
| exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) | ||||||
|
|
@@ -1075,6 +1077,12 @@ | |||||
| nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) | ||||||
| ng = self.neuron[-1] | ||||||
| nt = self.tebd_dim | ||||||
|
|
||||||
| # Gather neighbor info using xp_take_along_axis along axis=1. | ||||||
| # This avoids flat (nf*nall,) indexing that creates Ne(nall, nloc) | ||||||
| # constraints in torch.export, breaking NoPbc (nall == nloc). | ||||||
| nlist_2d = xp.reshape(nlist_masked, (nf, nloc * nnei)) # (nf, nloc*nnei) | ||||||
|
|
||||||
| # nfnl x nnei x 4 | ||||||
| rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) | ||||||
| rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) | ||||||
|
|
@@ -1083,15 +1091,16 @@ | |||||
| if self.tebd_input_mode in ["concat"]: | ||||||
| # nfnl x tebd_dim | ||||||
| atype_embd = xp.reshape( | ||||||
| atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim) | ||||||
| xp_take_first_n(atype_embd_ext, 1, nloc), (nf * nloc, self.tebd_dim) | ||||||
| ) | ||||||
| # nfnl x nnei x tebd_dim | ||||||
| atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1)) | ||||||
| index = xp.tile( | ||||||
| xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim) | ||||||
| # Gather neighbor type embeddings: (nf, nall, tebd_dim) -> (nf, nloc*nnei, tebd_dim) | ||||||
| nlist_idx_tebd = xp.tile(nlist_2d[:, :, xp.newaxis], (1, 1, self.tebd_dim)) | ||||||
| atype_embd_nlist = xp_take_along_axis( | ||||||
| atype_embd_ext, nlist_idx_tebd, axis=1 | ||||||
| ) | ||||||
| # nfnl x nnei x tebd_dim | ||||||
| atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) | ||||||
| atype_embd_nlist = xp.reshape( | ||||||
| atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) | ||||||
| ) | ||||||
|
|
@@ -1110,10 +1119,9 @@ | |||||
| assert self.embeddings_strip is not None | ||||||
| assert type_embedding is not None | ||||||
| ntypes_with_padding = type_embedding.shape[0] | ||||||
| # nf x (nl x nnei) | ||||||
| nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei)) | ||||||
| # nf x (nl x nnei) | ||||||
| nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1) | ||||||
| # Gather neighbor types: (nf, nall) -> (nf, nloc*nnei) | ||||||
| nei_type = xp_take_along_axis(atype_ext, nlist_2d, axis=1) | ||||||
| nei_type = xp.reshape(nei_type, (-1,)) # (nf * nloc * nnei,) | ||||||
| # (nf x nl x nnei) x ng | ||||||
| nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng)) | ||||||
| if self.type_one_side: | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -160,6 +160,54 @@ def train( | |||||||||||||||||
| trainer.run() | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def freeze( | ||||||||||||||||||
| model: str, | ||||||||||||||||||
| output: str = "frozen_model.pt2", | ||||||||||||||||||
| head: str | None = None, | ||||||||||||||||||
| ) -> None: | ||||||||||||||||||
| """Freeze a pt_expt training checkpoint to .pte or .pt2 format. | ||||||||||||||||||
|
|
||||||||||||||||||
| Parameters | ||||||||||||||||||
| ---------- | ||||||||||||||||||
| model : str | ||||||||||||||||||
| Path to the training checkpoint (.pt file). | ||||||||||||||||||
| output : str | ||||||||||||||||||
| Path for the frozen model output (.pte or .pt2). | ||||||||||||||||||
| head : str or None | ||||||||||||||||||
| Head to freeze in a multi-task model (not yet supported). | ||||||||||||||||||
| """ | ||||||||||||||||||
| import torch | ||||||||||||||||||
|
|
||||||||||||||||||
| from deepmd.pt_expt.model import ( | ||||||||||||||||||
| get_model, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from deepmd.pt_expt.train.wrapper import ( | ||||||||||||||||||
| ModelWrapper, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from deepmd.pt_expt.utils.env import ( | ||||||||||||||||||
| DEVICE, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from deepmd.pt_expt.utils.serialization import ( | ||||||||||||||||||
| deserialize_to_file, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| state_dict = torch.load(model, map_location=DEVICE, weights_only=True) | ||||||||||||||||||
| if "model" in state_dict: | ||||||||||||||||||
| state_dict = state_dict["model"] | ||||||||||||||||||
| model_params = state_dict["_extra_state"]["model_params"] | ||||||||||||||||||
|
|
||||||||||||||||||
| # Reconstruct model and load weights | ||||||||||||||||||
| pt_expt_model = get_model(model_params).to(DEVICE) | ||||||||||||||||||
| wrapper = ModelWrapper(pt_expt_model) | ||||||||||||||||||
| wrapper.load_state_dict(state_dict) | ||||||||||||||||||
| pt_expt_model.eval() | ||||||||||||||||||
|
|
||||||||||||||||||
| # Serialize to dict and export | ||||||||||||||||||
| model_dict = pt_expt_model.serialize() | ||||||||||||||||||
| deserialize_to_file(output, {"model": model_dict}) | ||||||||||||||||||
| log.info(f"Saved frozen model to {output}") | ||||||||||||||||||
|
Comment on lines
+163
to
+208
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't silently ignore If the caller passes a head today, 🛑 Suggested guard def freeze(
model: str,
output: str = "frozen_model.pt2",
head: str | None = None,
) -> None:
@@
from deepmd.pt_expt.utils.serialization import (
deserialize_to_file,
)
+ if head is not None:
+ raise NotImplementedError(
+ "--head is not supported for the pt_expt freeze command yet."
+ )
+
state_dict = torch.load(model, map_location=DEVICE, weights_only=True)As per coding guidelines, "Always run 🧰 Tools🪛 Ruff (0.15.4)[warning] 166-166: Unused function argument: (ARG001) 🤖 Prompt for AI Agents |
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def main(args: list[str] | argparse.Namespace | None = None) -> None: | ||||||||||||||||||
| """Entry point for the pt_expt backend CLI. | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -195,6 +243,18 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None: | |||||||||||||||||
| skip_neighbor_stat=FLAGS.skip_neighbor_stat, | ||||||||||||||||||
| output=FLAGS.output, | ||||||||||||||||||
| ) | ||||||||||||||||||
| elif FLAGS.command == "freeze": | ||||||||||||||||||
| if Path(FLAGS.checkpoint_folder).is_dir(): | ||||||||||||||||||
| checkpoint_path = Path(FLAGS.checkpoint_folder) | ||||||||||||||||||
| latest_ckpt_file = (checkpoint_path / "checkpoint").read_text() | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Reading Useful? React with 👍 / 👎. |
||||||||||||||||||
| FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file)) | ||||||||||||||||||
|
Comment on lines
+247
to
+250
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Strip the checkpoint pointer before joining the path.
🧩 Suggested fix- latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
- FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))
+ latest_ckpt_file = (checkpoint_path / "checkpoint").read_text().strip()
+ FLAGS.model = str(checkpoint_path / latest_ckpt_file)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||
| else: | ||||||||||||||||||
| FLAGS.model = FLAGS.checkpoint_folder | ||||||||||||||||||
| # Default to .pt2; user can specify .pte via -o flag | ||||||||||||||||||
| suffix = Path(FLAGS.output).suffix | ||||||||||||||||||
| if suffix not in (".pte", ".pt2"): | ||||||||||||||||||
| FLAGS.output = str(Path(FLAGS.output).with_suffix(".pt2")) | ||||||||||||||||||
| freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head) | ||||||||||||||||||
| else: | ||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||
| f"Unsupported command '{FLAGS.command}' for the pt_expt backend." | ||||||||||||||||||
|
|
||||||||||||||||||
Check notice
Code scanning / CodeQL
Unused local variable Note