Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Implement graph acquisition #164

Open
wants to merge 1 commit into
base: feat-torch-wl-kernel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 6 additions & 120 deletions grakel_replace/mixed_single_task_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,15 @@

from typing import TYPE_CHECKING

import networkx as nx
import torch
from botorch.models import SingleTaskGP
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, Kernel
from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel

if TYPE_CHECKING:
from gpytorch.module import Module
import networkx as nx
from torch import Tensor


class WLKernel(Kernel):
"""Weisfeiler-Lehman Kernel for graph similarity
integrated into the GPyTorch framework.

This kernel encapsulates the precomputed Weisfeiler-Lehman graph kernel matrix
and provides it in a GPyTorch-compatible format.
It computes either the training kernel
or the cross-kernel between training and test graphs as needed.
"""

def __init__(
self,
K_train: Tensor,
wl_kernel: TorchWLKernel,
train_graph_dataset: GraphDataset
) -> None:
super().__init__()
self._K_train = K_train
self._wl_kernel = wl_kernel
self._train_graph_dataset = train_graph_dataset

def forward(
self, x1: Tensor,
x2: Tensor | None = None,
diag: bool = False,
last_dim_is_batch: bool = False
) -> Tensor:
"""Forward method to compute the kernel matrix for the graph inputs.

Args:
x1 (Tensor): First input tensor
(unused, required for interface compatibility).
x2 (Tensor | None): Second input tensor.
If None, computes the training kernel matrix.
diag (bool): Whether to return only the diagonal of the kernel matrix.
last_dim_is_batch (bool): Whether the last dimension is a batch dimension.

Returns:
Tensor: The computed kernel matrix.
"""
if x2 is None:
# Return the precomputed training kernel matrix
return self._K_train

# Compute cross-kernel between training graphs and new test graphs
test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs
return self._wl_kernel(self._train_graph_dataset, test_dataset)


class MixedSingleTaskGP(SingleTaskGP):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't actually used anymore

"""A Gaussian Process model for mixed input spaces containing numerical, categorical,
and graph features.
Expand All @@ -85,9 +33,9 @@ def __init__(
train_X: Tensor,
train_graphs: list[nx.Graph],
train_Y: Tensor,
num_cat_kernel: Kernel,
wl_kernel: TorchWLKernel,
train_Yvar: Tensor | None = None,
num_cat_kernel: Module | None = None,
wl_kernel: TorchWLKernel | None = None,
**kwargs,
) -> None:
"""Initialize the mixed-input Gaussian Process model.
Expand Down Expand Up @@ -115,7 +63,7 @@ def __init__(
**kwargs,
)
# Initialize the Weisfeiler-Lehman kernel or use a default one
self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True)
self._wl_kernel = wl_kernel

# Preprocess the training graphs into a compatible format and compute the graph
# kernel matrix
Expand All @@ -137,69 +85,7 @@ def __init__(

def __call__(self, X: Tensor, graphs: list[nx.Graph] | None = None, **kwargs):
"""Custom __call__ method that retrieves train graphs if not explicitly passed."""
print("__call__", X.shape, len(graphs) if graphs is not None else None) # noqa: T201
if graphs is None: # Use stored graphs from train_inputs if not provided
graphs = self._train_inputs[1]
return self.forward(X, graphs)

def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal:
"""Forward pass to compute the Gaussian Process distribution for given inputs.

This combines the numerical/categorical kernel with the graph kernel
to compute the joint covariance matrix.

Args:
X (Tensor): Input tensor for numerical and categorical features.
graphs (list[nx.Graph]): List of input graphs.

Returns:
MultivariateNormal: The Gaussian Process distribution for the inputs.
"""
if len(X) != len(graphs):
raise ValueError(
f"Number of feature vectors ({len(X)}) must match "
f"number of graphs ({len(graphs)})"
)
if not all(isinstance(g, nx.Graph) for g in graphs):
raise TypeError("Expected input type is a list of NetworkX graphs.")

# Process the new graph inputs into a compatible dataset
proc_graphs = GraphDataset.from_networkx(graphs)

# Compute the kernel matrix for the new graphs
K_new = self._wl_kernel(proc_graphs)
K_new = K_new.to(dtype=X.dtype)

# Combine the graph kernel with the numerical/categorical kernel (if present)
if self.num_cat_kernel is not None:
K_num_cat = self.num_cat_kernel(X)

# Ensure K_new matches K_num_cat dimensions
if K_num_cat.dim() > 2:
batch_size = K_num_cat.size(0)
target_size = K_num_cat.size(1)

# Resize K_new if needed
if K_new.size(-1) != target_size:
K_new_resized = torch.zeros(
*K_new.shape[:-2], target_size, target_size,
dtype=K_new.dtype,
device=K_new.device
)
K_new_resized[..., :K_new.size(-2), :K_new.size(-1)] = K_new
K_new = K_new_resized

if K_new.dim() < K_num_cat.dim():
K_new = K_new.unsqueeze(0).expand(batch_size, target_size,
target_size)

# Convert to dense tensor if needed
if hasattr(K_num_cat, "to_dense"):
K_num_cat = K_num_cat.to_dense()

K_combined = K_num_cat + K_new
else:
K_combined = K_new

# Compute the mean using the mean module and construct the GP distribution
mean_x = self.mean_module(X)
return MultivariateNormal(mean_x, K_combined)
return self.forward(X)
125 changes: 79 additions & 46 deletions grakel_replace/mixed_single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from itertools import product
from typing import TYPE_CHECKING

import networkx as nx
import torch
from botorch import fit_gpytorch_mll
from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement
from botorch.fit import fit_gpytorch_mll
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from botorch.models.gp_regression_mixed import CategoricalKernel, Kernel, ScaleKernel
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, MaternKernel
from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP
from grakel_replace.optimize import optimize_acqf_graph
from grakel_replace.torch_wl_kernel import TorchWLKernel

TRAIN_CONFIGS = 10
if TYPE_CHECKING:
from gpytorch.distributions.multivariate_normal import MultivariateNormal

TRAIN_CONFIGS = 50
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3
N_GRAPH = 2
N_CATEGORICAL = 1
N_CATEGORICAL_VALUES_PER_CATEGORY = 2
N_GRAPH = 1
assert N_GRAPH == 1, "This example only supports a single graph feature"
Comment on lines +28 to +29
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can assume we'll only ever have 1 graph parameter for now. In the future, if we need more, we could have a kernel per graph hyperparameter.


kernels = []

# Create numerical and categorical features
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
X = torch.empty(
size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH),
dtype=torch.float64,
)
Comment on lines +34 to +37
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The + N_GRAPH is now where the indices for the graph lookup will go

