Skip to content

Commit

Permalink
Merge pull request #10 from histocartography/feature/model_zoo/gja
Browse files Browse the repository at this point in the history
Feature/model zoo/gja
  • Loading branch information
afoncubierta authored May 20, 2021
2 parents b3b8b20 + 6a22f9f commit 06bc671
Show file tree
Hide file tree
Showing 18 changed files with 94 additions and 436 deletions.
2 changes: 2 additions & 0 deletions histocartography/interpretability/grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def _get_weights(self, class_idx, scores):
# Backpropagate
self._backprop(scores, class_idx)

self.backward_hook = list(reversed(self.backward_hook))

# Compute alpha
grad_2 = [f.pow(2) for f in self.backward_hook]
grad_3 = [f.pow(3) for f in self.backward_hook]
Expand Down
2 changes: 1 addition & 1 deletion histocartography/ml/layers/dense_gin_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(self, adj, h):
adj = adj / degree

if self.add_self:
adj = adj + torch.eye(adj.size(1)).to(adj.device)
adj = adj.float() + torch.eye(adj.size(1)).to(adj.device)

# adjust h dim
if len(h.shape) < 3:
Expand Down
2 changes: 1 addition & 1 deletion histocartography/ml/layers/gin_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _compute_adj_lrp(self, relevance_score):

adjacency_matrix = torch.clamp(self.adjacency_matrix, min=0)
if self.agg_type == 'mean':
adjacency_matrix = torch.div(adjacency_matrix, self.in_degrees)
adjacency_matrix = torch.div(adjacency_matrix, self.in_degrees.to(adjacency_matrix.device))
adjacency_matrix = adjacency_matrix + \
torch.eye(self.adjacency_matrix.shape[0]).to(relevance_score.device)
rel_unnorm = torch.mm(self.input_features, adjacency_matrix.t()) + 1e-9
Expand Down
11 changes: 4 additions & 7 deletions histocartography/ml/layers/pna_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,11 @@ def forward(self, g, h):

# graph and batch normalization
if self.graph_norm:
if isinstance(
g, dgl.DGLGraph) or isinstance(
g, dgl.DGLHeteroGraph):
num_nodes = [g.number_of_nodes()]
else:
if hasattr(g, 'batch_num_nodes'):
num_nodes = g.batch_num_nodes
snorm_n = torch.FloatTensor(list(itertools.chain(
*[[np.sqrt(1 / n)] * n for n in num_nodes]))).to(h.device)
else:
num_nodes = [g.number_of_nodes()]
snorm_n = torch.FloatTensor(list(itertools.chain(*[[np.sqrt(1 / n)] * n for n in num_nodes]))).to(h.device)
h = h * snorm_n[:, None]
if self.batch_norm:
h = self.batchnorm_h(h)
Expand Down
133 changes: 18 additions & 115 deletions histocartography/ml/models/zoo.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
# The models follow the naming convention:
# <dataset_name>_<model_type>_<number_of_classes>_classes_<gnn_type>.pt
# e.g., in bracs_tggnn_5_classes_pna:
# e.g., in bracs_tggnn_7_classes_pna:
# -btrained on bracs dataset
# - with tggnn (Tissue Graph GNN) model
# - and 5 classes
# - and 7 classes
# - with PNA GNN layers

MODEL_NAME_TO_URL = {
# CG-GNN
'bracs_cggnn_3_classes_gin.pt': 'https://ibm.box.com/shared/static/pozkx0ngqjxdr34v5tpmckthts8m12u3.pt',
'bracs_cggnn_5_classes_pna.pt': 'https://ibm.box.com/shared/static/8yj33c4o5precry9sev5g6585vmijeso.pt',
'bracs_cggnn_5_classes_gin.pt': 'https://ibm.box.com/shared/static/uy2xeovpo001mb1edkwms5ccz1sqhdpz.pt',
'bracs_cggnn_7_classes_pna.pt': 'https://ibm.box.com/shared/static/i4xixoglstzkif53rc2cm4b568ei5wnf.pt',
'bracs_cggnn_7_classes_pna.pt': 'https://ibm.box.com/shared/static/i4xixoglstzkif53rc2cm4b568ei5wnf.pt',
# TG-GNN
'bracs_tggnn_3_classes_gin.pt': 'https://ibm.box.com/shared/static/aoogy0516lsp9vaxgw1tr9mdu5nycvvb.pt',
'bracs_tggnn_5_classes_pna.pt': 'https://ibm.box.com/shared/static/qvdr4j12fwo7zuute6wsgrmvxfxe1o8w.pt',
'bracs_tggnn_7_classes_pna.pt': 'https://ibm.box.com/shared/static/19q7kk2humvc6a8qedzg8rs5bny6qvrf.pt',
'bracs_tggnn_7_classes_pna.pt': 'https://ibm.box.com/shared/static/19q7kk2humvc6a8qedzg8rs5bny6qvrf.pt',
# HACT
'bracs_hact_5_classes_pna.pt': 'https://ibm.box.com/shared/static/efar14ic4mc13u5kidn6q23xzshh9oao.pt',
'bracs_hact_7_classes_pna.pt': 'https://ibm.box.com/shared/static/5v44c33cipdy7c2dhajkrfaywyrh2a5o.pt',
'bracs_hact_7_classes_pna.pt': 'https://ibm.box.com/shared/static/5v44c33cipdy7c2dhajkrfaywyrh2a5o.pt',
}

MODEL_NAME_TO_CONFIG = {
Expand All @@ -43,31 +40,6 @@
'hidden_dim': 128,
}
},
'bracs_cggnn_5_classes_pna.pt': {
'node_dim': 514,
'gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 3,
'readout_op': "concat",
'readout_type': "mean",
'aggregators': "mean max min std",
'scalers': "identity amplification attenuation",
'avg_d': 4,
'dropout': 0.,
'graph_norm': True,
'batch_norm': True,
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
},
'classification_params': {
'num_layers': 2,
'hidden_dim': 128,
}
},
'bracs_cggnn_5_classes_gin.pt': {
'node_dim': 514,
'gnn_params': {
Expand All @@ -93,8 +65,8 @@
'gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 5,
'readout_op': "concat",
'num_layers': 3,
'readout_op': "lstm",
'readout_type': "mean",
'aggregators': "mean max min std",
'scalers': "identity amplification attenuation",
Expand All @@ -105,8 +77,8 @@
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
'divide_input': False,
'residual': False,
},
'classification_params': {
'num_layers': 2,
Expand Down Expand Up @@ -134,63 +106,12 @@
'hidden_dim': 128,
}
},
'bracs_tggnn_5_classes_pna.pt': {
'node_dim': 514,
'gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 5,
'readout_op': "concat",
'readout_type': "mean",
'aggregators': "mean max min std",
'scalers': "identity amplification attenuation",
'avg_d': 4,
'dropout': 0.,
'graph_norm': True,
'batch_norm': True,
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
},
'classification_params': {
'num_layers': 2,
'hidden_dim': 128,
}
},
'bracs_tggnn_7_classes_pna.pt': {
'node_dim': 514,
'gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 3,
'readout_op': "concat",
'readout_type': "mean",
'aggregators': "mean max min std",
'scalers': "identity amplification attenuation",
'avg_d': 4,
'dropout': 0.,
'graph_norm': True,
'batch_norm': True,
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
},
'classification_params': {
'num_layers': 2,
'hidden_dim': 128,
}
},
# HACT
'bracs_hact_5_classes_pna.pt': {
'cg_node_dim': 514,
'cg_gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 4,
'readout_op': "lstm",
'readout_type': "mean",
'aggregators': "mean max min std",
Expand All @@ -202,39 +123,21 @@
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
},
'tg_node_dim': 514,
'tg_gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 4,
'readout_op': "lstm",
'readout_type': "mean",
'aggregators': "mean max min std",
'scalers': "identity amplification attenuation",
'avg_d': 4,
'dropout': 0.,
'graph_norm': True,
'batch_norm': True,
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
'divide_input': False,
'residual': False,
},
'classification_params': {
'num_layers': 2,
'hidden_dim': 128,
}
},
# HACT
'bracs_hact_7_classes_pna.pt': {
'cg_node_dim': 514,
'cg_gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 4,
'num_layers': 3,
'readout_op': "lstm",
'readout_type': "mean",
'aggregators': "mean max min std",
Expand All @@ -246,14 +149,14 @@
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
'divide_input': False,
'residual': False,
},
'tg_node_dim': 514,
'tg_gnn_params': {
'layer_type': "pna_layer",
'output_dim': 64,
'num_layers': 4,
'num_layers': 3,
'readout_op': "lstm",
'readout_type': "mean",
'aggregators': "mean max min std",
Expand All @@ -265,8 +168,8 @@
'towers': 1,
'pretrans_layers': 1,
'posttrans_layers': 1,
'divide_input': True,
'residual': True,
'divide_input': False,
'residual': False,
},
'classification_params': {
'num_layers': 2,
Expand Down
Loading

0 comments on commit 06bc671

Please sign in to comment.