Skip to content

Commit

Permalink
Add inference method (#106)
Browse files Browse the repository at this point in the history
Co-authored-by: Bart <bart@myotis>
  • Loading branch information
bart1 and Bart authored Jun 26, 2023
1 parent b04ad01 commit ed7cbaa
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
27 changes: 27 additions & 0 deletions bird_cloud_gnn/gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import os
import dgl
import numpy as np
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch.conv import GraphConv
from torch import nn
from torch import optim
Expand Down Expand Up @@ -99,3 +101,28 @@ def evaluate(self, test_dataloader):

accuracy = num_correct / num_tests
return accuracy

def infer(self, dataset, batch_size=1024):
"""
Using the model do inference on a dataset.
Args:
dataset: A `RadarDataSet` where for each graph inference needs to be done.
Returns:
labels: A numpy array with infered labels for each graph
"""
self.eval()
dataloader = GraphDataLoader(
shuffle=False,
dataset=dataset,
batch_size=batch_size,
drop_last=False,
)
labels = np.array([])
for batched_graph, _ in dataloader:
pred = (
self(batched_graph, batched_graph.ndata["x"].float()).argmax(1).numpy()
)
labels = np.concatenate([labels, pred])
return labels
5 changes: 5 additions & 0 deletions tests/test_gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def test_gnn_model(dataset_fixture):
model.fit(train_dataloader)
model.evaluate(test_dataloader)

assert len(model.infer(dataset_fixture, batch_size=30)) == len(dataset_fixture)
assert (
(model.infer(dataset_fixture) == 1) | (model.infer(dataset_fixture) == 0)
).all()


class TestBasicBehaviour:
"""Set of tests for field access, inequality of classes and expected exceptions"""
Expand Down

0 comments on commit ed7cbaa

Please sign in to comment.