Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train only an output model, freezing the representation model. #317

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

RaulPPelaez
Copy link
Collaborator

@RaulPPelaez RaulPPelaez commented Apr 17, 2024

Adds the --freeze-representation, --reset-output-model and --overwrite-representation to train.py.

  • Freeze representation: Makes it so that the representation model weights are not trained
  • Reset output model: Makes it so that the reset_parameters() is called on the output model after loading it for training. Ignored if load-model is not used.
  • Overwrite representation: Takes a path to a checkpoint, if present the weights of the representation model will be taken from here as initial weights.

This allows to train many output modules while keeping a single representation model. The workflow is intended to work like this:

 $ torchmd-train --conf my_model1.yaml --log-dir model1 # Initial training for the representation model
  # Train the second model but load the representation weighs from the first one.
  # Note that there are no limitations on the output model here with respect to model1.
 $ torchmd-train --conf my_model2.yaml --log-dir model2 --freeze-representation --overwrite-representation model1/best.ckpt
 # Now you have two models that share the representation model

For inference we can take advantage of the shared representation model and trick torch into calling it only one time.
For this we can create a class similar to Ensemble. For prototyping we can simply make it similar to TorchMD_Net like:

    def forward(
        self,
        z: Tensor,
        pos: Tensor,
        batch: Optional[Tensor] = None,
        box: Optional[Tensor] = None,
        q: Optional[Tensor] = None,
        s: Optional[Tensor] = None,
        extra_args: Optional[Dict[str, Tensor]] = None,
    ) -> Tuple[Tensor, Tensor]:
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        if self.derivative:
            pos.requires_grad_(True)
        x, v, z, pos, batch = self.models[0].representation_model(
            z, pos, batch, box=box, q=q, s=s
        )
        y = []
        neg_dy = []
        for m in self.models:
           o = m.output_model
           x_o = o.pre_reduce(x,v,z,pos,batch)
           if self.prior_model is not None:
               for prior in self.prior_model:
                   x_o = prior.pre_reduce(x_o, z, pos, batch, extra_args)
            x_o = o.reduce(x_o, batch)
            y_o = o.post_reduce(x_o)
            if self.prior_model is not None:
                for prior in self.prior_model:
                    y_o = prior.post_reduce(y_o, z, pos, batch, box, extra_args)
            y.append(y_o)
            if self.derivative:
                grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y_o)]
                dy_o = grad(
                    [y_o],
                    [pos],
                    grad_outputs=grad_outputs,
                    create_graph=self.training,
                    retain_graph=self.training,
                )[0]
                assert dy_o is not None, "Autograd returned None for the force prediction."
                neg_dy.append(-dy_o)
        y = torch.stack(y)
        neg_dy = torch.stack(neg_dy) if self.derivative else torch.empty(0)
        y_mean = torch.mean(y, axis=0)
        neg_dy_mean = torch.mean(neg_dy, axis=0)  if self.derivative else torch.empty(0)
        y_std = torch.std(y, axis=0)
        neg_dy_std = torch.std(neg_dy, axis=0)  if self.derivative else torch.empty(0)

        if self.return_std:
            return y_mean, neg_dy_mean, y_std, neg_dy_std
        else:
            return y_mean, neg_dy_mean

weights when loading an already trained model for training.
@RaulPPelaez
Copy link
Collaborator Author

cc @stefdoerr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant