Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible PyTorch implementation of WL kernel #153

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
36fc3bd
Add a PyTorch implementation of WL kernel
vladislavalerievich Oct 28, 2024
b0d3842
Fix imports
vladislavalerievich Oct 29, 2024
f87abd6
Remove redundant copy
vladislavalerievich Oct 29, 2024
358fbb7
Increase precision for allclose
vladislavalerievich Oct 29, 2024
de140b6
Fix calculation for graphs with reordered edges
vladislavalerievich Oct 29, 2024
08c7aea
Increase test coverage
vladislavalerievich Oct 29, 2024
6f07858
Improve readability of TorchWLKernel
vladislavalerievich Oct 30, 2024
896f461
Add additional comments to TorchWLKernel
vladislavalerievich Oct 30, 2024
383e924
Add MixedSingleTaskGP to process graphs
vladislavalerievich Nov 8, 2024
65666a3
Refactor WLKernelWrapper into a standalone WLKernel class.
vladislavalerievich Nov 20, 2024
7fa9432
Update tests
vladislavalerievich Nov 20, 2024
4227f22
Add a check for empty inputs
vladislavalerievich Nov 20, 2024
f194bd2
Improve and combine tests
vladislavalerievich Nov 20, 2024
a104840
Update WLKernel
vladislavalerievich Nov 21, 2024
246f9f6
Add acquisition function with graph sampling
vladislavalerievich Nov 21, 2024
770c626
Add a custom __call__ method to pass graphs during optimization
vladislavalerievich Nov 21, 2024
8bf7ea7
Update MixedSingleTaskGP
vladislavalerievich Dec 7, 2024
84d0104
Remove not used argument
vladislavalerievich Dec 7, 2024
d63239a
Update sample_graphs
vladislavalerievich Dec 7, 2024
3db3f89
Handle different batch dimensions
vladislavalerievich Dec 7, 2024
f69ddbe
Set num_restarts=10
vladislavalerievich Dec 7, 2024
1c4cc83
Add acquisition function
vladislavalerievich Dec 7, 2024
dab9a8c
Update WLKernel
vladislavalerievich Dec 7, 2024
2999582
Make train_inputs private
vladislavalerievich Dec 7, 2024
ad55030
Update tests
vladislavalerievich Dec 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions grakel_replace/mixed_single_task_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from botorch.models import SingleTaskGP
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel
from gpytorch.module import Module
from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel

if TYPE_CHECKING:
import networkx as nx
from torch import Tensor


class MixedSingleTaskGP(SingleTaskGP):
"""A Gaussian Process model that handles numerical, categorical, and graph inputs.

This class extends BoTorch's SingleTaskGP to work with hybrid input spaces containing:
- Numerical features
- Categorical features
- Graph structures

It uses the Weisfeiler-Lehman (WL) kernel for graph inputs and combines it with
standard kernels for numerical/categorical features using an additive kernel structure

Attributes:
_wl_kernel (TorchWLKernel): The Weisfeiler-Lehman kernel for graph similarity
_train_graphs (List[nx.Graph]): Training set graph instances
_K_graph (Tensor): Pre-computed graph kernel matrix for training data
num_cat_kernel (Optional[Module]): Kernel for numerical/categorical features
"""

def __init__(
self,
train_X: Tensor, # Shape: (n_samples, n_numerical_categorical_features)
train_graphs: list[nx.Graph], # List of n_samples graph instances
train_Y: Tensor, # Shape: (n_samples, n_outputs)
train_Yvar: Tensor | None = None, # Shape: (n_samples, n_outputs) or None
num_cat_kernel: Module | None = None,
wl_kernel: TorchWLKernel | None = None,
**kwargs # Additional arguments passed to SingleTaskGP
) -> None:
"""Initialize the mixed input Gaussian Process model.

Args:
train_X: Training data tensor for numerical and categorical features
train_graphs: List of training graphs
train_Y: Target values
train_Yvar: Observation noise variance (optional)
num_cat_kernel: Kernel for numerical/categorical features (optional)
wl_kernel: Custom Weisfeiler-Lehman kernel instance (optional)
**kwargs: Additional arguments for SingleTaskGP initialization
"""
# Initialize parent class with initial covar_module
super().__init__(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
covar_module=num_cat_kernel or self._graph_kernel_wrapper(),
**kwargs
)

# Initialize WL kernel with default parameters if not provided
self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True)
self._train_graphs = train_graphs

# Convert graphs to required format and compute kernel matrix
self._train_graph_dataset = GraphDataset.from_networkx(train_graphs)
self._K_train = self._wl_kernel(self._train_graph_dataset)

if num_cat_kernel is not None:
# Create additive kernel combining numerical/categorical and graph kernels
combined_kernel = AdditiveKernel(
num_cat_kernel,
self._graph_kernel_wrapper()
)
self.covar_module = combined_kernel

self.num_cat_kernel = num_cat_kernel

def _graph_kernel_wrapper(self) -> Module:
"""Creates a GPyTorch-compatible kernel module wrapping the WL kernel.

This wrapper allows the WL kernel to be used within the GPyTorch framework
by providing a forward method that returns the pre-computed kernel matrix.

Returns:
Module: A GPyTorch kernel module wrapping the WL kernel computation
"""

class WLKernelWrapper(Module):
def __init__(self, parent: MixedSingleTaskGP):
super().__init__()
self.parent = parent

