Skip to content

Commit 8e90b49

Browse files
committed
feats: support complex AttrLabel for gedlibpy module. It must use the updated GEDLIB C++ lib (commit #0ebf2f1). For float vector node and edge labels, AttrLabel version can reach to ~12x times and ~ 37x times faster than the previous GXLLabel (string) version, respectively with and without parallelization (see experimental results in filecompare_gedlib_with_coords_in_string_and_attr_format.py's comments).
1 parent 0d4e4ef commit 8e90b49

31 files changed

+8435
-913
lines changed

gklearn/experiments/ged/ged_model/compare_gedlib_with_coords_in_string_and_attr_format.py

Lines changed: 119 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,19 @@ def fit_model_ged(
6767
list(all_graphs[idx_edge[0]].edges)[0]].keys()
6868
)
6969

70-
from gklearn.experiments.ged.ged_model.parallel_version import GEDModel
70+
from gklearn.experiments.ged.ged_model.ged_model_parallel import GEDModel
7171

7272
if parallel is False:
7373
parallel = None
7474
elif parallel is True:
7575
parallel = 'imap_unordered'
7676

7777
model = GEDModel(
78+
# env_type=ged_options['env_type'],
7879
ed_method=ged_options['method'],
7980
edit_cost_fun=ged_options['edit_cost_fun'],
8081
init_edit_cost_constants=ged_options['edit_costs'],
82+
edit_cost_config=ged_options.get('edit_cost_config', {}),
8183
optim_method=ged_options['optim_method'],
8284
node_labels=nl_names, edge_labels=el_names,
8385
parallel=parallel,
@@ -157,7 +159,7 @@ def show_some_graphs(graphs):
157159
Show some graphs from the list of graphs.
158160
"""
159161
print(f'{INFO_TAG}Showing some graphs:')
160-
for i, g in enumerate(graphs[:5]):
162+
for i, g in enumerate(graphs[:3]):
161163
print(f'Graph {i}:')
162164
print('Number of nodes:', g.number_of_nodes())
163165
print('Number of edges:', g.number_of_edges())
@@ -177,12 +179,15 @@ def convert_graphs_coords_from_attr_to_string(graphs: List[nx.Graph]):
177179
coords = node[1]['coords']
178180
node[1]['x'] = str(coords[0])
179181
node[1]['y'] = str(coords[1])
182+
for idx in range(2, len(coords)):
183+
# If there are more than 2 coordinates, store them with extra keys:
184+
node[1][f'coord_{idx}'] = str(coords[idx])
180185
del node[1]['coords']
181186
print(f'{INFO_TAG}Converted coordinates from attribute format to string format.')
182187

183188

184189
def fit_model_attr_version(
185-
seed: int = 42, n_graphs: int = 100
190+
seed: int = 42, n_graphs: int = 100, n_emb_dim: int = 2, parallel: bool = False
186191
) -> (np.array, float):
187192
"""
188193
Fit the GED model with graphs that have coordinates on nodes in attribute format `AttrLabel`.
@@ -205,28 +210,43 @@ def fit_model_attr_version(
205210
with_continuous_n_features=True,
206211
with_continuous_e_features=False,
207212
continuous_n_feature_key='coords',
208-
continuous_n_feature_dim=2,
213+
continuous_n_feature_dim=n_emb_dim,
209214
continuous_e_feature_dim=0,
210215
seed=seed
211216
)
212217
graphs = generator.generate_graphs()
218+
# Check graph node label format:
219+
one_n_labels = graphs[0].nodes[list(graphs[0].nodes)[0]]
220+
assert 'coords' in one_n_labels and isinstance(one_n_labels['coords'], np.ndarray) and (
221+
len(one_n_labels['coords']) > 0 and one_n_labels['coords'].dtype in [
222+
np.float64, np.float32]
223+
), (
224+
'The node labels should contain "coords" key with a numpy array as value.'
225+
)
213226
print(
214227
f'{INFO_TAG}Generated {len(graphs)} graphs with coordinates in string format.'
215228
)
216229
show_some_graphs(graphs)
217230

218231
# Set GED options:
219232
ged_options = {
233+
'env_type': 'attr', # Use the attribute-based environment
220234
'method': 'BIPARTITE',
221235
'edit_cost_fun': 'GEOMETRIC',
222236
'edit_costs': [3, 3, 1, 3, 3, 1],
237+
'edit_cost_config': {
238+
'node_coord_metric': 'euclidean',
239+
'node_embed_metric': 'cosine_distance',
240+
'edge_weight_metric': 'euclidean',
241+
'edge_embed_metric': 'cosine_distance',
242+
},
223243
'optim_method': 'init',
224-
'repeats': 1
244+
'repeats': 1,
225245
}
226246

227247
fit_settings = {
228-
'parallel': None,
229-
'n_jobs': 1, # min(12, max(os.cpu_count() - 2, 0)),
248+
'parallel': parallel, # Use parallel processing if specified
249+
'n_jobs': 10, # min(12, max(os.cpu_count() - 2, 0)),
230250
'chunksize': None, # None == automatic determination
231251
'copy_graphs': True,
232252
'reorder_graphs': False,
@@ -251,7 +271,7 @@ def fit_model_attr_version(
251271

252272

253273
def fit_model_string_version(
254-
seed: int = 42, n_graphs: int = 100
274+
seed: int = 42, n_graphs: int = 100, n_emb_dim: int = 2, parallel: bool = False
255275
) -> (np.array, float):
256276
"""
257277
Fit the GED model with graphs that have coordinates on nodes in string format `GXLLabel`.
@@ -272,19 +292,26 @@ def fit_model_string_version(
272292
with_continuous_n_features=True,
273293
with_continuous_e_features=False,
274294
continuous_n_feature_key='coords',
275-
continuous_n_feature_dim=2,
295+
continuous_n_feature_dim=n_emb_dim,
276296
continuous_e_feature_dim=0,
277297
seed=seed
278298
)
279299
graphs = generator.generate_graphs()
280300
convert_graphs_coords_from_attr_to_string(graphs)
301+
# Check graph node label format:
302+
one_n_labels = graphs[0].nodes[list(graphs[0].nodes)[0]]
303+
assert 'x' in one_n_labels and 'y' in one_n_labels and isinstance(
304+
one_n_labels['x'], str) and isinstance(one_n_labels['y'], str), (
305+
'The node labels should contain "x" and "y" keys with string values.'
306+
)
281307
print(
282308
f'{INFO_TAG}Generated {len(graphs)} graphs with coordinates in string format.'
283309
)
284310
show_some_graphs(graphs)
285311

286312
# Set GED options:
287313
ged_options = {
314+
'env_type': 'gxl', # Use the GXLLabel environment
288315
'method': 'BIPARTITE',
289316
'edit_cost_fun': 'NON_SYMBOLIC',
290317
'edit_costs': [3, 3, 1, 3, 3, 1],
@@ -293,8 +320,8 @@ def fit_model_string_version(
293320
}
294321

295322
fit_settings = {
296-
'parallel': None,
297-
'n_jobs': 1, # min(12, max(os.cpu_count() - 2, 0)),
323+
'parallel': parallel, # Use parallel processing if specified
324+
'n_jobs': 10, # min(12, max(os.cpu_count() - 2, 0)),
298325
'chunksize': None, # None == automatic determination
299326
'copy_graphs': True,
300327
'reorder_graphs': False,
@@ -319,25 +346,31 @@ def fit_model_string_version(
319346

320347

321348
def compare_gedlib_with_coords_in_string_and_attr_format(
322-
seed: int = 42, n_graphs: int = 100
349+
seed: int = 42, n_graphs: int = 100, n_emb_dim: int = 2, parallel: bool = False
323350
) -> (np.array, np.array):
324351
"""
325352
Compare the output and the performance of GEDLIB with the same graphs with coordinates on nodes,
326353
but one is in string format `GXLLabel` and the other is in the complex attribute format `AttrLabel`.
327354
"""
328-
# cost_matrix_s, run_time_s = fit_model_string_version(seed=seed, n_graphs=n_graphs)
329-
cost_matrix_a, run_time_a = fit_model_attr_version(seed=seed, n_graphs=n_graphs)
330-
if not np.array_equal(cost_matrix_s, cost_matrix_a):
355+
cost_matrix_a, run_time_a = fit_model_attr_version(
356+
seed=seed, n_graphs=n_graphs, n_emb_dim=n_emb_dim, parallel=parallel
357+
)
358+
cost_matrix_s, run_time_s = fit_model_string_version(
359+
seed=seed, n_graphs=n_graphs, n_emb_dim=n_emb_dim, parallel=parallel
360+
)
361+
if not np.allclose(cost_matrix_s, cost_matrix_a, rtol=1e-9):
331362
print(
332363
f'{ISSUE_TAG}The cost matrices are not equal! '
333364
f'String version: {cost_matrix_s.shape}, '
334-
f'Attribute version: {cost_matrix_a.shape}'
365+
f'Attribute version: {cost_matrix_a.shape}, '
366+
f'Relevant tolerance: 1e-9.'
335367
)
336368
else:
337369
print(
338370
f'{SUCCESS_TAG}The cost matrices are equal! '
339371
f'String version: {cost_matrix_s.shape}, '
340-
f'Attribute version: {cost_matrix_a.shape}'
372+
f'Attribute version: {cost_matrix_a.shape}, '
373+
f'Relevant tolerance: 1e-9.'
341374
)
342375

343376
# Print the first 5 rows and columns of the matrices:
@@ -366,5 +399,72 @@ def compare_gedlib_with_coords_in_string_and_attr_format(
366399
# Test the class
367400
# feat_type = 'str'
368401
seed = 42
369-
n_graphs = 10
370-
compare_gedlib_with_coords_in_string_and_attr_format(seed=seed, n_graphs=n_graphs)
402+
n_graphs = 500
403+
n_emb_dim = 100
404+
parellel = True
405+
compare_gedlib_with_coords_in_string_and_attr_format(
406+
seed=seed, n_graphs=n_graphs, n_emb_dim=n_emb_dim, parallel=parellel
407+
)
408+
409+
# # Comparison of the two versions:
410+
#
411+
# General Settings:
412+
# - n_graphs: 500
413+
# - node numbers: 10-20
414+
# - edge numbers: 20-50
415+
# - Regenerate GEDEnv for each pair of computation (not optimized).
416+
# - Coordinates as labels of strings in GXLLabel or one label of np.array in AttrLabel,
417+
# where the latter is optimized by the Eigen C++ library for vectorized operations.
418+
#
419+
# ## Without parallelization:
420+
#
421+
# ### n_emb_dim = 2:
422+
# - String version run time: 7.4e-4 s per pair (92.3 s total).
423+
# - Attribute version run time: 5.0e-4 s per pair (62.4 s total).
424+
# The Attr version is ~ 1.5x faster than the String version.
425+
#
426+
# ### n_emb_dim = 20:
427+
# - String version run time: 5.4e-3 s per pair (675.1 s total).
428+
# - Attribute version run time: 5.5e-4 s per pair (69.0 s total).
429+
# The Attr version is ~ 10x faster than the String version.
430+
#
431+
# ### n_emb_dim = 100:
432+
# - String version run time: too long to compute (over 1 h ~ 3698.5 s).
433+
# - Attribute version run time: 8.0e-4 s per pair (99.9 s total).
434+
# The Attr version is ~ 37x faster than the String version.
435+
#
436+
# ### Conclusion:
437+
# - The Attribute version is faster than the String version.
438+
# - With the increase of the dimensionality of the coordinates (n_emb_dim):
439+
# -- Attribute version takes almost the same level of time to compute pairwise
440+
# distances (e.g., ~ 1.6x slower when n_emb_dim = 100 than 2).
441+
# -- String version becomes unusable (~ 40x slower when n_emb_dim = 100 than 2),
442+
# and ~ 37x slower than the Attribute version with n_emb_dim = 100.
443+
#
444+
# ## With parallelization (n_jobs=10):
445+
#
446+
# ### n_emb_dim = 2:
447+
# - String version run time: 3.6e-4 s per pair (45.3 s total).
448+
# - Attribute version run time: 3.6e-4 s per pair (45.3 s total).
449+
# The two versions are almost equal in terms of run time.
450+
#
451+
# ### n_emb_dim = 20:
452+
# - String version run time: 9.8e-4 s per pair (122.4 s total).
453+
# - Attribute version run time: 4.1e-4 s per pair (50.7 s total).
454+
# The Attribute version is ~ 2.4x faster than the String version.
455+
#
456+
# ### n_emb_dim = 100:
457+
# - String version run time: 5.3e-3 s per pair (664.3 s total).
458+
# - Attribute version run time: 4.4e-4 s per pair (54.3 s total).
459+
# The Attribute version is ~ 12.2x faster than the String version.
460+
#
461+
# ### Conclusion:
462+
# - The Attribute version is still faster than the String version.
463+
# - The parallelization helps to reduce the run time of both versions,
464+
# but the improvement on the String version is much more significant,
465+
# e.g., ~ x faster than the non-parallelized version with n_emb_dim = 100
466+
# - On the other hand, the improvement brought by parallelization is not so significant
467+
# for the Attribute version, e.g., ~ 1.8x faster than the non-parallelized version
468+
# with n_emb_dim = 100.
469+
# -- I assume the reason is that the construction of the GEDEnvAttr and the
470+
# Python-C++ interface conversion becomes the bottleneck of the process.

gklearn/experiments/ged/ged_model/fit_ged_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def fit_model_ged(
6767
list(all_graphs[idx_edge[0]].edges)[0]].keys()
6868
)
6969

70-
from gklearn.experiments.ged.ged_model.parallel_version import GEDModel
70+
from gklearn.experiments.ged.ged_model.ged_model_parallel import GEDModel
7171

7272
if parallel is False:
7373
parallel = None

gklearn/experiments/ged/ged_model/parallel_version.py renamed to gklearn/experiments/ged/ged_model/ged_model_parallel.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ class GEDModel(BaseEstimator): # , ABC):
5252

5353
def __init__(
5454
self,
55+
env_type: str | None = None,
5556
ed_method='BIPARTITE',
5657
edit_cost_fun='CONSTANT',
5758
init_edit_cost_constants=[3, 3, 1, 3, 3, 1],
59+
edit_cost_config: dict = {},
5860
optim_method='init',
5961
optim_options={'y_distance': euclid_d, 'mode': 'reg'},
6062
node_labels=[],
@@ -66,12 +68,33 @@ def __init__(
6668
copy_graphs=True, # make sure it is a full deep copy. and faster!
6769
verbose=2
6870
):
69-
"""`__init__` for `GEDModel` object."""
71+
"""`__init__` for `GEDModel` object.
72+
73+
Parameters
74+
----------
75+
env_type : str, optional
76+
The type of the GED environment. Default is None. If None, try to determine
77+
the type automatically based on the given graph node / edge labels.
78+
79+
Available types are:
80+
81+
- 'attr': Attribute-based environment (with complex node and edge labels).
82+
Each node or edge can have multiple key-value label pairs, and each value can
83+
be of the following types: int, float, str, list/np.ndarray of int or float.
84+
This is the default type if no node or edge labels are provided.
85+
86+
- 'gxl' or 'str': GXLLabel environment (with string labels). Each node or
87+
edge can have multiple key-value label pairs, but all values must be strings.
88+
The type will be set to GXL only if at least one node or edge label is
89+
provided.
90+
"""
7091
# @todo: the default settings of the parameters are different from those in the self.compute method.
7192
# self._graphs = None
93+
self.env_type = env_type
7294
self.ed_method = ed_method
7395
self.edit_cost_fun = edit_cost_fun
7496
self.init_edit_cost_constants = init_edit_cost_constants
97+
self.edit_cost_config = edit_cost_config
7598
self.optim_method = optim_method
7699
self.optim_options = optim_options
77100
self.node_labels = node_labels
@@ -1079,12 +1102,15 @@ def _wrapper_compute_ged(self, itr):
10791102

10801103
def compute_ged(self, Gi, Gj, **kwargs):
10811104
"""
1082-
Compute GED between two graph according to edit_cost.
1105+
Compute GED between two graphs according to edit_cost.
10831106
"""
1107+
env_type = self.get_env_type(graph=Gi)
10841108
ged_options = {
1109+
'env_type': env_type,
10851110
'edit_cost': self.edit_cost_fun,
10861111
'method': self.ed_method,
1087-
'edit_cost_constants': self._edit_cost_constants
1112+
'edit_cost_constants': self._edit_cost_constants,
1113+
'edit_cost_config': self.edit_cost_config,
10881114
}
10891115
repeats = kwargs.get('repeats', 1)
10901116
dis, pi_forward, pi_backward = pairwise_ged(
@@ -1103,6 +1129,42 @@ def compute_ged(self, Gi, Gj, **kwargs):
11031129
return dis, None
11041130

11051131

1132+
def get_env_type(self, graph: nx.Graph | None = None):
1133+
"""
1134+
Check the environment type of the graph.
1135+
If `env_type` is set on initialization, return it.
1136+
Otherwise, check the given graph's node and edge labels to determine the type.
1137+
1138+
Only one node and one edge are checked to determine the type.
1139+
This function expects that all nodes have the same type of labels, so as all
1140+
edges.
1141+
"""
1142+
if self.env_type is not None:
1143+
return self.env_type
1144+
if graph is None:
1145+
raise ValueError(
1146+
'Graph is not provided while `env_type` not set on initialization. '
1147+
'Cannot determine environment type.'
1148+
)
1149+
# Use 'gxl' env type only if all nodes and edge labes are strings, and at least one
1150+
# node or edge label is present:
1151+
one_n_labels = graph.nodes[list(graph.nodes)[0]]
1152+
for k, v in one_n_labels.items():
1153+
if not isinstance(v, str):
1154+
return 'attr'
1155+
if nx.number_of_edges(graph) != 0:
1156+
one_e_labels = graph.edges[list(graph.edges)[0]]
1157+
for k, v in one_e_labels.items():
1158+
if not isinstance(v, str):
1159+
return 'attr'
1160+
if len(one_n_labels) > 0 or (
1161+
nx.number_of_edges(graph) != 0 and len(one_e_labels) > 0
1162+
):
1163+
return 'gxl'
1164+
return 'attr'
1165+
1166+
1167+
11061168
# def _compute_kernel_list(self, g1, g_list):
11071169
# start_time = time.time()
11081170

0 commit comments

Comments
 (0)