Skip to content

Commit 53939b0

Browse files
authored
feat: label propagation algorithm (#78)
* fix: remove Vector import from utils * feat: v5 algorithmm label propagation * test: unit test of label propagation * fix: fix lint * fix: remove default export * fix: fix lint * fix: Eng annotation * fix: change plain object to map * chore: remove useless function
1 parent db6c435 commit 53939b0

File tree

8 files changed

+310
-40
lines changed

8 files changed

+310
-40
lines changed

__tests__/data/label-propagation-test-data.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import { Graph } from "@antv/graphlib";
2+
import { labelPropagation } from '../../packages/graph/src';
3+
import { dataTransformer } from "../utils/data";
4+
import labelPropagationTestData from '../data/label-propagation-test-data.json';
5+
6+
7+
describe('label propagation', () => {
8+
it('simple label propagation', () => {
9+
const oldData = {
10+
nodes: [
11+
{ id: '0' }, { id: '1' }, { id: '2' }, { id: '3' }, { id: '4' },
12+
{ id: '5' }, { id: '6' }, { id: '7' }, { id: '8' }, { id: '9' },
13+
{ id: '10' }, { id: '11' }, { id: '12' }, { id: '13' }, { id: '14' },
14+
],
15+
edges: [
16+
{ source: '0', target: '1' }, { source: '0', target: '2' }, { source: '0', target: '3' }, { source: '0', target: '4' },
17+
{ source: '1', target: '2' }, { source: '1', target: '3' }, { source: '1', target: '4' },
18+
{ source: '2', target: '3' }, { source: '2', target: '4' },
19+
{ source: '3', target: '4' },
20+
{ source: '0', target: '0' },
21+
{ source: '0', target: '0' },
22+
{ source: '0', target: '0' },
23+
24+
{ source: '5', target: '6', weight: 5 }, { source: '5', target: '7' }, { source: '5', target: '8' }, { source: '5', target: '9' },
25+
{ source: '6', target: '7' }, { source: '6', target: '8' }, { source: '6', target: '9' },
26+
{ source: '7', target: '8' }, { source: '7', target: '9' },
27+
{ source: '8', target: '9' },
28+
29+
{ source: '10', target: '11' }, { source: '10', target: '12' }, { source: '10', target: '13' }, { source: '10', target: '14' },
30+
{ source: '11', target: '12' }, { source: '11', target: '13' }, { source: '11', target: '14' },
31+
{ source: '12', target: '13' }, { source: '12', target: '14' },
32+
{ source: '13', target: '14', weight: 5 },
33+
34+
{ source: '0', target: '5' },
35+
{ source: '5', target: '10' },
36+
{ source: '10', target: '0' },
37+
{ source: '10', target: '0' },
38+
]
39+
};
40+
const data = dataTransformer(oldData);
41+
const graph = new Graph(data);
42+
const clusteredData = labelPropagation(graph, false, 'weight');
43+
expect(clusteredData.clusters.length).not.toBe(0);
44+
expect(clusteredData.clusterEdges.length).not.toBe(0);
45+
});
46+
47+
it('label propagation with large graph', () => {
48+
const data = dataTransformer(labelPropagationTestData);
49+
const graph = new Graph(data);
50+
const clusteredData = labelPropagation(graph, false, 'weight');
51+
expect(clusteredData.clusters.length).not.toBe(0);
52+
expect(clusteredData.clusterEdges.length).not.toBe(0);
53+
}
54+
});

__tests__/utils/data.ts

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import { INode, IEdge } from '../../packages/graph/src/types';
66
* @return {{nodes:INode[],edges:IEdge[]}} new data
77
*/
88
export const dataTransformer = (data: {
9-
nodes: { id: ID; [key: string]: any }[];
10-
edges: { source: ID; target: ID; [key: string]: any }[];
9+
nodes: { id: ID;[key: string]: any }[];
10+
edges: { source: ID; target: ID;[key: string]: any }[];
1111
}): { nodes: INode[]; edges: IEdge[] } => {
1212
const { nodes, edges } = data;
1313
return {
@@ -22,31 +22,31 @@ export const dataTransformer = (data: {
2222
};
2323
};
2424

25-
export const dataPropertiesTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
26-
const { nodes, edges } = data;
27-
return {
28-
nodes: nodes.map((n) => {
29-
const { id, properties, ...rest } = n;
30-
return { id, data: { ...properties, ...rest } };
31-
}),
32-
edges: edges.map((e, i) => {
33-
const { id, source, target, ...rest } = e;
34-
return { id: id ? id : `edge-${i}`, target, source, data: rest };
35-
}),
36-
};
25+
export const dataPropertiesTransformer = (data: { nodes: { id: ID, [key: string]: any }[], edges: { source: ID, target: ID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
26+
const { nodes, edges } = data;
27+
return {
28+
nodes: nodes.map((n) => {
29+
const { id, properties, ...rest } = n;
30+
return { id, data: { ...properties, ...rest } };
31+
}),
32+
edges: edges.map((e, i) => {
33+
const { id, source, target, ...rest } = e;
34+
return { id: id ? id : `edge-${i}`, target, source, data: rest };
35+
}),
36+
};
3737
};
3838

3939

40-
export const dataLabelDataTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
41-
const { nodes, edges } = data;
42-
return {
43-
nodes: nodes.map((n) => {
44-
const { id, label, data } = n;
45-
return { id, data: { label, ...data } };
46-
}),
47-
edges: edges.map((e, i) => {
48-
const { id, source, target, ...rest } = e;
49-
return { id: id ? id : `edge-${i}`, target, source, data: rest };
50-
}),
51-
};
40+
export const dataLabelDataTransformer = (data: { nodes: { id: ID, [key: string]: any }[], edges: { source: ID, target: ID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => {
41+
const { nodes, edges } = data;
42+
return {
43+
nodes: nodes.map((n) => {
44+
const { id, label, data } = n;
45+
return { id, data: { label, ...data } };
46+
}),
47+
edges: edges.map((e, i) => {
48+
const { id, source, target, ...rest } = e;
49+
return { id: id ? id : `edge-${i}`, target, source, data: rest };
50+
}),
51+
};
5252
};

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"build:ci": "pnpm -r run build:ci",
2222
"prepare": "husky install",
2323
"test": "jest",
24-
"test_one": "jest ./__tests__/unit/k-means.spec.ts",
24+
"test_one": "jest ./__tests__/unit/label-propagation.spec.ts",
2525
"coverage": "jest --coverage",
2626
"build:site": "vite build",
2727
"deploy": "gh-pages -d site/dist",

packages/graph/src/detect-cycle.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ export const detectDirectedCycle = (
6262
return true;
6363
},
6464
};
65-
for (let key of Object.keys(unvisitedSet)) {
65+
for (const key of Object.keys(unvisitedSet)) {
6666
depthFirstSearch(graph, key, callbacks, true, false);
6767
}
6868
return cycle;
@@ -193,8 +193,9 @@ export const detectAllDirectedCycle = (
193193
adjList: { [key: ID]: number[] }
194194
) => {
195195
let closed = false; // whether a path is closed
196-
if (nodeIds && include === false && nodeIds.indexOf(node.id) > -1)
196+
if (nodeIds && !include && nodeIds.indexOf(node.id) > -1) {
197197
return closed;
198+
}
198199
path.push(node);
199200
blocked.add(node);
200201
const neighbors = adjList[node.id];
@@ -277,7 +278,7 @@ export const detectAllDirectedCycle = (
277278
// 对自环情况 (点连向自身) 特殊处理:记录自环,但不加入adjList
278279
if (
279280
neighbor === node.id &&
280-
!(include === false && nodeIds.indexOf(node.id) > -1)
281+
!(!include && nodeIds.indexOf(node.id) > -1)
281282
) {
282283
allCycles.push({ [node.id]: node });
283284
} else {
@@ -306,8 +307,9 @@ export const detectAllDirectedCycle = (
306307
});
307308
const startNode = idx2Node[minIdx];
308309
// StartNode is not in the specified node to include. End the search ahead of time.
309-
if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1)
310+
if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1) {
310311
return allCycles;
312+
}
311313
circuit(startNode, startNode, adjList);
312314
nodeIdx = minIdx + 1;
313315
} else {

packages/graph/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ export * from './connected-component';
1313
export * from './mst';
1414
export * from './k-means';
1515
export * from './detect-cycle';
16+
export * from './label-propagation';
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import { uniqueId } from "@antv/util";
2+
import { ClusterData, INode, IEdge, Graph, Matrix } from "./types";
3+
import { ID } from "@antv/graphlib";
4+
5+
function getAdjMatrix(graph: Graph, directed: boolean) {
6+
const nodes = graph.getAllNodes();
7+
const matrix: Matrix[] = [];
8+
// map node with index in data.nodes
9+
const nodeMap = new Map<string | number, number>();
10+
11+
if (!nodes) {
12+
throw new Error("invalid nodes data!");
13+
}
14+
15+
if (nodes) {
16+
nodes.forEach((node, i) => {
17+
nodeMap.set(node.id, i);
18+
const row: number[] = [];
19+
matrix.push(row);
20+
});
21+
}
22+
23+
const edges = graph.getAllEdges();
24+
if (edges) {
25+
edges.forEach((edge) => {
26+
const { source, target } = edge;
27+
const sIndex = nodeMap.get(source);
28+
const tIndex = nodeMap.get(target);
29+
if ((!sIndex && sIndex !== 0) || (!tIndex && tIndex !== 0)) return;
30+
matrix[sIndex][tIndex] = 1;
31+
if (!directed) {
32+
matrix[tIndex][sIndex] = 1;
33+
}
34+
});
35+
}
36+
return matrix;
37+
}
38+
39+
/**
40+
* Performs label propagation clustering on the given graph.
41+
* @param graph The graph object representing the nodes and edges.
42+
* @param directed A boolean indicating whether the graph is directed or not. Default is false.
43+
* @param weightPropertyName The name of the property used as the weight for edges. Default is 'weight'.
44+
* @param maxIteration The maximum number of iterations for label propagation. Default is 1000.
45+
* @returns The clustering result including clusters, cluster edges, and node-to-cluster mapping.
46+
*/
47+
export const labelPropagation = (
48+
graph: Graph,
49+
directed: boolean = false,
50+
weightPropertyName: string = "weight",
51+
maxIteration: number = 1000
52+
): ClusterData => {
53+
// the origin data
54+
const nodes = graph.getAllNodes();
55+
const edges = graph.getAllEdges();
56+
57+
const clusters: { [key: string]: { id: string; nodes: INode[] } } = {};
58+
const nodeMap: { [key: ID]: { node: INode; idx: number } } = {};
59+
const nodeToCluster = new Map<ID, string>();
60+
// init the clusters and nodeMap
61+
nodes.forEach((node, i) => {
62+
const cid: string = uniqueId();
63+
nodeToCluster.set(node.id, cid);
64+
clusters[cid] = {
65+
id: cid,
66+
nodes: [node],
67+
};
68+
nodeMap[node.id] = {
69+
node,
70+
idx: i,
71+
};
72+
});
73+
74+
// the adjacent matrix of calNodes inside clusters
75+
const adjMatrix = getAdjMatrix(graph, directed);
76+
// the sum of each row in adjacent matrix
77+
const ks = [];
78+
/**
79+
* neighbor nodes (id for key and weight for value) for each node
80+
* neighbors = {
81+
* id(node_id): { id(neighbor_1_id): weight(weight of the edge), id(neighbor_2_id): weight(weight of the edge), ... },
82+
* ...
83+
* }
84+
*/
85+
const neighbors: Map<ID, Map<ID, number>> = new Map<ID, Map<ID, number>>();
86+
adjMatrix.forEach((row, i) => {
87+
let k = 0;
88+
const iid = nodes[i].id;
89+
neighbors.set(iid, new Map<ID, number>());
90+
row.forEach((entry, j) => {
91+
if (!entry) return;
92+
k += entry;
93+
const jid = nodes[j].id;
94+
neighbors.get(iid).set(jid, entry);
95+
});
96+
ks.push(k);
97+
});
98+
99+
let iter = 0;
100+
101+
while (iter < maxIteration) {
102+
let changed = false;
103+
nodes.forEach((node) => {
104+
const neighborClusters: { [key: string]: number } = {};
105+
neighbors.get(node.id).forEach((neighborId, value) => {
106+
const neighborWeight = neighbors.get(node.id).get(neighborId);
107+
const neighborNode = nodeMap[neighborId].node;
108+
const neighborClusterId = nodeToCluster.get(neighborNode.id);
109+
if (!neighborClusters[neighborClusterId]) {
110+
neighborClusters[neighborClusterId] = 0;
111+
}
112+
neighborClusters[neighborClusterId] += neighborWeight;
113+
});
114+
// find the cluster with max weight
115+
let maxWeight = -Infinity;
116+
let bestClusterIds: string[] = [];
117+
Object.keys(neighborClusters).forEach((clusterId) => {
118+
if (maxWeight < neighborClusters[clusterId]) {
119+
maxWeight = neighborClusters[clusterId];
120+
bestClusterIds = [clusterId];
121+
} else if (maxWeight === neighborClusters[clusterId]) {
122+
bestClusterIds.push(clusterId);
123+
}
124+
});
125+
if (
126+
bestClusterIds.length === 1 &&
127+
bestClusterIds[0] === nodeToCluster.get(node.id)
128+
) {
129+
return;
130+
}
131+
const selfClusterIdx = bestClusterIds.indexOf(nodeToCluster.get(node.id));
132+
if (selfClusterIdx >= 0) bestClusterIds.splice(selfClusterIdx, 1);
133+
if (bestClusterIds && bestClusterIds.length) {
134+
changed = true;
135+
136+
// remove from origin cluster
137+
const selfCluster = clusters[nodeToCluster.get(node.id)];
138+
const nodeInSelfClusterIdx = selfCluster.nodes.indexOf(node);
139+
selfCluster.nodes.splice(nodeInSelfClusterIdx, 1);
140+
141+
// move the node to the best cluster
142+
const randomIdx = Math.floor(Math.random() * bestClusterIds.length);
143+
const bestCluster = clusters[bestClusterIds[randomIdx]];
144+
bestCluster.nodes.push(node);
145+
nodeToCluster.set(node.id, bestCluster.id);
146+
}
147+
});
148+
if (!changed) break;
149+
iter++;
150+
}
151+
152+
// delete the empty clusters
153+
Object.keys(clusters).forEach((clusterId) => {
154+
const cluster = clusters[clusterId];
155+
if (!cluster.nodes || !cluster.nodes.length) {
156+
delete clusters[clusterId];
157+
}
158+
});
159+
160+
// get the cluster edges
161+
const clusterEdges: IEdge[] = [];
162+
const clusterEdgeMap: { [key: string]: IEdge } = {};
163+
edges.forEach((edge) => {
164+
let i = 0;
165+
const { source, target } = edge;
166+
const weight = (edge.data[weightPropertyName] || 1) as number;
167+
const sourceClusterId = nodeToCluster.get(nodeMap[source].node.id);
168+
const targetClusterId = nodeToCluster.get(nodeMap[target].node.id);
169+
const newEdgeId = `${sourceClusterId}---${targetClusterId}`;
170+
if (clusterEdgeMap[newEdgeId]) {
171+
clusterEdgeMap[newEdgeId].data.weight += weight;
172+
(clusterEdgeMap[newEdgeId].data.count as number)++;
173+
} else {
174+
const newEdge = {
175+
id: i++,
176+
source: sourceClusterId,
177+
target: targetClusterId,
178+
data: {
179+
weight,
180+
count: 1,
181+
},
182+
};
183+
clusterEdgeMap[newEdgeId] = newEdge;
184+
clusterEdges.push(newEdge);
185+
}
186+
});
187+
188+
const clustersArray: { id: string; nodes: INode[] }[] = [];
189+
Object.keys(clusters).forEach((clusterId) => {
190+
clustersArray.push(clusters[clusterId]);
191+
});
192+
return {
193+
clusters: clustersArray,
194+
clusterEdges,
195+
nodeToCluster,
196+
};
197+
};

0 commit comments

Comments
 (0)