Skip to content

Commit 544a1c3

Browse files
authored
Merge pull request #46 from antvis/feat/add-one-hot
feat: add one-hot data preprocessing
2 parents 646a2bc + 0096424 commit 544a1c3

8 files changed

+142
-8
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# ChangeLog
2+
3+
#### 0.1.18
4+
5+
- feat: add one-hot data preprocessing
6+
27
#### 0.1.17
38

49
- feat: add consine-similarity algorithm and nodes-consine-similarity algorithm;

packages/graph/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@antv/algorithm",
3-
"version": "0.1.17",
3+
"version": "0.1.18",
44
"description": "graph algorithm",
55
"keywords": [
66
"graph",

packages/graph/src/louvain.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ import { clone } from '@antv/util';
22
import getAdjMatrix from './adjacent-matrix';
33
import { NodeConfig, ClusterData, GraphData, ClusterMap } from './types';
44
import Vector from './utils/vector';
5-
import { getPropertyWeight } from './utils/node-properties';
5+
import { getAllProperties } from './utils/node-properties';
6+
import { oneHot } from './utils/data-preprocessing';
67

78
const getModularity = (
89
nodes: NodeConfig[],
@@ -118,8 +119,10 @@ const louvain = (
118119
node.properties.nodeType = nodeTypeInfo.findIndex(nodeType => nodeType === node.nodeType);
119120
})
120121
}
121-
// 所有节点属性特征向量集合
122-
allPropertiesWeight = getPropertyWeight(nodes);
122+
// 所有节点属性集合
123+
const properties = getAllProperties(nodes);
124+
// 所有节点属性one-hot特征向量集合
125+
allPropertiesWeight = oneHot(properties);
123126
}
124127

125128
let uniqueId = 1;

packages/graph/src/nodes-cosine-similarity.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11
import { clone } from '@antv/util';
22
import { NodeConfig } from './types';
3-
import { getPropertyWeight } from './utils/node-properties';
3+
import { getAllProperties } from './utils/node-properties';
4+
import { oneHot } from './utils/data-preprocessing';
45
import cosineSimilarity from './cosine-similarity';
56
/**
67
* nodes-cosine-similarity算法 基于节点属性计算余弦相似度(基于种子节点寻找相似节点)
78
* @param nodes 图节点数据
89
* @param seedNode 种子节点
10+
* @param involvedKeys 参与计算的key集合
11+
* @param uninvolvedKeys 不参与计算的key集合
912
*/
1013
const nodesCosineSimilarity = (
1114
nodes: NodeConfig[] = [],
1215
seedNode: NodeConfig,
16+
involvedKeys: string[] = [],
17+
uninvolvedKeys: string[] = [],
1318
): {
1419
allCosineSimilarity: number[],
1520
similarNodes: NodeConfig[],
1621
} => {
1722
const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id));
1823
const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id);
19-
// 所有节点属性特征向量集合
20-
const allPropertiesWeight = getPropertyWeight(nodes);
24+
// 所有节点属性集合
25+
const properties = getAllProperties(nodes);
26+
// 所有节点属性one-hot特征向量集合s
27+
const allPropertiesWeight = oneHot(properties, involvedKeys, uninvolvedKeys);
2128
// 种子节点属性
2229
const seedNodeProperties = allPropertiesWeight[seedNodeIndex];
2330

packages/graph/src/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ export interface DegreeType {
4949
}
5050
}
5151

52+
export interface PlainObject {
53+
[key: string]: any;
54+
}
5255

