|
1 |
| -import numpy as np |
2 |
| -import tensorflow as tf |
3 |
| -from typing import Union, List |
4 |
| -from kgcnn.data.utils import ragged_tensor_from_nested_numpy |
5 |
| -from kgcnn.data.base import MemoryGraphDataset |
6 |
| -ks = tf.keras |
7 |
| - |
8 |
| - |
9 |
| -class GraphBatchLoader(ks.utils.Sequence): |
10 |
| - r"""Example (minimal) implementation of a graph batch loader based on :obj:`ks.utils.Sequence` .""" |
11 |
| - |
12 |
| - def __init__(self, |
13 |
| - data: Union[List[dict], MemoryGraphDataset], |
14 |
| - inputs: Union[dict, List[dict]], |
15 |
| - outputs: Union[dict, List[dict]], |
16 |
| - batch_size: int = 32, |
17 |
| - shuffle: bool = False, |
18 |
| - device: str = None): |
19 |
| - """Initialization with data and input information. |
20 |
| -
|
21 |
| - Args: |
22 |
| - data (list, MemoryGraphDataset): Any iterable data that implements indexing operator for graph instance. |
23 |
| - Each graph instance must implement indexing operator for named property. |
24 |
| - inputs (dict, list): List of dictionaries that specify graph properties in list via 'name' key. |
25 |
| - The dict-items match the tensor input for :obj:`tf.keras.layers.Input` layers. |
26 |
| - Required dict-keys should be 'name' and 'ragged'. |
27 |
| - Optionally shape information can be included via 'shape'. |
28 |
| - E.g.: `[{'name': 'edge_indices', 'ragged': True}, {...}, ...]`. |
29 |
| - outputs (dict, list): List of dictionaries that specify graph properties in list via 'name' key. |
30 |
| - Required dict-keys should be 'name' and 'ragged'. |
31 |
| - Optionally shape information can be included via 'shape'. |
32 |
| - E.g.: `[{'name': 'graph_labels', 'ragged': False}, {...}, ...]`. |
33 |
| - batch_size (int): Batch size. Default is 32. |
34 |
| - shuffle (bool): Whether to shuffle data. Default is False. |
35 |
| - device (str): Device to make tensor on. For multiprocessing this can cause deadlocks if e.g. set on GPU. |
36 |
| - Example for CPU would be '/cpu:0' . |
37 |
| - """ |
38 |
| - self.data = data |
39 |
| - self.inputs = inputs |
40 |
| - self.outputs = outputs |
41 |
| - self.batch_size = batch_size |
42 |
| - self.shuffle = shuffle |
43 |
| - self.indices = np.arange(len(data)) |
44 |
| - self.device = device |
45 |
| - self._shuffle_indices() |
46 |
| - |
47 |
| - def __len__(self): |
48 |
| - """Denotes the number of batches per epoch""" |
49 |
| - return int(np.ceil(len(self.data) / float(self.batch_size))) |
50 |
| - |
51 |
| - def __getitem__(self, index): |
52 |
| - """Generate one batch of data""" |
53 |
| - # Generate indexes of the batch |
54 |
| - batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size] |
55 |
| - |
56 |
| - # Generate data |
57 |
| - if self.device is not None: |
58 |
| - with tf.device(self.device): |
59 |
| - x_model, y_model = self._data_generation(batch_indices) |
60 |
| - else: |
61 |
| - x_model, y_model = self._data_generation(batch_indices) |
62 |
| - |
63 |
| - return x_model, y_model |
64 |
| - |
65 |
| - def _shuffle_indices(self): |
66 |
| - if self.shuffle: |
67 |
| - np.random.shuffle(self.indices) |
68 |
| - |
69 |
| - def on_epoch_end(self): |
70 |
| - """Updates after each epoch""" |
71 |
| - self._shuffle_indices() |
72 |
| - |
73 |
| - @staticmethod |
74 |
| - def _to_tensor(item: Union[np.ndarray, list], is_ragged: bool): |
75 |
| - if is_ragged: |
76 |
| - return ragged_tensor_from_nested_numpy(item) |
77 |
| - else: |
78 |
| - return tf.constant(np.array(item)) |
79 |
| - |
80 |
| - def _get_data(self, index): |
81 |
| - # Accessing data method. E.g. loading from file etc. |
82 |
| - return self.data[int(index)] |
83 |
| - |
84 |
| - def _data_generation(self, batch_indices: Union[np.ndarray, list]): |
85 |
| - """Generates data containing batch_size samples""" |
86 |
| - graphs = [self._get_data(i) for i in batch_indices] |
87 |
| - # Inputs |
88 |
| - inputs = self.inputs if not isinstance(self.inputs, dict) else [self.inputs] |
89 |
| - x_inputs = [] |
90 |
| - for i in inputs: |
91 |
| - data_list = [g[i["name"]] for g in graphs] |
92 |
| - is_ragged = i["ragged"] if "ragged" in i else False |
93 |
| - x_inputs.append(self._to_tensor(data_list, is_ragged)) |
94 |
| - y_outputs = [] |
95 |
| - # Outputs |
96 |
| - outputs = self.outputs if not isinstance(self.outputs, dict) else [self.outputs] |
97 |
| - for i in outputs: |
98 |
| - data_list = [g[i["name"]] for g in graphs] |
99 |
| - is_ragged = i["ragged"] if "ragged" in i else False |
100 |
| - y_outputs.append(self._to_tensor(data_list, is_ragged)) |
101 |
| - # Check return type. |
102 |
| - x_inputs = x_inputs if not isinstance(self.inputs, dict) else x_inputs[0] |
103 |
| - y_outputs = y_outputs if not isinstance(self.outputs, dict) else y_outputs[0] |
104 |
| - return x_inputs, y_outputs |
0 commit comments