-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Implement graph acquisition #164
base: feat-torch-wl-kernel
Are you sure you want to change the base?
Conversation
test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs | ||
return self._wl_kernel(self._train_graph_dataset, test_dataset) | ||
|
||
|
||
class MixedSingleTaskGP(SingleTaskGP): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't actually used anymore
N_GRAPH = 1 | ||
assert N_GRAPH == 1, "This example only supports a single graph feature" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can assume we'll only ever have 1 graph parameter for now. In the future, if we need more, we could have a kernel per graph hyperparameter.
X = torch.empty( | ||
size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH), | ||
dtype=torch.float64, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The + N_GRAPH
is now where the indices for the graph lookup will go
if N_GRAPH > 0: | ||
wl_kernel = ScaleKernel( | ||
TorchWLKernel( | ||
graph_lookup=train_graphs, | ||
n_iter=5, | ||
normalize=True, | ||
active_dims=(X.shape[1] - 1,), # Last column | ||
) | ||
) | ||
kernels.append(wl_kernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that TorchWLKernel
inherits from Kernel
, we can use it just like any other kernel, including wrapping it in a ScaleKernel
.
Importantly, the graph_lookup
we pass in is what the integer column in column X.shape[1] - 1
will refer to.
train_graphs = graphs[:TRAIN_CONFIGS] | ||
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch | ||
# Combine numerical and categorical kernels | ||
kernel = AdditiveKernel(*kernels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can now dump the TorchWLKernel
in here with the rest.
assert x1.shape[-1] == 1, "Last dimension must be the graph index" | ||
assert x2.shape[-1] == 1, "Last dimension must be the graph index" | ||
|
||
# TODO: Optimizations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But suspect reasons for it being slow as well as some possible solutions.
|
||
# NOTE: The active dim is already selected out for us and is the last dimension | ||
# (not including whatever happens when last_dim_is_batch) is True. | ||
if x1.ndim == 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to do with the batching that goes on during optimize_acqf
.
It corresponds with the raw_samples=
parameter
q_dim_size = x1.shape[0] | ||
assert x2.shape[0] == q_dim_size | ||
|
||
out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device) | ||
for q in range(q_dim_size): | ||
out[q] = self.forward(x1[q], x2[q], diag=diag) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the kernel can't natively handle this batching, we essentially do a for loop over the q
dimension. With the code comments above, that we be 32
iterations of x1.shape == (5, 1)
and x2.shape == (55, 1)
, the 1
being that we have only 1 column of indices.
if x1_is_x2: | ||
_ixs = x1.flatten().to(torch.int64).tolist() | ||
all_graphs = [self.graph_lookup[i] for i in _ixs] | ||
|
||
# No selection requires | ||
select = None | ||
else: | ||
_ixs1 = x1.flatten().to(torch.int64).tolist() | ||
_ixs2 = x2.flatten().to(torch.int64).tolist() | ||
all_graphs = [self.graph_lookup[i] for i in _ixs1 + _ixs2] | ||
|
||
# Select out K_x1_x2 | ||
select = lambda _K: _K[: len(_ixs1), len(_ixs1) :] | ||
|
||
_kernel = _TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize) | ||
K = _kernel(all_graphs) | ||
K_selected = K if select is None else select(K) | ||
if diag: | ||
return torch.diag(K_selected) | ||
return K_selected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we pull out the graphs from the lookup. I put in a mini-optimization that occurs during fit_gpytorch_mll
, i.e. when x1_is_x2
, meaning we only compute the K_x1_x1
part and use that.
...other wise, I end up computing the whole K
matrix and just subselecting out K_x1_x2
. It would be much more efficient if we could just calculate that part
return K_selected | ||
|
||
|
||
class _TorchWLKernel(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wrapped this class, could probably lift the logic out into functions, or otherwise move the logic into the class above.
@vladislavalerievich here's a working but highly un-optimized acquisition over mixed spaces with categoricals, numericals and graphs.
The main thing done was to not use the overwritten
SingleTaskGP
and pass in graphs as a column in theX
tensor. This allows us to use almost all of the undelying bo-torch functionality. I made this as a PR so it can be annotated, feel free to merge as you need (it's into your branch).fixed_features
and they are not optimized over. Instead, we optimize them in an outer-loop, as was done with the categorials.graph_lookup
that we attach to theTorchWLKernel
when we need it. See the associated@contextmanager
. This contextmanager might be overkill but it worked for me.It's slow, horribly slow ... I changed the parameters to reduce the amount of iterations it will do but ultimately it needs to be sped up. I left a big
TODO
w.r.t. the optimizations. Most of it is just re-dundant calculations which need to be fixed up.If you need some guidance on where it's slow, highly recommend
py-spy
,py-spy record -f speedscope -o profile.speedscope -- python <pythonfile>
.You can then upload the
profile.speedscope
to here to see a flamegraph of where all the time is being spent. I recommend usingleft heavy
view (top left).If you need to see the underlying non-python code in the output, you can add in a
-n
argument, although it gets noisy until you're used to looking at it.