Skip to content

Commit

Permalink
Merge pull request #8932 from s1ck/moar-parameter-records
Browse files Browse the repository at this point in the history
Replace parameter classes with records pt2
  • Loading branch information
s1ck authored Apr 11, 2024
2 parents e1f7c6c + b173724 commit 15357d2
Show file tree
Hide file tree
Showing 22 changed files with 90 additions and 223 deletions.
27 changes: 13 additions & 14 deletions algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class Node2Vec extends Algorithm<Node2VecResult> {

private final Graph graph;
private final int concurrency;
private final WalkParameters walkParameters;
private final SamplingWalkParameters samplingWalkParameters;
private final List<Long> sourceNodes;
private final Optional<Long> maybeRandomSeed;
private final TrainParameters trainParameters;
Expand All @@ -52,18 +52,17 @@ public Node2Vec(
List<Long> sourceNodes,
Optional<Long> maybeRandomSeed,
int walkBufferSize,
WalkParameters walkParameters,
TrainParameters trainParameters,
Node2VecParameters node2VecParameters,
ProgressTracker progressTracker
) {
super(progressTracker);
this.graph = graph;
this.concurrency = concurrency;
this.walkParameters = walkParameters;
this.samplingWalkParameters = node2VecParameters.samplingWalkParameters();
this.walkBufferSize = walkBufferSize;
this.sourceNodes = sourceNodes;
this.maybeRandomSeed = maybeRandomSeed;
this.trainParameters = trainParameters;
this.trainParameters = node2VecParameters.trainParameters();
}

@Override
Expand All @@ -83,10 +82,10 @@ public Node2VecResult compute() {
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
graph.nodeCount(),
concurrency,
walkParameters.positiveSamplingFactor,
walkParameters.negativeSamplingExponent
samplingWalkParameters.positiveSamplingFactor(),
samplingWalkParameters.negativeSamplingExponent()
);
var walks = new CompressedRandomWalks(graph.nodeCount() * walkParameters.walksPerNode);
var walks = new CompressedRandomWalks(graph.nodeCount() * samplingWalkParameters.walksPerNode());

progressTracker.beginSubTask("RandomWalk");

Expand All @@ -97,7 +96,7 @@ public Node2VecResult compute() {
maybeRandomSeed,
concurrency,
sourceNodes,
walkParameters,
samplingWalkParameters,
walkBufferSize,
DefaultPool.INSTANCE,
progressTracker,
Expand Down Expand Up @@ -143,7 +142,7 @@ private List<Node2VecRandomWalkTask> walkTasks(
Optional<Long> maybeRandomSeed,
int concurrency,
List<Long> sourceNodes,
WalkParameters walkParameters,
SamplingWalkParameters samplingWalkParameters,
int walkBufferSize,
ExecutorService executorService,
ProgressTracker progressTracker,
Expand All @@ -164,7 +163,7 @@ private List<Node2VecRandomWalkTask> walkTasks(
tasks.add(new Node2VecRandomWalkTask(
graph.concurrentCopy(),
nextNodeSupplier,
walkParameters.walksPerNode,
samplingWalkParameters.walksPerNode(),
cumulativeWeightsSupplier,
progressTracker,
terminationFlag,
Expand All @@ -173,9 +172,9 @@ private List<Node2VecRandomWalkTask> walkTasks(
randomWalkPropabilitiesBuilder,
walkBufferSize,
randomSeed,
walkParameters.walkLength,
walkParameters.returnFactor,
walkParameters.inOutFactor
samplingWalkParameters.walkLength(),
samplingWalkParameters.returnFactor(),
samplingWalkParameters.inOutFactor()
));
}
return tasks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ public Node2Vec build(
configuration.sourceNodes(),
configuration.randomSeed(),
configuration.walkBufferSize(),
configuration.walkParameters(),
configuration.trainParameters(),
configuration.node2VecParameters(),
progressTracker
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,17 @@ default List<Long> sourceNodes() {

@Configuration.Ignore
default Node2VecParameters node2VecParameters() {
return new Node2VecParameters(walkParameters(), trainParameters());
}
@Configuration.Ignore
default WalkParameters walkParameters() {
return new WalkParameters(
walksPerNode(),
walkLength(),
returnFactor(),
inOutFactor(),
var walkParameters = walkParameters();

var samplingWalkParameters = new SamplingWalkParameters(
walkParameters.walksPerNode(),
walkParameters.walkLength(),
walkParameters.returnFactor(),
walkParameters.inOutFactor(),
positiveSamplingFactor(),
negativeSamplingExponent()
);
}

@Configuration.Ignore
default TrainParameters trainParameters() {
return new TrainParameters(
var trainParameters = new TrainParameters(
initialLearningRate(),
minLearningRate(),
iterations(),
Expand All @@ -108,5 +102,7 @@ default TrainParameters trainParameters() {
embeddingDimension(),
embeddingInitializer()
);

return new Node2VecParameters(samplingWalkParameters, trainParameters);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public Node2VecMemoryEstimateDefinition(Node2VecParameters parameters) {

@Override
public MemoryEstimation memoryEstimation() {
int walksPerNode = parameters.walkParameters().walksPerNode;
int walkLength = parameters.walkParameters().walkLength;
int walksPerNode = parameters.samplingWalkParameters().walksPerNode();
int walkLength = parameters.samplingWalkParameters().walkLength();
int embeddingDimension = parameters.trainParameters().embeddingDimension();
return MemoryEstimations.builder(Node2Vec.class)
.perNode("random walks", (nodeCount) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
import org.neo4j.gds.annotation.Parameters;

@Parameters
public record Node2VecParameters(WalkParameters walkParameters, TrainParameters trainParameters) {
public record Node2VecParameters(SamplingWalkParameters samplingWalkParameters, TrainParameters trainParameters) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,12 @@
import org.neo4j.gds.annotation.Parameters;

@Parameters
public class WalkParameters extends org.neo4j.gds.traversal.WalkParameters {
final double negativeSamplingExponent;
final double positiveSamplingFactor;

public WalkParameters(
int walksPerNode,
int walkLength,
double returnFactor,
double inOutFactor,
double positiveSamplingFactor,
double negativeSamplingExponent
) {
super(walksPerNode, walkLength, returnFactor, inOutFactor);
this.negativeSamplingExponent = negativeSamplingExponent;
this.positiveSamplingFactor = positiveSamplingFactor;
}
public record SamplingWalkParameters(
int walksPerNode,
int walkLength,
double returnFactor,
double inOutFactor,
double positiveSamplingFactor,
double negativeSamplingExponent
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ private LongUnaryOperator initComponents() {

// run WCC to determine components
progressTracker.beginSubTask();
var wccParameters = WccParameters.create(0D, null, concurrency);
var wccParameters = new WccParameters(0D, concurrency);
Wcc wcc = new WccAlgorithmFactory<>().build(graph, wccParameters, ProgressTracker.NULL_TRACKER);
DisjointSetStruct disjointSets = wcc.compute();
progressTracker.endSubTask();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ default double delta() {

@Configuration.Ignore
default SteinerTreeParameters toParameters() {
return SteinerTreeParameters.create(concurrency(), sourceNode(), targetNodes(), delta(), applyRerouting());
return new SteinerTreeParameters(concurrency(), sourceNode(), targetNodes(), delta(), applyRerouting());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,54 +24,11 @@
import java.util.List;

@Parameters
public final class SteinerTreeParameters {
static SteinerTreeParameters create(int concurrency, long sourceNode, List<Long> targetNodes, double delta, boolean applyRerouting) {
return new SteinerTreeParameters(
concurrency,
sourceNode,
targetNodes,
delta,
applyRerouting
);
}

private final int concurrency;
private final long sourceNode;
private final List<Long> targetNodes;
private final double delta;
private final boolean applyRerouting;

private SteinerTreeParameters(
int concurrency,
long sourceNode,
List<Long> targetNodes,
double delta,
boolean applyRerouting
) {
this.concurrency = concurrency;
this.sourceNode = sourceNode;
this.targetNodes = targetNodes;
this.delta = delta;
this.applyRerouting = applyRerouting;
}

public int concurrency() {
return concurrency;
}

public long sourceNode() {
return sourceNode;
}

public List<Long> targetNodes() {
return targetNodes;
}

public double delta() {
return delta;
}

public boolean applyRerouting() {
return applyRerouting;
}
public record SteinerTreeParameters(
int concurrency,
long sourceNode,
List<Long> targetNodes,
double delta,
boolean applyRerouting
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ public RandomWalkTask get() {
nextNodeSupplier,
cumulativeWeightSupplier,
walks,
walkParameters.walksPerNode,
walkParameters.walkLength,
walkParameters.returnFactor,
walkParameters.inOutFactor,
walkParameters.walksPerNode(),
walkParameters.walkLength(),
walkParameters.returnFactor(),
walkParameters.inOutFactor(),
randomSeed,
progressTracker,
terminationFlag
Expand Down
16 changes: 1 addition & 15 deletions algo/src/main/java/org/neo4j/gds/traversal/WalkParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,6 @@

import org.neo4j.gds.annotation.Parameters;

/**
* Parameter object holding Random Walk parameters.
*/
@Parameters
public class WalkParameters {
public final int walksPerNode;
public final int walkLength;
public final double returnFactor;
public final double inOutFactor;

public WalkParameters(int walksPerNode, int walkLength, double returnFactor, double inOutFactor) {
this.walksPerNode = walksPerNode;
this.walkLength = walkLength;
this.returnFactor = returnFactor;
this.inOutFactor = inOutFactor;
}
public record WalkParameters(int walksPerNode, int walkLength, double returnFactor, double inOutFactor) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ static TriangleCountBaseConfig of(CypherMapWrapper userInput) {

@Configuration.Ignore
default TriangleCountParameters toParameters() {
return TriangleCountParameters.create(concurrency(), maxDegree());
return new TriangleCountParameters(concurrency(), maxDegree());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,5 @@
import org.neo4j.gds.annotation.Parameters;

@Parameters
public final class TriangleCountParameters {

public static TriangleCountParameters create(int concurrency, long maxDegree) {
return new TriangleCountParameters(concurrency, maxDegree);
}
private final int concurrency;
private final long maxDegree;

private TriangleCountParameters(int concurrency, long maxDegree) {
this.concurrency = concurrency;
this.maxDegree = maxDegree;
}

public int concurrency() {
return concurrency;
}

public long maxDegree() {
return maxDegree;
}
public record TriangleCountParameters(int concurrency, long maxDegree) {
}
9 changes: 3 additions & 6 deletions algo/src/main/java/org/neo4j/gds/wcc/WccBaseConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@
import org.neo4j.gds.config.RelationshipWeightConfig;
import org.neo4j.gds.config.SeedConfig;

import java.util.Optional;

public interface WccBaseConfig extends AlgoBaseConfig, SeedConfig, ConsecutiveIdsConfig, RelationshipWeightConfig {

default double threshold() {
return 0D;
}

@Configuration.Ignore
default boolean hasThreshold() {
return !Double.isNaN(threshold()) && threshold() > 0;
}

@Configuration.Check
default void validate() {
if (threshold() > 0 && relationshipWeightProperty().isEmpty()) {
Expand All @@ -45,6 +42,6 @@ default void validate() {

@Configuration.Ignore
default WccParameters toParameters() {
return WccParameters.create(threshold(), seedProperty(), concurrency());
return new WccParameters(threshold(), Optional.ofNullable(seedProperty()), concurrency());
}
}
Loading

0 comments on commit 15357d2

Please sign in to comment.