Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Implement graph acquisition #164

Open
wants to merge 1 commit into
base: feat-torch-wl-kernel
Choose a base branch
from

Conversation

eddiebergman
Copy link
Contributor

@vladislavalerievich here's a working but highly un-optimized acquisition over mixed spaces with categoricals, numericals and graphs.

The main thing done was to not use the overwritten SingleTaskGP and pass in graphs as a column in the X tensor. This allows us to use almost all of the undelying bo-torch functionality. I made this as a PR so it can be annotated, feel free to merge as you need (it's into your branch).

  • This tensor represents indices. As far as the acquisition optimization is concerned, these are fixed_features and they are not optimized over. Instead, we optimize them in an outer-loop, as was done with the categorials.
  • These indices map into a graph_lookup that we attach to the TorchWLKernel when we need it. See the associated @contextmanager. This contextmanager might be overkill but it worked for me.

It's slow, horribly slow ... I changed the parameters to reduce the amount of iterations it will do but ultimately it needs to be sped up. I left a big TODO w.r.t. the optimizations. Most of it is just re-dundant calculations which need to be fixed up.

If you need some guidance on where it's slow, highly recommend py-spy,

py-spy record -f speedscope -o profile.speedscope -- python <pythonfile>.

You can then upload the profile.speedscope to here to see a flamegraph of where all the time is being spent. I recommend using left heavy view (top left).

If you need to see the underlying non-python code in the output, you can add in a -n argument, although it gets noisy until you're used to looking at it.

test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs
return self._wl_kernel(self._train_graph_dataset, test_dataset)


class MixedSingleTaskGP(SingleTaskGP):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't actually used anymore

Comment on lines +28 to +29
N_GRAPH = 1
assert N_GRAPH == 1, "This example only supports a single graph feature"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can assume we'll only ever have 1 graph parameter for now. In the future, if we need more, we could have a kernel per graph hyperparameter.

Comment on lines +34 to +37
X = torch.empty(
size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH),
dtype=torch.float64,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The + N_GRAPH is now where the indices for the graph lookup will go

Comment on lines +94 to +103
if N_GRAPH > 0:
wl_kernel = ScaleKernel(
TorchWLKernel(
graph_lookup=train_graphs,
n_iter=5,
normalize=True,
active_dims=(X.shape[1] - 1,), # Last column
)
)
kernels.append(wl_kernel)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that TorchWLKernel inherits from Kernel, we can use it just like any other kernel, including wrapping it in a ScaleKernel.

Importantly, the graph_lookup we pass in is what the integer column in column X.shape[1] - 1 will refer to.

train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch
# Combine numerical and categorical kernels
kernel = AdditiveKernel(*kernels)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can now dump the TorchWLKernel in here with the rest.

assert x1.shape[-1] == 1, "Last dimension must be the graph index"
assert x2.shape[-1] == 1, "Last dimension must be the graph index"

# TODO: Optimizations
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But suspect reasons for it being slow as well as some possible solutions.


# NOTE: The active dim is already selected out for us and is the last dimension
# (not including whatever happens when last_dim_is_batch) is True.
if x1.ndim == 3:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to do with the batching that goes on during optimize_acqf.

It corresponds with the raw_samples= parameter

Comment on lines +89 to +95
q_dim_size = x1.shape[0]
assert x2.shape[0] == q_dim_size

out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device)
for q in range(q_dim_size):
out[q] = self.forward(x1[q], x2[q], diag=diag)
return out
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the kernel can't natively handle this batching, we essentially do a for loop over the q dimension. With the code comments above, that we be 32 iterations of x1.shape == (5, 1) and x2.shape == (55, 1), the 1 being that we have only 1 column of indices.

Comment on lines +97 to +116
if x1_is_x2:
_ixs = x1.flatten().to(torch.int64).tolist()
all_graphs = [self.graph_lookup[i] for i in _ixs]

# No selection requires
select = None
else:
_ixs1 = x1.flatten().to(torch.int64).tolist()
_ixs2 = x2.flatten().to(torch.int64).tolist()
all_graphs = [self.graph_lookup[i] for i in _ixs1 + _ixs2]

# Select out K_x1_x2
select = lambda _K: _K[: len(_ixs1), len(_ixs1) :]

_kernel = _TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize)
K = _kernel(all_graphs)
K_selected = K if select is None else select(K)
if diag:
return torch.diag(K_selected)
return K_selected
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we pull out the graphs from the lookup. I put in a mini-optimization that occurs during fit_gpytorch_mll, i.e. when x1_is_x2, meaning we only compute the K_x1_x1 part and use that.

...other wise, I end up computing the whole K matrix and just subselecting out K_x1_x2. It would be much more efficient if we could just calculate that part

return K_selected


class _TorchWLKernel(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wrapped this class, could probably lift the logic out into functions, or otherwise move the logic into the class above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

1 participant