def forward(
self,
x1: Tensor,
x2: Tensor | None = None,
diag: bool = False,
last_dim_is_batch: bool = False
) -> Tensor:
"""Compute the kernel matrix for the graph inputs.

Args:
x1: First input tensor (unused, required for interface compatibility)
x2: Second input tensor (must be None)
diag: Whether to return only diagonal elements
last_dim_is_batch: Whether the last dimension is a batch dimension

Returns:
Tensor: Pre-computed graph kernel matrix

Raises:
NotImplementedError: If x2 is not None (cross-covariance not implemented)
"""
if x2 is None:
return self.parent._K_train

# Compute cross-covariance between train and test graphs
test_dataset = GraphDataset.from_networkx(self.parent._test_graphs)
return self.parent._wl_kernel(
self.parent._train_graph_dataset,
test_dataset
)

return WLKernelWrapper(self)
vladislavalerievich marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal:
"""Forward pass computing the GP distribution for given inputs.

Computes the kernel matrix for both numerical/categorical features and graphs,
combines them if both are present, and returns the resulting GP distribution.

Args:
X: Input tensor for numerical and categorical features
graphs: List of input graphs

Returns:
MultivariateNormal: GP distribution for the given inputs
"""
if len(X) != len(graphs):
raise ValueError(
f"Number of feature vectors ({len(X)}) must match "
f"number of graphs ({len(graphs)})"
)

# Process new graphs and compute kernel matrix
proc_graphs = GraphDataset.from_networkx(graphs)
K_new = self._wl_kernel(proc_graphs) # Shape: (n_samples, n_samples)

# If we have both numerical/categorical and graph features
if self.num_cat_kernel is not None:
# Compute kernel for numerical/categorical features
K_num_cat = self.num_cat_kernel(X)
# Add the kernels (element-wise addition)
K_combined = K_num_cat + K_new
else:
K_combined = K_new

# Compute mean using the mean module
mean_x = self.mean_module(X)

return MultivariateNormal(mean_x, K_combined)
102 changes: 102 additions & 0 deletions grakel_replace/mixed_single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import networkx as nx
import torch
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, MaternKernel
from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP
from grakel_replace.torch_wl_kernel import TorchWLKernel

TRAIN_CONFIGS = 10
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3
N_GRAPH = 2

kernels = []

# Create numerical and categorical features
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
dtype=torch.float64,
)

# Create random graph architectures
graphs = []
for _ in range(TOTAL_CONFIGS):
G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes
graphs.append(G)

# Create random target values
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)

# Setup kernels for numerical and categorical features
if N_NUMERICAL > 0:
matern = ScaleKernel(
MaternKernel(
nu=2.5,
ard_num_dims=N_NUMERICAL,
active_dims=tuple(range(N_NUMERICAL)),
),
)
kernels.append(matern)

if N_CATEGORICAL > 0:
hamming = ScaleKernel(
CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)),
),
)
kernels.append(hamming)

# Combine numerical and categorical kernels
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None

# Create WL kernel for graphs
wl_kernel = TorchWLKernel(n_iter=5, normalize=True)
Comment on lines +72 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing there's no way to really make it such that we could pass the TorchWLKernel to the AdditiveKernel, i.e. you would use it just like any other kernel type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could define a WLKernel class that extends gpytorch.kernels.Kernel and use that class instead of TorchWLKernel.


# Split into train and test sets
train_x = X[:TRAIN_CONFIGS]
train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch

test_x = X[TRAIN_CONFIGS:]
test_graphs = graphs[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1)

# Initialize the mixed GP
gp = MixedSingleTaskGP(
train_X=train_x,
train_graphs=train_graphs,
train_Y=train_y,
num_cat_kernel=combined_num_cat_kernel,
wl_kernel=wl_kernel,
)

# Compute the posterior distribution
multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs)
print("Posterior distribution:", multivariate_normal)

# Making predictions on test data
with torch.no_grad():
posterior = gp.forward(test_x, test_graphs)
predictions = posterior.mean
uncertainties = posterior.variance.sqrt()
covar = posterior.covariance_matrix

print("\nMean:", predictions)
print("Variance:", uncertainties)
print("Covariance matrix:", covar)
81 changes: 81 additions & 0 deletions grakel_replace/single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
from botorch.models import SingleTaskGP
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, MaternKernel

TRAIN_CONFIGS = 10
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3

kernels = []

# Create some random encoded hyperparameter configurations
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
dtype=torch.float64,
)

y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)

if N_NUMERICAL > 0:
matern = ScaleKernel(
MaternKernel(
nu=2.5,
ard_num_dims=N_NUMERICAL,
active_dims=tuple(range(N_NUMERICAL)),
),
)
kernels.append(matern)

if N_CATEGORICAL > 0:
hamming = ScaleKernel(
CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)),
),
)
kernels.append(hamming)


combined_num_cat_kernel = AdditiveKernel(*kernels)

train_x = X[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS]

test_x = X[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:]

K_matrix = combined_num_cat_kernel.forward(train_x, train_x)
print(
"K_matrix: ", K_matrix.to_dense()
)

train_y = train_y.unsqueeze(-1)
test_y = test_y.unsqueeze(-1)

gp = SingleTaskGP(
train_X=train_x,
train_Y=train_y,
mean_module=None, # We can leave it as the default it uses which is `ConstantMean`
covar_module=combined_num_cat_kernel,
)

multivariate_normal: MultivariateNormal = gp.forward(train_x)
print("Mean:", multivariate_normal.mean)
print("Variance:", multivariate_normal.variance)
print("Covariance matrix:", multivariate_normal.covariance_matrix)
Loading
Loading