Skip to content

Commit 646a2bc

Browse files
authored
Merge pull request #45 from antvis/feat/add-cosine-similarity
feat: add nodesCosineSimilarity算法-基于节点属性计算余弦相似度
2 parents 46e7616 + 51eb973 commit 646a2bc

10 files changed

+293
-63
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# ChangeLog
2+
#### 0.1.17
3+
4+
- feat: add consine-similarity algorithm and nodes-consine-similarity algorithm;
25

36
#### 0.1.16
47

packages/graph/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@antv/algorithm",
3-
"version": "0.1.16",
3+
"version": "0.1.17",
44
"description": "graph algorithm",
55
"keywords": [
66
"graph",
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import Vector from './utils/vector';
2+
/**
3+
* cosine-similarity算法 计算余弦相似度
4+
* @param item 元素
5+
* @param targetItem 目标元素
6+
*/
7+
const cosineSimilarity = (
8+
item: number[],
9+
targetItem: number[],
10+
): number => {
11+
// 目标元素向量
12+
const targetItemVector = new Vector(targetItem);
13+
// 目标元素向量的模长
14+
const targetNodeNorm2 = targetItemVector.norm2();
15+
// 元素向量
16+
const itemVector = new Vector(item);
17+
// 元素向量的模长
18+
const itemNorm2 = itemVector.norm2();
19+
// 计算元素向量和目标元素向量的点积
20+
const dot = targetItemVector.dot(itemVector);
21+
const norm2Product = targetNodeNorm2 * itemNorm2;
22+
// 计算元素向量和目标元素向量的余弦相似度
23+
const cosineSimilarity = norm2Product ? dot / norm2Product : 0;
24+
return cosineSimilarity;
25+
}
26+
27+
export default cosineSimilarity;

packages/graph/src/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import labelPropagation from './label-propagation';
1212
import louvain from './louvain';
1313
import iLouvain from './i-louvain';
1414
import kCore from './k-core';
15+
import cosineSimilarity from './cosine-similarity';
16+
import nodesCosineSimilarity from './nodes-cosine-similarity';
1517
import minimumSpanningTree from './mts';
1618
import pageRank from './pageRank';
1719
import GADDI from './gaddi';
@@ -42,6 +44,8 @@ export {
4244
louvain,
4345
iLouvain,
4446
kCore,
47+
cosineSimilarity,
48+
nodesCosineSimilarity,
4549
minimumSpanningTree,
4650
pageRank,
4751
getNeighbors,
@@ -71,6 +75,8 @@ export default {
7175
louvain,
7276
iLouvain,
7377
kCore,
78+
cosineSimilarity,
79+
nodesCosineSimilarity,
7480
minimumSpanningTree,
7581
pageRank,
7682
getNeighbors,

packages/graph/src/louvain.ts

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { clone } from '@antv/util';
22
import getAdjMatrix from './adjacent-matrix';
33
import { NodeConfig, ClusterData, GraphData, ClusterMap } from './types';
44
import Vector from './utils/vector';
5-
import { secondReg, dateReg } from './constants/time';
5+
import { getPropertyWeight } from './utils/node-properties';
66

77
const getModularity = (
88
nodes: NodeConfig[],
@@ -28,59 +28,6 @@ const getModularity = (
2828
return modularity;
2929
}
3030

31-
// 获取所有属性并排序
32-
const getAllSortProperties = (nodes: NodeConfig[] = []) => {
33-
const propertyKeyInfo = {};
34-
nodes.forEach(node => {
35-
Object.keys(node.properties).forEach(propertyKey => {
36-
// 目前过滤只保留可以转成数值型的或日期型的, todo: 统一转成one-hot特征向量
37-
if (!`${node.properties[propertyKey]}`.match(secondReg) &&
38-
!`${node.properties[propertyKey]}`.match(dateReg) &&
39-
isNaN(Number(node.properties[propertyKey])) || propertyKey === 'id') {
40-
if (propertyKeyInfo.hasOwnProperty(propertyKey)) {
41-
delete propertyKeyInfo[propertyKey];
42-
}
43-
return;
44-
}
45-
if (propertyKeyInfo.hasOwnProperty(propertyKey)) {
46-
propertyKeyInfo[propertyKey] += 1;
47-
} else {
48-
propertyKeyInfo[propertyKey] = 1;
49-
}
50-
})
51-
})
52-
53-
// 取top50的属性
54-
const sortKeys = Object.keys(propertyKeyInfo).sort((a,b)=>{
55-
return propertyKeyInfo[b] - propertyKeyInfo[a];
56-
});
57-
return sortKeys.length < 100 ? sortKeys : sortKeys.slice(0, 100);
58-
}
59-
60-
const processProperty = (properties, propertyKeys) => propertyKeys.map(key => {
61-
if (properties.hasOwnProperty(key)) {
62-
// 可以转成数值的直接转成数值
63-
if (!isNaN(Number(properties[key]))) {
64-
return Number(properties[key]);
65-
}
66-
// 时间型的转成时间戳
67-
if (properties[key].match(secondReg) || properties[key].match(dateReg)) {
68-
// @ts-ignore
69-
return Number(Date.parse(new Date(properties[key]))) / 1000;
70-
}
71-
}
72-
return 0;
73-
})
74-
75-
// 获取属性特征权重
76-
const getPropertyWeight = (propertyKeys, nodes) => {
77-
let allPropertiesWeight = [];
78-
for (let i = 0; i < nodes.length; i++) {
79-
allPropertiesWeight[i] = processProperty(nodes[i].properties, propertyKeys);
80-
}
81-
return allPropertiesWeight;
82-
}
83-
8431
// 模块惯性度,衡量属性相似度
8532
const getInertialModularity = (
8633
nodes: NodeConfig[] = [],
@@ -171,9 +118,8 @@ const louvain = (
171118
node.properties.nodeType = nodeTypeInfo.findIndex(nodeType => nodeType === node.nodeType);
172119
})
173120
}
174-
const propertyKeys = getAllSortProperties(nodes);
175121
// 所有节点属性特征向量集合
176-
allPropertiesWeight = getPropertyWeight(propertyKeys, nodes);
122+
allPropertiesWeight = getPropertyWeight(nodes);
177123
}
178124

179125
let uniqueId = 1;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { clone } from '@antv/util';
2+
import { NodeConfig } from './types';
3+
import { getPropertyWeight } from './utils/node-properties';
4+
import cosineSimilarity from './cosine-similarity';
5+
/**
6+
* nodes-cosine-similarity算法 基于节点属性计算余弦相似度(基于种子节点寻找相似节点)
7+
* @param nodes 图节点数据
8+
* @param seedNode 种子节点
9+
*/
10+
const nodesCosineSimilarity = (
11+
nodes: NodeConfig[] = [],
12+
seedNode: NodeConfig,
13+
): {
14+
allCosineSimilarity: number[],
15+
similarNodes: NodeConfig[],
16+
} => {
17+
const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id));
18+
const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id);
19+
// 所有节点属性特征向量集合
20+
const allPropertiesWeight = getPropertyWeight(nodes);
21+
// 种子节点属性
22+
const seedNodeProperties = allPropertiesWeight[seedNodeIndex];
23+
24+
const allCosineSimilarity: number[] = [];
25+
similarNodes.forEach((node, index) => {
26+
if (node.id !== seedNode.id) {
27+
// 节点属性
28+
const nodeProperties = allPropertiesWeight[index];
29+
// 计算节点向量和种子节点向量的余弦相似度
30+
const cosineSimilarityValue = cosineSimilarity(nodeProperties, seedNodeProperties);
31+
allCosineSimilarity.push(cosineSimilarityValue);
32+
node.cosineSimilarity = cosineSimilarityValue;
33+
}
34+
});
35+
36+
// 将返回的节点按照余弦相似度大小排序
37+
similarNodes.sort((a, b) => b.cosineSimilarity - a.cosineSimilarity);
38+
return { allCosineSimilarity, similarNodes };
39+
}
40+
41+
export default nodesCosineSimilarity;
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import { NodeConfig } from '../types';
2+
import { secondReg, dateReg } from '../constants/time';
3+
4+
// 获取所有属性并排序
5+
export const getAllSortProperties = (nodes: NodeConfig[] = [], n: number = 100) => {
6+
const propertyKeyInfo = {};
7+
nodes.forEach(node => {
8+
if (!node.properties) {
9+
return;
10+
}
11+
Object.keys(node.properties).forEach(propertyKey => {
12+
// 目前过滤只保留可以转成数值型的或日期型的, todo: 统一转成one-hot特征向量或者embedding
13+
if (propertyKey === 'id' || !`${node.properties[propertyKey]}`.match(secondReg) &&
14+
!`${node.properties[propertyKey]}`.match(dateReg) &&
15+
isNaN(Number(node.properties[propertyKey]))) {
16+
if (propertyKeyInfo.hasOwnProperty(propertyKey)) {
17+
delete propertyKeyInfo[propertyKey];
18+
}
19+
return;
20+
}
21+
if (propertyKeyInfo.hasOwnProperty(propertyKey)) {
22+
propertyKeyInfo[propertyKey] += 1;
23+
} else {
24+
propertyKeyInfo[propertyKey] = 1;
25+
}
26+
})
27+
})
28+
29+
// 取top50的属性
30+
const sortKeys = Object.keys(propertyKeyInfo).sort((a,b) => propertyKeyInfo[b] - propertyKeyInfo[a]);
31+
return sortKeys.length < n ? sortKeys : sortKeys.slice(0, n);
32+
}
33+
34+
const processProperty = (properties, propertyKeys) => propertyKeys.map(key => {
35+
if (properties.hasOwnProperty(key)) {
36+
// 可以转成数值的直接转成数值
37+
if (!isNaN(Number(properties[key]))) {
38+
return Number(properties[key]);
39+
}
40+
// 时间型的转成时间戳
41+
if (properties[key].match(secondReg) || properties[key].match(dateReg)) {
42+
// @ts-ignore
43+
return Number(Date.parse(new Date(properties[key]))) / 1000;
44+
}
45+
}
46+
return 0;
47+
})
48+
49+
// 获取属性特征权重
50+
export const getPropertyWeight = (nodes: NodeConfig[]) => {
51+
const propertyKeys = getAllSortProperties(nodes);
52+
let allPropertiesWeight = [];
53+
for (let i = 0; i < nodes.length; i++) {
54+
allPropertiesWeight[i] = processProperty(nodes[i].properties, propertyKeys);
55+
}
56+
return allPropertiesWeight;
57+
}
58+
59+
export default {
60+
getAllSortProperties,
61+
getPropertyWeight,
62+
}

