Skip to content

Commit

Permalink
Memory store segment leader election (#1533)
Browse files Browse the repository at this point in the history
The memory storage backend didn't implement any if the IDistributedStorage leader election methods.
This was ok until we started using leader election for segment scheduling, which doesn't require to run in a distributed mode.
As a consequence, Reaper using the mem store would schedule one new segment at each poll, even if the replicas were already busy processing another segment.

This PR introduces a class which manages locks on replicas for segments and moves the required methods from IDistributedStorage to IStorage so they can be implemented in the memory storage implementation.
  • Loading branch information
adejanovski authored Dec 18, 2024
1 parent ee85307 commit 664236b
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ private void abortSegmentsWithNoLeaderNonIncremental(RepairRun repairRun, Collec
if (context.storage instanceof IDistributedStorage || !repairRunners.containsKey(repairRun.getId())) {
// When multiple Reapers are in use, we can get stuck segments when one instance is rebooted
// Any segment in RUNNING or STARTED state but with no leader should be killed
Set<UUID> leaders = context.storage instanceof IDistributedStorage
? ((IDistributedStorage) context.storage).getLockedSegmentsForRun(repairRun.getId())
: Collections.emptySet();
Set<UUID> leaders = context.storage.getLockedSegmentsForRun(repairRun.getId());

Collection<RepairSegment> orphanedSegments = runningSegments
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -869,10 +869,8 @@ private boolean takeLead(RepairSegment segment) {
? ((IDistributedStorage) context.storage).takeLead(leaderElectionId)
: true;
} else {
result = context.storage instanceof IDistributedStorage
? ((IDistributedStorage) context.storage).lockRunningRepairsForNodes(this.repairRunner.getRepairRunId(),
segment.getId(), segment.getReplicas().keySet())
: true;
result = context.storage.lockRunningRepairsForNodes(this.repairRunner.getRepairRunId(),
segment.getId(), segment.getReplicas().keySet());
}
if (!result) {
context.metricRegistry.counter(MetricRegistry.name(SegmentRunner.class, "takeLead", "failed")).inc();
Expand All @@ -895,10 +893,8 @@ private boolean renewLead(RepairSegment segment) {
}
return result;
} else {
boolean resultLock2 = context.storage instanceof IDistributedStorage
? ((IDistributedStorage) context.storage).renewRunningRepairsForNodes(this.repairRunner.getRepairRunId(),
segment.getId(), segment.getReplicas().keySet())
: true;
boolean resultLock2 = context.storage.renewRunningRepairsForNodes(this.repairRunner.getRepairRunId(),
segment.getId(), segment.getReplicas().keySet());
if (!resultLock2) {
context.metricRegistry.counter(MetricRegistry.name(SegmentRunner.class, "renewLead", "failed")).inc();
releaseLead(segment);
Expand All @@ -912,13 +908,14 @@ private boolean renewLead(RepairSegment segment) {
private void releaseLead(RepairSegment segment) {
try (Timer.Context cx
= context.metricRegistry.timer(MetricRegistry.name(SegmentRunner.class, "releaseLead")).time()) {
if (context.storage instanceof IDistributedStorage) {
if (repairUnit.getIncrementalRepair() && !repairUnit.getSubrangeIncrementalRepair()) {

if (repairUnit.getIncrementalRepair() && !repairUnit.getSubrangeIncrementalRepair()) {
if (context.storage instanceof IDistributedStorage) {
((IDistributedStorage) context.storage).releaseLead(leaderElectionId);
} else {
((IDistributedStorage) context.storage).releaseRunningRepairsForNodes(this.repairRunner.getRepairRunId(),
segment.getId(), segment.getReplicas().keySet());
}
} else {
context.storage.releaseRunningRepairsForNodes(this.repairRunner.getRepairRunId(),
segment.getId(), segment.getReplicas().keySet());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import io.cassandrareaper.storage.operations.IOperationsDao;

import java.util.List;
import java.util.Set;
import java.util.UUID;


Expand All @@ -62,23 +61,6 @@ public interface IDistributedStorage extends IDistributedMetrics {

void releaseLead(UUID leaderId);

boolean lockRunningRepairsForNodes(
UUID repairId,
UUID segmentId,
Set<String> replicas);

boolean renewRunningRepairsForNodes(
UUID repairId,
UUID segmentId,
Set<String> replicas);

boolean releaseRunningRepairsForNodes(
UUID repairId,
UUID segmentId,
Set<String> replicas);

Set<UUID> getLockedSegmentsForRun(UUID runId);

int countRunningReapers();

List<UUID> getRunningReapers();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import io.cassandrareaper.storage.repairunit.IRepairUnitDao;
import io.cassandrareaper.storage.snapshot.ISnapshotDao;

import java.util.Set;
import java.util.UUID;

import io.dropwizard.lifecycle.Managed;

/**
Expand All @@ -34,6 +37,23 @@
public interface IStorageDao extends Managed,
IMetricsDao {

boolean lockRunningRepairsForNodes(
UUID repairId,
UUID segmentId,
Set<String> replicas);

boolean renewRunningRepairsForNodes(
UUID repairId,
UUID segmentId,
Set<String> replicas);

boolean releaseRunningRepairsForNodes(
UUID repairId,
UUID segmentId,
Set<String> replicas);

Set<UUID> getLockedSegmentsForRun(UUID runId);

boolean isStorageConnected();

IEventsDao getEventsDao();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.cassandrareaper.storage.events.IEventsDao;
import io.cassandrareaper.storage.events.MemoryEventsDao;
import io.cassandrareaper.storage.memory.MemoryStorageRoot;
import io.cassandrareaper.storage.memory.ReplicaLockManagerWithTtl;
import io.cassandrareaper.storage.metrics.MemoryMetricsDao;
import io.cassandrareaper.storage.repairrun.IRepairRunDao;
import io.cassandrareaper.storage.repairrun.MemoryRepairRunDao;
Expand All @@ -46,9 +47,11 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;

import com.google.common.io.Files;
import org.eclipse.serializer.persistence.types.PersistenceFieldEvaluator;
import org.eclipse.store.storage.embedded.types.EmbeddedStorage;
import org.eclipse.store.storage.embedded.types.EmbeddedStorageManager;
Expand All @@ -61,8 +64,9 @@
*/
public final class MemoryStorageFacade implements IStorageDao {

// Default time to live of leads taken on a segment
private static final long DEFAULT_LEAD_TTL = 90_000;
private static final Logger LOG = LoggerFactory.getLogger(MemoryStorageFacade.class);

/** Field evaluator to find transient attributes. This is needed to deal with persisting Guava collections objects
* that sometimes use the transient keyword for some of their implementation's backing stores**/
private static final PersistenceFieldEvaluator TRANSIENT_FIELD_EVALUATOR =
Expand All @@ -85,8 +89,9 @@ public final class MemoryStorageFacade implements IStorageDao {
);
private final MemorySnapshotDao memSnapshotDao = new MemorySnapshotDao();
private final MemoryMetricsDao memMetricsDao = new MemoryMetricsDao();
private final ReplicaLockManagerWithTtl replicaLockManagerWithTtl;

public MemoryStorageFacade(String persistenceStoragePath) {
public MemoryStorageFacade(String persistenceStoragePath, long leadTime) {
LOG.info("Using memory storage backend. Persistence storage path: {}", persistenceStoragePath);
this.embeddedStorage = EmbeddedStorage.Foundation(Paths.get(persistenceStoragePath))
.onConnectionFoundation(
Expand All @@ -103,10 +108,19 @@ public MemoryStorageFacade(String persistenceStoragePath) {
LOG.info("Loading existing data from persistence storage");
this.memoryStorageRoot = (MemoryStorageRoot) this.embeddedStorage.root();
}
this.replicaLockManagerWithTtl = new ReplicaLockManagerWithTtl(leadTime);
}

public MemoryStorageFacade() {
this("/tmp/" + UUID.randomUUID().toString());
this(Files.createTempDir().getAbsolutePath(), DEFAULT_LEAD_TTL);
}

public MemoryStorageFacade(String persistenceStoragePath) {
this(persistenceStoragePath, DEFAULT_LEAD_TTL);
}

public MemoryStorageFacade(long leadTime) {
this(Files.createTempDir().getAbsolutePath(), leadTime);
}

@Override
Expand Down Expand Up @@ -296,4 +310,25 @@ public Collection<RepairSegment> getRepairSegmentsByRunId(UUID runId) {
public Map<UUID, DiagEventSubscription> getSubscriptionsById() {
return this.memoryStorageRoot.getSubscriptionsById();
}

@Override
public boolean lockRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
return replicaLockManagerWithTtl.lockRunningRepairsForNodes(runId, segmentId, replicas);
}

@Override
public boolean renewRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
return replicaLockManagerWithTtl.renewRunningRepairsForNodes(runId, segmentId, replicas);
}

@Override
public boolean releaseRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
LOG.info("Releasing locks for runId: {}, segmentId: {}, replicas: {}", runId, segmentId, replicas);
return replicaLockManagerWithTtl.releaseRunningRepairsForNodes(runId, segmentId, replicas);
}

@Override
public Set<UUID> getLockedSegmentsForRun(UUID runId) {
return replicaLockManagerWithTtl.getLockedSegmentsForRun(runId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright 2024-2024 DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.cassandrareaper.storage.memory;

import java.util.Collections;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import com.google.common.annotations.VisibleForTesting;

public class ReplicaLockManagerWithTtl {

private final ConcurrentHashMap<String, LockInfo> replicaLocks = new ConcurrentHashMap<>();
private final ConcurrentHashMap<UUID, Set<UUID>> repairRunToSegmentLocks = new ConcurrentHashMap<>();
private final Lock lock = new ReentrantLock();

private final long ttlMilliSeconds;

public ReplicaLockManagerWithTtl(long ttlMilliSeconds) {
this.ttlMilliSeconds = ttlMilliSeconds;
// Schedule cleanup of expired locks
ScheduledExecutorService lockCleanupScheduler = Executors.newScheduledThreadPool(1);
lockCleanupScheduler.scheduleAtFixedRate(this::cleanupExpiredLocks, 1, 1, TimeUnit.SECONDS);
}

private String getReplicaLockKey(String replica, UUID runId) {
return replica + runId;
}

public boolean lockRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
lock.lock();
try {
long currentTime = System.currentTimeMillis();
// Check if any replica is already locked by another runId
boolean anyReplicaLocked = replicas.stream()
.map(replica -> replicaLocks.get(getReplicaLockKey(replica, runId)))
.anyMatch(lockInfo -> lockInfo != null
&& lockInfo.expirationTime > currentTime && lockInfo.runId.equals(runId));

if (anyReplicaLocked) {
return false; // Replica is locked by another runId and not expired
}

// Lock the replicas for the given runId and segmentId
long expirationTime = currentTime + ttlMilliSeconds;
replicas.forEach(replica ->
replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, expirationTime))
);

// Update runId to segmentId mapping
repairRunToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId);
return true;
} finally {
lock.unlock();
}
}

public boolean renewRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
lock.lock();
try {
long currentTime = System.currentTimeMillis();

// Check if all replicas are already locked by this runId
boolean allReplicasLocked = replicas.stream()
.map(replica -> replicaLocks.get(getReplicaLockKey(replica, runId)))
.allMatch(lockInfo -> lockInfo != null && lockInfo.runId.equals(runId)
&& lockInfo.expirationTime > currentTime);

if (!allReplicasLocked) {
return false; // Some replica is not validly locked by this runId
}

// Renew the lock by extending the expiration time
long newExpirationTime = currentTime + ttlMilliSeconds;
replicas.forEach(replica ->
replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, newExpirationTime))
);

// Ensure the segmentId is linked to the runId
repairRunToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId);
return true;
} finally {
lock.unlock();
}
}

public boolean releaseRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
lock.lock();
try {
// Remove the lock for replicas
replicas.stream()
.map(replica -> getReplicaLockKey(replica, runId))
.map(replicaLocks::get)
.filter(lockInfo -> lockInfo != null && lockInfo.runId.equals(runId))
.forEach(lockInfo -> replicaLocks.remove(getReplicaLockKey(lockInfo.runId.toString(), runId)));

// Remove the segmentId from the runId mapping
Set<UUID> segments = repairRunToSegmentLocks.get(runId);
if (segments != null) {
segments.remove(segmentId);
if (segments.isEmpty()) {
repairRunToSegmentLocks.remove(runId);
}
}
return true;
} finally {
lock.unlock();
}
}

public Set<UUID> getLockedSegmentsForRun(UUID runId) {
return repairRunToSegmentLocks.getOrDefault(runId, Collections.emptySet());
}

@VisibleForTesting
public void cleanupExpiredLocks() {
lock.lock();
try {
long currentTime = System.currentTimeMillis();

// Remove expired locks from replicaLocks
replicaLocks.entrySet().removeIf(entry -> entry.getValue().expirationTime <= currentTime);

// Clean up runToSegmentLocks by removing segments with no active replicas
repairRunToSegmentLocks.entrySet().removeIf(entry -> {
UUID runId = entry.getKey();
Set<UUID> segments = entry.getValue();

// Retain only active segments
segments.removeIf(segmentId -> {
boolean active = replicaLocks.values().stream()
.anyMatch(info -> info.runId.equals(runId));
return !active;
});
return segments.isEmpty();
});
} finally {
lock.unlock();
}
}

// Class to store lock information
private static class LockInfo {
UUID runId;
long expirationTime;

LockInfo(UUID runId, long expirationTime) {
this.runId = runId;
this.expirationTime = expirationTime;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ public void a_new_daily_repair_schedule_is_added_for_the_last_added_cluster_and_
params.put("intensity", "0.9");
params.put("scheduleDaysBetween", "1");
params.put("scheduleTriggerTime", DateTime.now().plusSeconds(1).toString());
params.put("segmentCountPerNode", "1");
ReaperTestJettyRunner runner = RUNNERS.get(RAND.nextInt(RUNNERS.size()));
Response response = runner.callReaper("POST", "/repair_schedule", Optional.of(params));
int responseStatus = response.getStatus();
Expand Down
Loading

0 comments on commit 664236b

Please sign in to comment.