if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
X[:, N_NUMERICAL : N_NUMERICAL + N_CATEGORICAL] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
Expand All @@ -45,8 +55,21 @@
G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes
graphs.append(G)

# Assign a new index column to the graphs
X[:, -1] = torch.arange(TOTAL_CONFIGS, dtype=torch.float64)

# Create random target values
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + 0.5

# Split into train and test sets
train_x = X[:TRAIN_CONFIGS]
train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch

test_x = X[TRAIN_CONFIGS:]
test_graphs = graphs[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1)


# Setup kernels for numerical and categorical features
if N_NUMERICAL > 0:
Expand All @@ -68,47 +91,56 @@
)
kernels.append(hamming)

# Combine numerical and categorical kernels
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None
if N_GRAPH > 0:
wl_kernel = ScaleKernel(
TorchWLKernel(
graph_lookup=train_graphs,
n_iter=5,
normalize=True,
active_dims=(X.shape[1] - 1,), # Last column
)
)
kernels.append(wl_kernel)
Comment on lines +94 to +103
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that TorchWLKernel inherits from Kernel, we can use it just like any other kernel, including wrapping it in a ScaleKernel.

Importantly, the graph_lookup we pass in is what the integer column in column X.shape[1] - 1 will refer to.


# Create WL kernel for graphs
wl_kernel = TorchWLKernel(n_iter=5, normalize=True)

# Split into train and test sets
train_x = X[:TRAIN_CONFIGS]
train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch
# Combine numerical and categorical kernels
kernel = AdditiveKernel(*kernels)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can now dump the TorchWLKernel in here with the rest.


test_x = X[TRAIN_CONFIGS:]
test_graphs = graphs[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1)
from botorch.models import SingleTaskGP

# Initialize the mixed GP
gp = MixedSingleTaskGP(
train_X=train_x,
train_graphs=train_graphs,
train_Y=train_y,
num_cat_kernel=combined_num_cat_kernel,
wl_kernel=wl_kernel,
)
gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=kernel)
Comment on lines +109 to +112
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... which also means we can use the default SingleTaskGP without needing to create our own.


# Compute the posterior distribution
multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs)
print("Posterior distribution:", multivariate_normal)
# The wl_kernel will use the indices to index into the training graphs it is holding
# on to...
multivariate_normal: MultivariateNormal = gp.forward(train_x)


# Making predictions on test data
with torch.no_grad():
posterior = gp.forward(test_x, test_graphs)
# No the wl_kernel needs to be aware of the test graphs
@contextmanager
def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[None]:
kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = []
for kern in _gp.covar_module.sub_kernels():
if isinstance(kern, TorchWLKernel):
kernel_prev_graphs.append((kern, kern.graph_lookup))
kern.set_graph_lookup(new_graphs)

yield

for _kern, _prev_graphs in kernel_prev_graphs:
_kern.set_graph_lookup(_prev_graphs)
Comment on lines +121 to +133
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, here's the hack that makes it work. We iterate through the kernels of the GP and if we find a TorchWLKernel, we change it's graph_lookup.



with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and use it here.

posterior = gp.forward(test_x)
predictions = posterior.mean
uncertainties = posterior.variance.sqrt()
covar = posterior.covariance_matrix

print("\nMean:", predictions)
print("Variance:", uncertainties)

# =============== Fitting the GP using botorch ===============

print("\nFitting the GP model using botorch...")

mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)
Expand All @@ -124,8 +156,10 @@
# Define bounds
bounds = torch.tensor(
[
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL,
[1.0] * N_NUMERICAL + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH,
[1.0] * N_NUMERICAL
+ [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL
+ [len(X) - 1] * N_GRAPH,
Comment on lines 157 to +162
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the bounds needs to be of the correct size, we extend it. You'll notice that I put [-1.0] as the lower bound. I don't like this, but it let me quickly hack in that we can use the index -1 to refer to the last graph in the graph_lookup, you'll see it used later.

]
)

Expand All @@ -142,21 +176,20 @@
fixed_cats = [{col: i} for i in choice_indices]
else:
fixed_cats = [
dict(zip(cats_per_column.keys(), combo))
dict(zip(cats_per_column.keys(), combo, strict=False))
for combo in product(*cats_per_column.values())
]


print("------------------") # noqa: T201
# Use the graph-optimized acquisition function
best_candidate, best_score = optimize_acqf_graph(
acq_function=acq_function,
bounds=bounds,
fixed_features_list=fixed_cats,
train_graphs=train_graphs,
num_graph_samples=20,
num_restarts=10,
raw_samples=10,
num_graph_samples=2,
num_restarts=2,
raw_samples=16,
q=1,
)

print("Best candidate:", best_candidate)
print("Acquisition score:", best_score)
Loading