packages/graph/src/utils/vector.ts

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Vector {
2121
}
2222
if (this.arr.length === otherArr.length) {
2323
let res = [];
24-
for(let key in this.arr) {
24+
for (let key in this.arr) {
2525
res[key] = this.arr[key] + otherArr[key];
2626
}
2727
return new Vector(res);
@@ -38,7 +38,7 @@ class Vector {
3838
}
3939
if (this.arr.length === otherArr.length) {
4040
let res = [];
41-
for(let key in this.arr) {
41+
for (let key in this.arr) {
4242
res[key] = this.arr[key] - otherArr[key];
4343
}
4444
return new Vector(res);
@@ -47,15 +47,15 @@ class Vector {
4747

4848
avg(length) {
4949
let res = [];
50-
for(let key in this.arr) {
50+
for (let key in this.arr) {
5151
res[key] = this.arr[key] / length;
5252
}
5353
return new Vector(res);
5454
}
5555

5656
negate() {
5757
let res = [];
58-
for(let key in this.arr) {
58+
for (let key in this.arr) {
5959
res[key] = - this.arr[key];
6060
}
6161
return new Vector(res);
@@ -69,7 +69,7 @@ class Vector {
6969
}
7070
if (this.arr.length === otherArr.length) {
7171
let res = 0;
72-
for(let key in this.arr) {
72+
for (let key in this.arr) {
7373
res += Math.pow(this.arr[key] - otherVector.arr[key], 2);
7474
}
7575
return res;
@@ -83,11 +83,38 @@ class Vector {
8383
cloneArr.sort((a, b) => a - b);
8484
const max = cloneArr[cloneArr.length - 1];
8585
const min = cloneArr[0];
86-
for(let key in this.arr) {
86+
for (let key in this.arr) {
8787
res[key] = (this.arr[key] - min) / (max - min);
8888
}
8989
return new Vector(res);
9090
}
91+
92+
// 2范数 or 模长
93+
norm2() {
94+
if (!this.arr?.length) {
95+
return 0;
96+
}
97+
let res = 0;
98+
for (let key in this.arr) {
99+
res += Math.pow(this.arr[key], 2);
100+
}
101+
return Math.sqrt(res);
102+
}
103+
104+
// 两个向量的点积
105+
dot(otherVector) {
106+
const otherArr = otherVector.arr;
107+
if (!this.arr?.length || !otherArr?.length) {
108+
return 0;
109+
}
110+
if (this.arr.length === otherArr.length) {
111+
let res = 0;
112+
for (let key in this.arr) {
113+
res += this.arr[key] * otherVector.arr[key];
114+
}
115+
return res;
116+
}
117+
}
91118
}
92119

93120
export default Vector;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { cosineSimilarity } from '../../src';
2+
3+
describe('cosineSimilarity abnormal demo: ', () => {
4+
it('item contains only zeros: ', () => {
5+
const item = [0, 0, 0];
6+
const targetTtem = [3, 1, 1];
7+
const cosineSimilarityValue = cosineSimilarity(item, targetTtem);
8+
expect(cosineSimilarityValue).toBe(0);
9+
});
10+
it('targetTtem contains only zeros: ', () => {
11+
const item = [3, 5, 2];
12+
const targetTtem = [0, 0, 0];
13+
const cosineSimilarityValue = cosineSimilarity(item, targetTtem);
14+
expect(cosineSimilarityValue).toBe(0);
15+
});
16+
it('item and targetTtem both contains only zeros: ', () => {
17+
const item = [0, 0, 0];
18+
const targetTtem = [0, 0, 0];
19+
const cosineSimilarityValue = cosineSimilarity(item, targetTtem);
20+
expect(cosineSimilarityValue).toBe(0);
21+
});
22+
});
23+
24+
describe('cosineSimilarity normal demo: ', () => {
25+
it('demo similar: ', () => {
26+
const item = [30, 0, 100];
27+
const targetTtem = [32, 1, 120];
28+
const cosineSimilarityValue = cosineSimilarity(item, targetTtem);
29+
expect(cosineSimilarityValue).toBeGreaterThanOrEqual(0);
30+
expect(cosineSimilarityValue).toBeLessThan(1);
31+
expect(Number(cosineSimilarityValue.toFixed(3))).toBe(0.999);
32+
});
33+
it('demo dissimilar: ', () => {
34+
const item = [10, 300, 2];
35+
const targetTtem = [1, 2, 30];
36+
const cosineSimilarityValue = cosineSimilarity(item, targetTtem);
37+
expect(cosineSimilarityValue).toBeGreaterThanOrEqual(0);
38+
expect(cosineSimilarityValue).toBeLessThan(1);
39+
expect(Number(cosineSimilarityValue.toFixed(3))).toBe(0.074);
40+
});
41+
});

0 commit comments

Comments
 (0)