Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to PyTorch 2.2 #52

Merged
merged 19 commits into from
Apr 22, 2024
Merged

Update to PyTorch 2.2 #52

merged 19 commits into from
Apr 22, 2024

Conversation

awf
Copy link
Contributor

@awf awf commented Apr 15, 2024

A number of changes needed to work with PT2.2+. Primarily, there's now (since commit) no good way to intercept module calls, so we instead replace all nn.Modules with trivial subclasses, making them "user" modules.

Passes tests on asses tests on 2.1,2.2 (stable),2.4 (nightly) and examples/scale_analysis.py produces identical output (with torch.manual_seed).

This implementation is somewhat faster, as it does less work in the patched forward functions.

One source of changes is that the node naming post 2.2 better reflects the input code.
For example, input code

def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]:  # pragma: no cover
    # idxs has 0 args -> shouldn't be pruned
    x = self.emb(idxs)  # emb has 1 float arg (weights) -> depends on tol
    _x = x.flatten(start_dim=0, end_dim=-1)  # 1 float, same scale -> prune
    x = _x.view(x.shape)  # 1 float arg, same scale -> prune
    y = self.linear(x)  # scale changes -> shouldn't be pruned
    scores = F.softmax(y, dim=-1)  # scale changes -> shouldn't be pruned
    top_idx = torch.argmax(scores, dim=-1)  # not float -> shouldn't be pruned
    top_idx = torch.unsqueeze(top_idx, -1)  # not float -> shouldn't be pruned
    top_score_x = torch.gather(x, -1, top_idx)  # small change -> depends on tol
    top_score_x += randn_like(top_score_x)  # 2 floats, same scale -> no prune
    return top_score_x, top_idx

Became, pre 2.2, where the variable names are derived from the operator:

# Example of a pre-2.2 captured graph
def forward(self, L_idxs_ : torch.Tensor):
    l_idxs_ = L_idxs_
    l__self___emb_weight = foo.L__self___emb_weight
    embedding = torch.nn.functional.embedding(l_idxs_, l__self___emb_weight, None, None, 2.0, False, False);  l_idxs_ = l__self___emb_weight = None
    flatten = embedding.flatten(start_dim = 0, end_dim = -1);  embedding = None
    view = flatten.view((8, 32, 64));  flatten = None
    l__self___linear_weight = foo.L__self___linear_weight
    l__self___linear_bias = foo.L__self___linear_bias
    linear = torch._C._nn.linear(view, l__self___linear_weight, l__self___linear_bias);  l__self___linear_weight = l__self___linear_bias = None
    softmax = torch.nn.functional.softmax(linear, dim = -1);  linear = None
    argmax = torch.argmax(softmax, dim = -1);  softmax = None
    unsqueeze = torch.unsqueeze(argmax, -1);  argmax = None
    gather = torch.gather(view, -1, unsqueeze);  view = None
    randn_like = torch.randn_like(gather)
    gather += randn_like;  iadd = gather;  gather = randn_like = None
    return (iadd, unsqueeze)

and are now

# Example of a post-2.2 captured graph
def forward(self, L_idxs_ : torch.Tensor):
    l_idxs_ = L_idxs_
    l__self___emb_weight = foo.L__self___emb_weight
    x = torch.nn.functional.embedding(l_idxs_, l__self___emb_weight, None, None, 2.0, False, False);  l_idxs_ = l__self___emb_weight = None
    _x = x.flatten(start_dim = 0, end_dim = -1);  x = None
    x_1 = _x.view((8, 32, 64));  _x = None
    l__self___linear_weight = foo.L__self___linear_weight
    l__self___linear_bias = foo.L__self___linear_bias
    y = torch._C._nn.linear(x_1, l__self___linear_weight, l__self___linear_bias);  l__self___linear_weight = l__self___linear_bias = None
    scores = torch.nn.functional.softmax(y, dim = -1);  y = None
    top_idx = torch.argmax(scores, dim = -1);  scores = None
    top_idx_1 = torch.unsqueeze(top_idx, -1);  top_idx = None
    top_score_x = torch.gather(x_1, -1, top_idx_1);  x_1 = None
    randn_like = torch.randn_like(top_score_x)
    top_score_x += randn_like;  top_score_x_1 = top_score_x;  top_score_x = randn_like = None
    return (top_score_x_1, top_idx_1)

@awf awf marked this pull request as draft April 15, 2024 21:40
unit_scaling/transforms/utils.py Outdated Show resolved Hide resolved
@awf awf changed the title Update to PyTorch 2.3 WIP: Update to PyTorch 2.3 Apr 16, 2024
@awf awf changed the title WIP: Update to PyTorch 2.3 Update to PyTorch 2.3 Apr 17, 2024
@awf awf marked this pull request as ready for review April 17, 2024 11:19
Comment on lines +439 to +441
p.set_yticks(p.get_yticks())
p.set_yticklabels([_rename(item.get_text()) for item in p.get_yticklabels()])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is suppressing a warning about setting yticklabels without setting ticks

@awf awf marked this pull request as draft April 18, 2024 09:42
@awf awf changed the title Update to PyTorch 2.3 Update to PyTorch 2.2 Apr 18, 2024
@awf awf marked this pull request as ready for review April 19, 2024 17:17
Copy link
Collaborator

@thecharlieblake thecharlieblake left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this Andrew, pleased with the approach we've converged to (and learned a few things about deepcopy and pickling in the process).

I'll leave merging to you in case there are any other adjustments

@awf awf merged commit 2394ee4 into main Apr 22, 2024
1 check passed
@awf awf deleted the awf/pt23 branch April 22, 2024 08:49
@awf awf restored the awf/pt23 branch April 22, 2024 10:38
@awf awf deleted the awf/pt23 branch April 22, 2024 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants