Skip to content

Commit 0ba4d87

Browse files
committed
refactoring for keras 3.0
1 parent 8ddcd03 commit 0ba4d87

File tree

4 files changed

+170
-121
lines changed

4 files changed

+170
-121
lines changed

kgcnn/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
# main package
1+
# Main package version.
22
__kgcnn_version__ = "4.0.0"
33

4+
# Global definition of index order and axis.
45
__indices_axis__ = 0
56
__index_receive__ = 0
67
__index_send__ = 1

kgcnn/io/loader.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +0,0 @@
1-
import numpy as np
2-
import tensorflow as tf
3-
from typing import Union, List
4-
from kgcnn.data.utils import ragged_tensor_from_nested_numpy
5-
from kgcnn.data.base import MemoryGraphDataset
6-
ks = tf.keras
7-
8-
9-
class GraphBatchLoader(ks.utils.Sequence):
10-
r"""Example (minimal) implementation of a graph batch loader based on :obj:`ks.utils.Sequence` ."""
11-
12-
def __init__(self,
13-
data: Union[List[dict], MemoryGraphDataset],
14-
inputs: Union[dict, List[dict]],
15-
outputs: Union[dict, List[dict]],
16-
batch_size: int = 32,
17-
shuffle: bool = False,
18-
device: str = None):
19-
"""Initialization with data and input information.
20-
21-
Args:
22-
data (list, MemoryGraphDataset): Any iterable data that implements indexing operator for graph instance.
23-
Each graph instance must implement indexing operator for named property.
24-
inputs (dict, list): List of dictionaries that specify graph properties in list via 'name' key.
25-
The dict-items match the tensor input for :obj:`tf.keras.layers.Input` layers.
26-
Required dict-keys should be 'name' and 'ragged'.
27-
Optionally shape information can be included via 'shape'.
28-
E.g.: `[{'name': 'edge_indices', 'ragged': True}, {...}, ...]`.
29-
outputs (dict, list): List of dictionaries that specify graph properties in list via 'name' key.
30-
Required dict-keys should be 'name' and 'ragged'.
31-
Optionally shape information can be included via 'shape'.
32-
E.g.: `[{'name': 'graph_labels', 'ragged': False}, {...}, ...]`.
33-
batch_size (int): Batch size. Default is 32.
34-
shuffle (bool): Whether to shuffle data. Default is False.
35-
device (str): Device to make tensor on. For multiprocessing this can cause deadlocks if e.g. set on GPU.
36-
Example for CPU would be '/cpu:0' .
37-
"""
38-
self.data = data
39-
self.inputs = inputs
40-
self.outputs = outputs
41-
self.batch_size = batch_size
42-
self.shuffle = shuffle
43-
self.indices = np.arange(len(data))
44-
self.device = device
45-
self._shuffle_indices()
46-
47-
def __len__(self):
48-
"""Denotes the number of batches per epoch"""
49-
return int(np.ceil(len(self.data) / float(self.batch_size)))
50-
51-
def __getitem__(self, index):
52-
"""Generate one batch of data"""
53-
# Generate indexes of the batch
54-
batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
55-
56-
# Generate data
57-
if self.device is not None:
58-
with tf.device(self.device):
59-
x_model, y_model = self._data_generation(batch_indices)
60-
else:
61-
x_model, y_model = self._data_generation(batch_indices)
62-
63-
return x_model, y_model
64-
65-
def _shuffle_indices(self):
66-
if self.shuffle:
67-
np.random.shuffle(self.indices)
68-
69-
def on_epoch_end(self):
70-
"""Updates after each epoch"""
71-
self._shuffle_indices()
72-
73-
@staticmethod
74-
def _to_tensor(item: Union[np.ndarray, list], is_ragged: bool):
75-
if is_ragged:
76-
return ragged_tensor_from_nested_numpy(item)
77-
else:
78-
return tf.constant(np.array(item))
79-
80-
def _get_data(self, index):
81-
# Accessing data method. E.g. loading from file etc.
82-
return self.data[int(index)]
83-
84-
def _data_generation(self, batch_indices: Union[np.ndarray, list]):
85-
"""Generates data containing batch_size samples"""
86-
graphs = [self._get_data(i) for i in batch_indices]
87-
# Inputs
88-
inputs = self.inputs if not isinstance(self.inputs, dict) else [self.inputs]
89-
x_inputs = []
90-
for i in inputs:
91-
data_list = [g[i["name"]] for g in graphs]
92-
is_ragged = i["ragged"] if "ragged" in i else False
93-
x_inputs.append(self._to_tensor(data_list, is_ragged))
94-
y_outputs = []
95-
# Outputs
96-
outputs = self.outputs if not isinstance(self.outputs, dict) else [self.outputs]
97-
for i in outputs:
98-
data_list = [g[i["name"]] for g in graphs]
99-
is_ragged = i["ragged"] if "ragged" in i else False
100-
y_outputs.append(self._to_tensor(data_list, is_ragged))
101-
# Check return type.
102-
x_inputs = x_inputs if not isinstance(self.inputs, dict) else x_inputs[0]
103-
y_outputs = y_outputs if not isinstance(self.outputs, dict) else y_outputs[0]
104-
return x_inputs, y_outputs

0 commit comments

Comments
 (0)