Skip to content

Commit ab32168

Browse files
Merge pull request #9651 from IoannisPanagiotas/node2vec-consistency-conc1
Guarantee node2vec is deterministic for concurrency 1 (and set seed)
2 parents f61be96 + 68eb654 commit ab32168

File tree

8 files changed

+53
-34
lines changed

8 files changed

+53
-34
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/NegativeSampleProducer.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,25 @@
2121

2222
import org.neo4j.gds.collections.ha.HugeLongArray;
2323

24-
import java.util.concurrent.ThreadLocalRandom;
24+
import java.util.SplittableRandom;
2525

2626
public class NegativeSampleProducer {
2727

2828
private final HugeLongArray contextNodeDistribution;
2929
private final long cumulativeProbability;
30+
private SplittableRandom splittableRandom;
3031

3132
public NegativeSampleProducer(
32-
HugeLongArray contextNodeDistribution
33+
HugeLongArray contextNodeDistribution,
34+
long randomSeed
3335
) {
36+
this.splittableRandom = new SplittableRandom(randomSeed);
3437
this.contextNodeDistribution = contextNodeDistribution;
3538
this.cumulativeProbability = contextNodeDistribution.get(contextNodeDistribution.size() - 1);
3639
}
3740

3841
public long next() {
39-
long index = contextNodeDistribution.binarySearch(ThreadLocalRandom.current().nextLong(cumulativeProbability));
40-
42+
long index = contextNodeDistribution.binarySearch(splittableRandom.nextLong(cumulativeProbability));
4143
if (index < contextNodeDistribution.size() - 1) {
4244
index++;
4345
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22+
import org.neo4j.gds.collections.ha.HugeLongArray;
2223
import org.neo4j.gds.collections.ha.HugeObjectArray;
2324
import org.neo4j.gds.core.concurrency.Concurrency;
2425
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
@@ -40,7 +41,6 @@
4041

4142
public class Node2VecModel {
4243

43-
private final NegativeSampleProducer negativeSamples;
4444

4545
private final HugeObjectArray<FloatVector> centerEmbeddings;
4646
private final HugeObjectArray<FloatVector> contextEmbeddings;
@@ -114,7 +114,6 @@ public class Node2VecModel {
114114
this.walks = walks;
115115
this.randomWalkProbabilities = randomWalkProbabilities;
116116
this.progressTracker = progressTracker;
117-
this.negativeSamples = new NegativeSampleProducer(randomWalkProbabilities.negativeSamplingDistribution());
118117
this.randomSeed = maybeRandomSeed.orElseGet(() -> new SplittableRandom().nextLong());
119118

120119
var random = new Random();
@@ -145,18 +144,20 @@ Node2VecResult train() {
145144
var positiveSampleProducer = new PositiveSampleProducer(
146145
walks.iterator(partition.startNode(), partition.nodeCount()),
147146
randomWalkProbabilities.positiveSamplingProbabilities(),
148-
windowSize
147+
windowSize,
148+
Optional.of(randomSeed)
149149
);
150150

151151
return new TrainingTask(
152152
centerEmbeddings,
153153
contextEmbeddings,
154154
positiveSampleProducer,
155-
negativeSamples,
155+
randomWalkProbabilities.negativeSamplingDistribution(),
156156
learningRate,
157157
negativeSamplingRate,
158158
embeddingDimension,
159-
progressTracker
159+
progressTracker,
160+
randomSeed
160161
);
161162
}
162163
);
@@ -226,16 +227,17 @@ private TrainingTask(
226227
HugeObjectArray<FloatVector> centerEmbeddings,
227228
HugeObjectArray<FloatVector> contextEmbeddings,
228229
PositiveSampleProducer positiveSampleProducer,
229-
NegativeSampleProducer negativeSampleProducer,
230+
HugeLongArray negativeSamples,
230231
float learningRate,
231232
int negativeSamplingRate,
232233
int embeddingDimensions,
233-
ProgressTracker progressTracker
234+
ProgressTracker progressTracker,
235+
long randomSeed
234236
) {
235237
this.centerEmbeddings = centerEmbeddings;
236238
this.contextEmbeddings = contextEmbeddings;
237239
this.positiveSampleProducer = positiveSampleProducer;
238-
this.negativeSampleProducer = negativeSampleProducer;
240+
this.negativeSampleProducer = new NegativeSampleProducer(negativeSamples, randomSeed + Thread.currentThread().getId());
239241
this.learningRate = learningRate;
240242
this.negativeSamplingRate = negativeSamplingRate;
241243

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducer.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2323

2424
import java.util.Iterator;
25+
import java.util.Optional;
26+
import java.util.SplittableRandom;
2527
import java.util.concurrent.ThreadLocalRandom;
2628

2729
import static org.neo4j.gds.mem.BitUtil.ceilDiv;
@@ -40,11 +42,13 @@ public class PositiveSampleProducer {
4042
private int contextWordIndex;
4143
private int currentWindowStart;
4244
private int currentWindowEnd;
45+
private SplittableRandom probabilitySupplier;
4346

4447
PositiveSampleProducer(
4548
Iterator<long[]> walks,
4649
HugeDoubleArray samplingProbabilities,
47-
int windowSize
50+
int windowSize,
51+
Optional<Long> maybeRandomSeed
4852
) {
4953
this.walks = walks;
5054
this.samplingProbabilities = samplingProbabilities;
@@ -55,6 +59,10 @@ public class PositiveSampleProducer {
5559
this.currentWalk = new long[0];
5660
this.centerWordIndex = -1;
5761
this.contextWordIndex = 1;
62+
probabilitySupplier = maybeRandomSeed
63+
.map(seed -> new SplittableRandom(Thread.currentThread().getId() + seed))
64+
.orElseGet(() -> new SplittableRandom(ThreadLocalRandom.current().nextLong()));
65+
5866
}
5967

6068
public boolean next(long[] buffer) {
@@ -134,7 +142,7 @@ private int filter(long[] walk) {
134142
}
135143

136144
private boolean shouldPickNode(long nodeId) {
137-
return ThreadLocalRandom.current().nextDouble(0, 1) < samplingProbabilities.get(nodeId);
145+
return probabilitySupplier.nextDouble(0,1) < samplingProbabilities.get(nodeId);
138146
}
139147

140148
// We need to adjust the window size for a given center word to ignore filtered nodes that might occur in the window

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/NegativeSampleProducerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void shouldProduceSamplesAccordingToNodeDistribution() {
4747

4848
RandomWalkProbabilities probabilityComputer = builder.build();
4949

50-
var sampler = new NegativeSampleProducer(probabilityComputer.negativeSamplingDistribution());
50+
var sampler = new NegativeSampleProducer(probabilityComputer.negativeSamplingDistribution(),0);
5151

5252
Map<Long, Integer> distribution = IntStream
5353
.range(0, 1300)

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecMemoryEstimateDefinitionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void shouldEstimateMemory() {
4141

4242
MemoryEstimationAssert.assertThat(memoryEstimation)
4343
.memoryRange(1000, new Concurrency(1))
44-
.hasSameMinAndMaxEqualTo(7_688_464L);
44+
.hasSameMinAndMaxEqualTo(7688456L);
4545
}
4646

4747
}

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22-
import org.junit.jupiter.api.Disabled;
2322
import org.junit.jupiter.api.Test;
2423
import org.junit.jupiter.params.ParameterizedTest;
2524
import org.junit.jupiter.params.provider.ValueSource;
@@ -141,7 +140,8 @@ void testModel() {
141140
);
142141
}
143142

144-
@Disabled("The order of the randomWalks + its usage in the training is not deterministic yet.")
143+
// @Disabled("The order of the randomWalks + its usage in the training is not deterministic yet.")
144+
//We can only guarantee consstency for concurrency 1
145145
@ParameterizedTest
146146
@ValueSource(ints = {0, 1, 4})
147147
void randomSeed(int iterations) {
@@ -168,7 +168,7 @@ void randomSeed(int iterations) {
168168
nodeId -> nodeId,
169169
nodeCount,
170170
trainParameters,
171-
new Concurrency(4),
171+
new Concurrency(1),
172172
Optional.of(1337L),
173173
walks,
174174
probabilitiesBuilder.build(),
@@ -179,7 +179,7 @@ void randomSeed(int iterations) {
179179
nodeId -> nodeId,
180180
nodeCount,
181181
trainParameters,
182-
new Concurrency(4),
182+
new Concurrency(1),
183183
Optional.of(1337L),
184184
walks,
185185
probabilitiesBuilder.build(),

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,16 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22-
import org.assertj.core.api.SoftAssertions;
2322
import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension;
2423
import org.assertj.core.data.Offset;
25-
import org.junit.jupiter.api.Disabled;
2624
import org.junit.jupiter.api.Test;
2725
import org.junit.jupiter.api.extension.ExtendWith;
2826
import org.junit.jupiter.params.ParameterizedTest;
2927
import org.junit.jupiter.params.provider.Arguments;
3028
import org.junit.jupiter.params.provider.CsvSource;
3129
import org.junit.jupiter.params.provider.EnumSource;
3230
import org.junit.jupiter.params.provider.MethodSource;
31+
import org.junit.jupiter.params.provider.ValueSource;
3332
import org.neo4j.gds.NodeLabel;
3433
import org.neo4j.gds.Orientation;
3534
import org.neo4j.gds.RelationshipType;
@@ -231,9 +230,10 @@ void failOnNegativeWeights() {
231230

232231
}
233232

234-
@Disabled("The order of the randomWalks + its usage in the training is not deterministic yet.")
235-
@Test
236-
void randomSeed(SoftAssertions softly) {
233+
//"The order of the randomWalks + its usage in the training is not deterministic yet. Can guarantee only for concurrency 1")
234+
@ParameterizedTest
235+
@ValueSource(ints= {1})
236+
void randomSeed(int concurrency) {
237237

238238

239239
int embeddingDimension = 2;
@@ -242,7 +242,7 @@ void randomSeed(SoftAssertions softly) {
242242

243243
var embeddings = new Node2Vec(
244244
graph,
245-
new Concurrency(4),
245+
new Concurrency(concurrency),
246246
NO_SOURCE_NODES,
247247
Optional.of(1337L),
248248
1000,
@@ -253,7 +253,7 @@ void randomSeed(SoftAssertions softly) {
253253

254254
var otherEmbeddings = new Node2Vec(
255255
graph,
256-
new Concurrency(4),
256+
new Concurrency(concurrency),
257257
NO_SOURCE_NODES,
258258
Optional.of(1337L),
259259
1000,
@@ -263,7 +263,7 @@ void randomSeed(SoftAssertions softly) {
263263
).compute().embeddings();
264264

265265
for (long node = 0; node < graph.nodeCount(); node++) {
266-
softly.assertThat(otherEmbeddings.get(node)).isEqualTo(embeddings.get(node));
266+
assertThat(otherEmbeddings.get(node)).isEqualTo(embeddings.get(node));
267267
}
268268
}
269269

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducerTest.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.ArrayList;
3030
import java.util.Collection;
3131
import java.util.List;
32+
import java.util.Optional;
3233
import java.util.stream.LongStream;
3334
import java.util.stream.Stream;
3435

@@ -59,7 +60,8 @@ void doesNotCauseStackOverflow() {
5960
var sampleProducer = new PositiveSampleProducer(
6061
walks.iterator(0, nbrOfWalks),
6162
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
62-
10
63+
10,
64+
Optional.empty()
6365
);
6466

6567
var counter = 0L;
@@ -86,7 +88,8 @@ void doesNotCauseStackOverflowDueToBadLuck() {
8688
var sampleProducer = new PositiveSampleProducer(
8789
walks.iterator(0, nbrOfWalks),
8890
probabilities,
89-
10
91+
10,
92+
Optional.empty()
9093
);
9194
// does not overflow the stack = passes test
9295

@@ -109,7 +112,8 @@ void doesNotAttemptToFetchOutsideBatch() {
109112
var sampleProducer = new PositiveSampleProducer(
110113
walks.iterator(0, nbrOfWalks / 2),
111114
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
112-
10
115+
10,
116+
Optional.empty()
113117
);
114118

115119
var counter = 0L;
@@ -133,7 +137,8 @@ void shouldProducePairsWith(
133137
PositiveSampleProducer producer = new PositiveSampleProducer(
134138
walks.iterator(0, walks.size()),
135139
centerNodeProbabilities,
136-
windowSize
140+
windowSize,
141+
Optional.empty()
137142
);
138143
while (producer.next(buffer)) {
139144
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -155,7 +160,8 @@ void shouldProducePairsWithBounds() {
155160
PositiveSampleProducer producer = new PositiveSampleProducer(
156161
walks.iterator(0, 2),
157162
centerNodeProbabilities,
158-
3
163+
3,
164+
Optional.empty()
159165
);
160166
while (producer.next(buffer)) {
161167
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -200,7 +206,8 @@ void shouldRemoveDownsampledWordFromWalk() {
200206
PositiveSampleProducer producer = new PositiveSampleProducer(
201207
walks.iterator(0, walks.size()),
202208
centerNodeProbabilities,
203-
3
209+
3,
210+
Optional.empty()
204211
);
205212

206213
while (producer.next(buffer)) {

0 commit comments

Comments
 (0)