-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Implement graph acquisition #164
base: feat-torch-wl-kernel
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,48 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Iterator | ||
from contextlib import contextmanager | ||
from itertools import product | ||
from typing import TYPE_CHECKING | ||
|
||
import networkx as nx | ||
import torch | ||
from botorch import fit_gpytorch_mll | ||
from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement | ||
from botorch.fit import fit_gpytorch_mll | ||
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel | ||
from botorch.models.gp_regression_mixed import CategoricalKernel, Kernel, ScaleKernel | ||
from gpytorch import ExactMarginalLogLikelihood | ||
from gpytorch.distributions.multivariate_normal import MultivariateNormal | ||
from gpytorch.kernels import AdditiveKernel, MaternKernel | ||
from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP | ||
from grakel_replace.optimize import optimize_acqf_graph | ||
from grakel_replace.torch_wl_kernel import TorchWLKernel | ||
|
||
TRAIN_CONFIGS = 10 | ||
if TYPE_CHECKING: | ||
from gpytorch.distributions.multivariate_normal import MultivariateNormal | ||
|
||
TRAIN_CONFIGS = 50 | ||
TEST_CONFIGS = 10 | ||
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS | ||
|
||
N_NUMERICAL = 2 | ||
N_CATEGORICAL = 2 | ||
N_CATEGORICAL_VALUES_PER_CATEGORY = 3 | ||
N_GRAPH = 2 | ||
N_CATEGORICAL = 1 | ||
N_CATEGORICAL_VALUES_PER_CATEGORY = 2 | ||
N_GRAPH = 1 | ||
assert N_GRAPH == 1, "This example only supports a single graph feature" | ||
Comment on lines
+28
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can assume we'll only ever have 1 graph parameter for now. In the future, if we need more, we could have a kernel per graph hyperparameter. |
||
|
||
kernels = [] | ||
|
||
# Create numerical and categorical features | ||
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) | ||
X = torch.empty( | ||
size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH), | ||
dtype=torch.float64, | ||
) | ||
Comment on lines
+34
to
+37
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
if N_NUMERICAL > 0: | ||
X[:, :N_NUMERICAL] = torch.rand( | ||
size=(TOTAL_CONFIGS, N_NUMERICAL), | ||
dtype=torch.float64, | ||
) | ||
|
||
if N_CATEGORICAL > 0: | ||
X[:, N_NUMERICAL:] = torch.randint( | ||
X[:, N_NUMERICAL : N_NUMERICAL + N_CATEGORICAL] = torch.randint( | ||
0, | ||
N_CATEGORICAL_VALUES_PER_CATEGORY, | ||
size=(TOTAL_CONFIGS, N_CATEGORICAL), | ||
|
@@ -45,8 +55,21 @@ | |
G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes | ||
graphs.append(G) | ||
|
||
# Assign a new index column to the graphs | ||
X[:, -1] = torch.arange(TOTAL_CONFIGS, dtype=torch.float64) | ||
|
||
# Create random target values | ||
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) | ||
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + 0.5 | ||
|
||
# Split into train and test sets | ||
train_x = X[:TRAIN_CONFIGS] | ||
train_graphs = graphs[:TRAIN_CONFIGS] | ||
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch | ||
|
||
test_x = X[TRAIN_CONFIGS:] | ||
test_graphs = graphs[TRAIN_CONFIGS:] | ||
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) | ||
|
||
|
||
# Setup kernels for numerical and categorical features | ||
if N_NUMERICAL > 0: | ||
|
@@ -68,47 +91,56 @@ | |
) | ||
kernels.append(hamming) | ||
|
||
# Combine numerical and categorical kernels | ||
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None | ||
if N_GRAPH > 0: | ||
wl_kernel = ScaleKernel( | ||
TorchWLKernel( | ||
graph_lookup=train_graphs, | ||
n_iter=5, | ||
normalize=True, | ||
active_dims=(X.shape[1] - 1,), # Last column | ||
) | ||
) | ||
kernels.append(wl_kernel) | ||
Comment on lines
+94
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that Importantly, the |
||
|
||
# Create WL kernel for graphs | ||
wl_kernel = TorchWLKernel(n_iter=5, normalize=True) | ||
|
||
# Split into train and test sets | ||
train_x = X[:TRAIN_CONFIGS] | ||
train_graphs = graphs[:TRAIN_CONFIGS] | ||
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch | ||
# Combine numerical and categorical kernels | ||
kernel = AdditiveKernel(*kernels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can now dump the |
||
|
||
test_x = X[TRAIN_CONFIGS:] | ||
test_graphs = graphs[TRAIN_CONFIGS:] | ||
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) | ||
from botorch.models import SingleTaskGP | ||
|
||
# Initialize the mixed GP | ||
gp = MixedSingleTaskGP( | ||
train_X=train_x, | ||
train_graphs=train_graphs, | ||
train_Y=train_y, | ||
num_cat_kernel=combined_num_cat_kernel, | ||
wl_kernel=wl_kernel, | ||
) | ||
gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=kernel) | ||
Comment on lines
+109
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... which also means we can use the default |
||
|
||
# Compute the posterior distribution | ||
multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs) | ||
print("Posterior distribution:", multivariate_normal) | ||
# The wl_kernel will use the indices to index into the training graphs it is holding | ||
# on to... | ||
multivariate_normal: MultivariateNormal = gp.forward(train_x) | ||
|
||
|
||
# Making predictions on test data | ||
with torch.no_grad(): | ||
posterior = gp.forward(test_x, test_graphs) | ||
# No the wl_kernel needs to be aware of the test graphs | ||
@contextmanager | ||
def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[None]: | ||
kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] | ||
for kern in _gp.covar_module.sub_kernels(): | ||
if isinstance(kern, TorchWLKernel): | ||
kernel_prev_graphs.append((kern, kern.graph_lookup)) | ||
kern.set_graph_lookup(new_graphs) | ||
|
||
yield | ||
|
||
for _kern, _prev_graphs in kernel_prev_graphs: | ||
_kern.set_graph_lookup(_prev_graphs) | ||
Comment on lines
+121
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, here's the hack that makes it work. We iterate through the kernels of the GP and if we find a |
||
|
||
|
||
with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... and use it here. |
||
posterior = gp.forward(test_x) | ||
predictions = posterior.mean | ||
uncertainties = posterior.variance.sqrt() | ||
covar = posterior.covariance_matrix | ||
|
||
print("\nMean:", predictions) | ||
print("Variance:", uncertainties) | ||
|
||
# =============== Fitting the GP using botorch =============== | ||
|
||
print("\nFitting the GP model using botorch...") | ||
|
||
mll = ExactMarginalLogLikelihood(gp.likelihood, gp) | ||
fit_gpytorch_mll(mll) | ||
|
@@ -124,8 +156,10 @@ | |
# Define bounds | ||
bounds = torch.tensor( | ||
[ | ||
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, | ||
[1.0] * N_NUMERICAL + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL | ||
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, | ||
[1.0] * N_NUMERICAL | ||
+ [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL | ||
+ [len(X) - 1] * N_GRAPH, | ||
Comment on lines
157
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the |
||
] | ||
) | ||
|
||
|
@@ -142,21 +176,20 @@ | |
fixed_cats = [{col: i} for i in choice_indices] | ||
else: | ||
fixed_cats = [ | ||
dict(zip(cats_per_column.keys(), combo)) | ||
dict(zip(cats_per_column.keys(), combo, strict=False)) | ||
for combo in product(*cats_per_column.values()) | ||
] | ||
|
||
|
||
print("------------------") # noqa: T201 | ||
# Use the graph-optimized acquisition function | ||
best_candidate, best_score = optimize_acqf_graph( | ||
acq_function=acq_function, | ||
bounds=bounds, | ||
fixed_features_list=fixed_cats, | ||
train_graphs=train_graphs, | ||
num_graph_samples=20, | ||
num_restarts=10, | ||
raw_samples=10, | ||
num_graph_samples=2, | ||
num_restarts=2, | ||
raw_samples=16, | ||
q=1, | ||
) | ||
|
||
print("Best candidate:", best_candidate) | ||
print("Acquisition score:", best_score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't actually used anymore