Skip to content

Commit 6a7f24c

Browse files
committed
concatenate2d (#14)
1 parent 4faf17e commit 6a7f24c

File tree

4 files changed

+186
-2
lines changed

4 files changed

+186
-2
lines changed

src/merge/Concatenate.js

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
* @author syt123450 / https://github.com/syt123450
33
*/
44

5+
import { MergedLayer2d } from "../layer/abstract/MergedLayer2d";
56
import { MergedLayer3d } from "../layer/abstract/MergedLayer3d";
67

78
function Concatenate( layerList, config ) {
@@ -44,6 +45,14 @@ function Concatenate( layerList, config ) {
4445

4546
} else if ( layerList[ 0 ].layerDimension === 2 ) {
4647

48+
return new MergedLayer2d( {
49+
50+
operator: operatorType,
51+
mergedElements: layerList,
52+
userConfig: userConfig
53+
54+
} );
55+
4756
} else if ( layerList[ 0 ].layerDimension === 3 ) {
4857

4958
return new MergedLayer3d( {

src/merge/factory/StrategyFactory.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { Subtract2d } from "../strategy/Subtract2d";
1414
import { Maximum2d } from "../strategy/Maximum2d";
1515
import { Average2d } from "../strategy/Average2d";
1616
import { Multiply2d } from "../strategy/Multiply2d";
17+
import { Concatenate2d } from "../strategy/Concatenate2d";
1718

1819
let StrategyFactory = ( function() {
1920

@@ -59,7 +60,7 @@ let StrategyFactory = ( function() {
5960

6061
} else if ( operator === "concatenate" ) {
6162

62-
63+
return new Concatenate2d( mergedElements );
6364

6465
} else if ( operator === "subtract" ) {
6566

src/merge/strategy/Concatenate2d.js

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import { Strategy2d } from "../abstract/Strategy2d";
2+
3+
function Concatenate2d( mergedElements ) {
4+
5+
Strategy2d.call( this, mergedElements );
6+
7+
this.strategyType = "Concatenate2d";
8+
9+
}
10+
11+
Concatenate2d.prototype = Object.assign( Object.create( Strategy2d.prototype ), {
12+
13+
validate: function() {
14+
15+
let inputShape = this.mergedElements[ 0 ].outputShape;
16+
17+
for ( let i = 0; i < this.mergedElements.length; i ++ ) {
18+
19+
let layerShape = this.mergedElements[ i ].outputShape;
20+
21+
if ( layerShape[ 0 ] !== inputShape[ 0 ] ) {
22+
23+
return false;
24+
25+
}
26+
27+
}
28+
29+
return true;
30+
31+
},
32+
33+
getOutputShape: function() {
34+
35+
let width = this.mergedElements[ 0 ].outputShape[ 0 ];
36+
let depth = 0;
37+
38+
for (let i = 0; i < this.mergedElements.length; i ++) {
39+
40+
depth += this.mergedElements[ i ].outputShape[ 1 ];
41+
42+
}
43+
44+
return [ width, depth ];
45+
46+
},
47+
48+
getRelativeElements: function( selectedElement ) {
49+
50+
let curveElements = [];
51+
let straightElements = [];
52+
53+
if ( selectedElement.elementType === "aggregationElement" ) {
54+
55+
let request = {
56+
57+
all: true
58+
59+
};
60+
61+
for ( let i = 0; i < this.mergedElements.length; i++ ) {
62+
63+
let relativeResult = this.mergedElements[ i ].provideRelativeElements( request );
64+
let relativeElements = relativeResult.elementList;
65+
66+
if ( this.mergedElements[ i ].layerIndex === this.layerIndex - 1 ) {
67+
68+
for ( let j = 0; j < relativeElements.length; j ++ ) {
69+
70+
straightElements.push( relativeElements[ j ] );
71+
72+
}
73+
74+
} else {
75+
76+
if ( relativeResult.isOpen ) {
77+
78+
for ( let j = 0; j < relativeElements.length; j ++ ) {
79+
80+
straightElements.push( relativeElements[ j ] );
81+
82+
}
83+
84+
} else {
85+
86+
for ( let j = 0; j < relativeElements.length; j ++ ) {
87+
88+
curveElements.push( relativeElements[ j ] );
89+
90+
}
91+
92+
}
93+
94+
}
95+
96+
}
97+
98+
} else if ( selectedElement.elementType === "gridLine" ) {
99+
100+
let gridIndex = selectedElement.gridIndex;
101+
102+
let relativeLayer;
103+
104+
for ( let i = 0; i < this.mergedElements.length; i ++ ) {
105+
106+
let layerDepth = this.mergedElements[ i ].outputShape[ 1 ];
107+
108+
if ( layerDepth > gridIndex ) {
109+
110+
relativeLayer = this.mergedElements[ i ];
111+
break;
112+
113+
} else {
114+
115+
gridIndex -= layerDepth;
116+
117+
}
118+
119+
}
120+
121+
let request = {
122+
123+
index: gridIndex
124+
125+
};
126+
127+
let relativeResult = relativeLayer.provideRelativeElements( request );
128+
let relativeElements = relativeResult.elementList;
129+
130+
if ( relativeLayer.layerIndex === this.layerIndex - 1 ) {
131+
132+
for ( let i = 0; i < relativeElements.length; i ++ ) {
133+
134+
straightElements.push( relativeElements[ i ] );
135+
136+
}
137+
138+
} else {
139+
140+
if ( relativeResult.isOpen ) {
141+
142+
for ( let i = 0; i < relativeElements.length; i ++ ) {
143+
144+
straightElements.push( relativeElements[ i ] );
145+
146+
}
147+
148+
} else {
149+
150+
for ( let i = 0; i < relativeElements.length; i ++ ) {
151+
152+
curveElements.push( relativeElements[ i ] );
153+
154+
}
155+
156+
}
157+
158+
}
159+
160+
161+
}
162+
163+
return {
164+
165+
straight: straightElements,
166+
curve: curveElements
167+
168+
};
169+
170+
}
171+
172+
} );
173+
174+
export { Concatenate2d };

test/test.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575
model.add( conv1d2 );
7676

77-
let addLayer = TSP.layers.Multiply( [ conv1d1, conv1d2 ] );
77+
let addLayer = TSP.layers.Concatenate( [ conv1d1, conv1d2 ] );
7878

7979
model.add( addLayer );
8080

0 commit comments

Comments
 (0)