5356
export interface IAlgorithm {
5457
getAdjMatrix: (graphData: GraphData, directed?: boolean) => Matrix[],
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import { isEmpty, uniq } from '@antv/util';
2+
import { PlainObject } from '../types';
3+
4+
/**
5+
* 获取数据中所有的属性及其对应的值
6+
* @param dataList 数据集
7+
* @param involvedKeys 参与计算的key集合
8+
* @param uninvolvedKeys 不参与计算的key集合
9+
*/
10+
export const getAllKeyValueMap = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => {
11+
let keys = [];
12+
// 指定了参与计算的keys时,使用指定的keys
13+
if (involvedKeys?.length) {
14+
keys = involvedKeys;
15+
} else {
16+
// 未指定抽取的keys时,提取数据中所有的key
17+
dataList.forEach(data => {
18+
keys = keys.concat(Object.keys(data));
19+
})
20+
keys = uniq(keys);
21+
}
22+
// 获取所有值非空的key的value数组
23+
const allKeyValueMap = {};
24+
keys.forEach(key => {
25+
let value = [];
26+
dataList.forEach(data => {
27+
if (data[key] !== undefined && data[key] !== '') {
28+
value.push(data[key]);
29+
}
30+
})
31+
if (value.length && !uninvolvedKeys?.includes(key)) {
32+
allKeyValueMap[key] = uniq(value);
33+
}
34+
})
35+
36+
return allKeyValueMap;
37+
}
38+
39+
/**
40+
* one-hot编码:数据特征提取
41+
* @param dataList 数据集
42+
* @param involvedKeys 参与计算的的key集合
43+
* @param uninvolvedKeys 不参与计算的key集合
44+
*/
45+
export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => {
46+
// 获取数据中所有的属性及其对应的值
47+
const allKeyValueMap = getAllKeyValueMap(dataList, involvedKeys, uninvolvedKeys);
48+
const oneHotCode = [];
49+
// 对数据进行one-hot编码
50+
dataList.forEach((data, index) => {
51+
let code = [];
52+
Object.keys(allKeyValueMap).forEach(key => {
53+
const keyValue = data[key];
54+
const allKeyValue = allKeyValueMap[key];
55+
const valueIndex = allKeyValue.findIndex(value => keyValue === value);
56+
let subCode = [];
57+
for(let i = 0; i < allKeyValue.length; i++) {
58+
if (i === valueIndex) {
59+
subCode.push(1);
60+
} else {
61+
subCode.push(0);
62+
}
63+
}
64+
code = code.concat(subCode);
65+
})
66+
oneHotCode[index] = code;
67+
})
68+
return oneHotCode;
69+
}
70+
71+
export default {
72+
getAllKeyValueMap,
73+
oneHot,
74+
}

packages/graph/src/utils/node-properties.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,20 @@ export const getPropertyWeight = (nodes: NodeConfig[]) => {
5656
return allPropertiesWeight;
5757
}
5858

59+
// 获取所有节点的属性集合
60+
export const getAllProperties = (nodes, key='properties') => {
61+
const allProperties = [];
62+
nodes.forEach(node => {
63+
if (!node.properties) {
64+
return;
65+
}
66+
allProperties.push(node[key]);
67+
})
68+
return allProperties;
69+
}
70+
5971
export default {
6072
getAllSortProperties,
6173
getPropertyWeight,
62-
}
74+
getAllProperties
75+
}

packages/graph/tests/unit/nodesCosineSimilarity-spec.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,33 @@ describe('nodesCosineSimilarity normal demo', () => {
7474
expect(data).toBeLessThanOrEqual(1);
7575
})
7676
});
77+
78+
79+
it('demo use involvedKeys: ', () => {
80+
const involvedKeys = ['amount', 'city'];
81+
const { nodes } = propertiesGraphData;
82+
const { allCosineSimilarity, similarNodes } = nodesCosineSimilarity(nodes as NodeConfig[], nodes[16], involvedKeys);
83+
expect(allCosineSimilarity.length).toBe(16);
84+
expect(similarNodes.length).toBe(16);
85+
allCosineSimilarity.forEach(data => {
86+
expect(data).toBeGreaterThanOrEqual(0);
87+
expect(data).toBeLessThanOrEqual(1);
88+
})
89+
expect(Number(Math.max.apply(null, allCosineSimilarity).toString().match(/^\d+(?:\.\d{0,2})?/))).toBe(0.99);
90+
expect(similarNodes[0].id).toBe('node-11');
91+
});
92+
93+
it('demo use uninvolvedKeys: ', () => {
94+
const uninvolvedKeys = ['amount'];
95+
const { nodes } = propertiesGraphData;
96+
const { allCosineSimilarity, similarNodes } = nodesCosineSimilarity(nodes as NodeConfig[], nodes[16], [], uninvolvedKeys);
97+
expect(allCosineSimilarity.length).toBe(16);
98+
expect(similarNodes.length).toBe(16);
99+
allCosineSimilarity.forEach(data => {
100+
expect(data).toBeGreaterThanOrEqual(0);
101+
expect(data).toBeLessThanOrEqual(1);
102+
})
103+
expect(Number(Math.max.apply(null, allCosineSimilarity).toString().match(/^\d+(?:\.\d{0,2})?/))).toBe(0.66);
104+
expect(similarNodes[0].id).toBe('node-11');
105+
});
77106
});

0 commit comments

Comments
 (0)