Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cost Matrix Computation in Weight Matching #4

Open
frallebini opened this issue Oct 29, 2022 · 10 comments
Open

Cost Matrix Computation in Weight Matching #4

frallebini opened this issue Oct 29, 2022 · 10 comments

Comments

@frallebini
Copy link

Hi, I read the paper and I am having a really hard time reconciling the formula

weight_matching

with the actual computation of the cost matrix for the LAP in weight_matching.py, namely

A = jnp.zeros((n, n))
for wk, axis in ps.perm_to_axes[p]:
  w_a = params_a[wk]
  w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
  w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
  w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
  A += w_a @ w_b.T

Are you following a different mathematical derivation or am I missing something?

@samuela
Copy link
Owner

samuela commented Oct 29, 2022

Hi @frallebini! The writeup in the paper is for the special case of an MLP with no bias terms -- the version in the code is just more general. The connection here is that there's a sum over all weight arrays that interact with that P_\ell. Then for each one, we need to apply its relevant permutations on all other axis, take the Frobenius inner product with the reference model, and all those terms together. So A represents that sum, each for loop iterations adds a single term in to the sum, get_permuted_param applies the other (non-P_\ell) permutations to w_b, and the moveaxis-reshape-matmul corresponds to the Frobenius inner product with w_a.

@frallebini
Copy link
Author

Thanks @samuela, I understand that the code is a generalization of the MLP with no bias case, but still:

  1. If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?
  2. How does get_permuted_param "skip" the non-P_\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?

@samuela
Copy link
Owner

samuela commented Oct 29, 2022

If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?

Ack, you're right! I messed up: it's not actually a Frobenius inner product, just a regular matrix product. The moveaxis-reshape combo is necessary to flatten dimensions that we don't care about in the case of non-2d weight arrays.

How does get_permuted_param "skip" the non-P_\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?

Yup, that's exactly what except_axis is doing. But I think you may have it backwards -- except_axis is excepting the P_\ell axis but applying all other fixed P's to all the other axes.

@frallebini
Copy link
Author

Ok, but let us consider the MLP-with-no bias case. The way the paper models weight matching as an LAP is

weight_matching_complete

In other words, it computes A as

paper (1)

What the code does, instead—if I understood correctly—is computing A by

  1. Permuting w_b disregarding P_\ell
  2. Transposing it
  3. Multiplying w_a by it

In other words

code (2)

I don't think (1) and (2) are the same thing though.

@samuela
Copy link
Owner

samuela commented Nov 1, 2022

Hmm I think the error here is in the first line of (2): The shapes here don't line up since $W_\ell^A$ has shape (n, *) and $W_{\ell+1}^A$ has shape (*, n). So adding those things together will result in a shape error if your layers have different widths.

I think tracing out the code for the MLP without bias terms case is a good idea. In that case we run through the for wk, axis in ps.perm_to_axes[p]: loop two times: once for $W_\ell$ and once for $W_{\ell+1}$.

  • For $W_\ell$: First of all, axis=0 since $W_\ell$ has shape (n, *). Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) will give us $W_\ell^B P_{\ell-1}^T$. In other words, $W_\ell^B$ but with the other permutations -- $P_{\ell-1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1)) will be a no-op since axis = 0. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1)) will also be a no-op. So, w_a @ w_b.T is $W_\ell^A (W_\ell^B P_{\ell-1}^T)^T$ matches up with the first term in the sum.
  • For $W_{\ell+1}$: In this case axis = 1 since $W_{\ell+1}$ has shape (*, n). Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) will give us $P_{\ell+1} W_{\ell+1}^B$. In other words, $W_{\ell+1}^B$ but with the other permutations -- $P_{\ell+1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1)) will result in a transpose, aka $(W_{\ell+1}^A)^T$. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1)) will also result in a transpose, aka $(W_{\ell+1}^B)^T P_{\ell+1}^T$. So, w_a @ w_b.T matches up with the second term in the sum.

@frallebini
Copy link
Author

frallebini commented Nov 1, 2022

Ok, the role of moveaxis is clear, and the computation matches the formula in the paper for an MLP with no biases.

On the other hand, the reshape((n, -1)) (extending the reasoning to the presence of biases):

  • Is always a no-op for weight matrices—as n is either the number of rows of $W_\ell$ or it is the number of columns of $W_{\ell+1}$, which however has already been transposed by the moveaxis.
  • It is needed in order to transform the (n,) bias vectors into (n, 1) vectors so that w_a @ w_b.T is a (n, n) matrix which can be added to A.

Right?

@samuela
Copy link
Owner

samuela commented Nov 1, 2022

That's correct! In addition, it's necessary when dealing weight arrays of higher shapes as well, eg in a convolutional layer where the weights have shape (w, h, channel_in, channel_out).

@LeCongThuong
Copy link

LeCongThuong commented Dec 9, 2022

Hi, I read the code and I really did not understand the following snippet. Because It relates to the weight matching algorithm, so I post here.
In the line 199 weight_matching.py:

perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

According to the above line, if W_\ell has shape [m, n] (m is output feature dim, n is input feature dim) in the Dense layer, then the shape of the permutation matrix P_\ell will be [n, n]. But when I read the paper, I think it should be [m, m].

Sorry for the silly question, but might you explain? @samuela @frallebini

Thank you!

@samuela
Copy link
Owner

samuela commented Dec 9, 2022

Hi @LeCongThuong, ps.perm_to_axes is a dict of form PermutationId => [(ParamId, Axis), ...] where in this case PermutationIds are strings, ParamIds are also strings, and Axiss are integers. So for example in an MLP (without bias and assuming that weights have shape [out_dim, in_dim]) terms this dict would look something like

{ "P_5": [("Dense_5/kernel", 0), ("Dense_6/kernel", 1)], ... }

Therefore, axes[0][0] will be something like "Dense_0/kernel" and axes[0][1] will be 0. HTH!

@LeCongThuong
Copy link

Thank you so much for replying @samuela!

I tried to understand ps.perm_to_axes and got the meaning of Axis. Axis, from what I got from your comment, it will let us know to permute W_b to another axis than "Axis''. Following your above example, I think it should be

{ "P_5": [("Dense_5/kernel", 1), ("Dense_6/kernel", 0)], ... }

From that axes[0][1] will be 1, thus the shape of P_l will be [n, n].

Thank you again for replying to my question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants