Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
adejanovski committed Dec 13, 2024
1 parent 00f3981 commit ae5b482
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@
*/
public final class MemoryStorageFacade implements IStorageDao {

private static final long DEFAULT_LEAD_TIME = 90;
// Default time to live of leads taken on a segment
private static final long DEFAULT_LEAD_TTL = 90_000;
private static final Logger LOG = LoggerFactory.getLogger(MemoryStorageFacade.class);
/** Field evaluator to find transient attributes. This is needed to deal with persisting Guava collections objects
* that sometimes use the transient keyword for some of their implementation's backing stores**/
private static final PersistenceFieldEvaluator TRANSIENT_FIELD_EVALUATOR =
(clazz, field) -> !field.getName().startsWith("_");
private static final UUID REAPER_INSTANCE_ID = UUID.randomUUID();

private final EmbeddedStorageManager embeddedStorage;
private final MemoryStorageRoot memoryStorageRoot;
Expand All @@ -89,7 +89,7 @@ public final class MemoryStorageFacade implements IStorageDao {
);
private final MemorySnapshotDao memSnapshotDao = new MemorySnapshotDao();
private final MemoryMetricsDao memMetricsDao = new MemoryMetricsDao();
private final ReplicaLockManagerWithTtl repairRunLockManager;
private final ReplicaLockManagerWithTtl replicaLockManagerWithTtl;

public MemoryStorageFacade(String persistenceStoragePath, long leadTime) {
LOG.info("Using memory storage backend. Persistence storage path: {}", persistenceStoragePath);
Expand All @@ -108,15 +108,15 @@ public MemoryStorageFacade(String persistenceStoragePath, long leadTime) {
LOG.info("Loading existing data from persistence storage");
this.memoryStorageRoot = (MemoryStorageRoot) this.embeddedStorage.root();
}
this.repairRunLockManager = new ReplicaLockManagerWithTtl(leadTime);
this.replicaLockManagerWithTtl = new ReplicaLockManagerWithTtl(leadTime);
}

public MemoryStorageFacade() {
this(Files.createTempDir().getAbsolutePath(), DEFAULT_LEAD_TIME);
this(Files.createTempDir().getAbsolutePath(), DEFAULT_LEAD_TTL);
}

public MemoryStorageFacade(String persistenceStoragePath) {
this(persistenceStoragePath, DEFAULT_LEAD_TIME);
this(persistenceStoragePath, DEFAULT_LEAD_TTL);
}

public MemoryStorageFacade(long leadTime) {
Expand Down Expand Up @@ -313,22 +313,22 @@ public Map<UUID, DiagEventSubscription> getSubscriptionsById() {

@Override
public boolean lockRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
return repairRunLockManager.lockRunningRepairsForNodes(runId, segmentId, replicas);
return replicaLockManagerWithTtl.lockRunningRepairsForNodes(runId, segmentId, replicas);
}

@Override
public boolean renewRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
return repairRunLockManager.renewRunningRepairsForNodes(runId, segmentId, replicas);
return replicaLockManagerWithTtl.renewRunningRepairsForNodes(runId, segmentId, replicas);
}

@Override
public boolean releaseRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
LOG.info("Releasing locks for runId: {}, segmentId: {}, replicas: {}", runId, segmentId, replicas);
return repairRunLockManager.releaseRunningRepairsForNodes(runId, segmentId, replicas);
return replicaLockManagerWithTtl.releaseRunningRepairsForNodes(runId, segmentId, replicas);
}

@Override
public Set<UUID> getLockedSegmentsForRun(UUID runId) {
return repairRunLockManager.getLockedSegmentsForRun(runId);
return replicaLockManagerWithTtl.getLockedSegmentsForRun(runId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,92 +31,102 @@
public class ReplicaLockManagerWithTtl {

private final ConcurrentHashMap<String, LockInfo> replicaLocks = new ConcurrentHashMap<>();
private final ConcurrentHashMap<UUID, Set<UUID>> runToSegmentLocks = new ConcurrentHashMap<>();
private final Lock lock = new ReentrantLock();
private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
private final ConcurrentHashMap<UUID, Set<UUID>> repairRunToSegmentLocks = new ConcurrentHashMap<>();
private final ConcurrentHashMap<UUID, ReentrantLock> runIdLocks = new ConcurrentHashMap<>();

private final long ttlSeconds; // 1 minute
private final long ttlMilliSeconds;

public ReplicaLockManagerWithTtl(long ttlSeconds) {
this.ttlSeconds = ttlSeconds;
public ReplicaLockManagerWithTtl(long ttlMilliSeconds) {
this.ttlMilliSeconds = ttlMilliSeconds;
// Schedule cleanup of expired locks
scheduler.scheduleAtFixedRate(this::cleanupExpiredLocks, 1, 1, TimeUnit.SECONDS);
ScheduledExecutorService lockCleanupScheduler = Executors.newScheduledThreadPool(1);
lockCleanupScheduler.scheduleAtFixedRate(this::cleanupExpiredLocks, 1, 1, TimeUnit.SECONDS);
}

private String getReplicaLockKey(String replica, UUID runId) {
return replica + runId;
}

private Lock getLockForRunId(UUID runId) {
return runIdLocks.computeIfAbsent(runId, k -> new ReentrantLock());
}

public boolean lockRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
Lock lock = getLockForRunId(runId);
lock.lock();
try {
long currentTime = System.currentTimeMillis();
// Check if any replica is already locked by another runId
for (String replica : replicas) {
LockInfo lockInfo = replicaLocks.get(getReplicaLockKey(replica, runId));
if (lockInfo != null && lockInfo.expirationTime > currentTime && lockInfo.runId.equals(runId)) {
return false; // Replica is locked by another runId and not expired
}
boolean anyReplicaLocked = replicas.stream()
.map(replica -> replicaLocks.get(getReplicaLockKey(replica, runId)))
.anyMatch(lockInfo -> lockInfo != null
&& lockInfo.expirationTime > currentTime && lockInfo.runId.equals(runId));

if (anyReplicaLocked) {
return false; // Replica is locked by another runId and not expired
}

// Lock the replicas for the given runId and segmentId
long expirationTime = currentTime + (ttlSeconds * 1000);
for (String replica : replicas) {
replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, expirationTime));
}
long expirationTime = currentTime + ttlMilliSeconds;
replicas.forEach(replica ->
replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, expirationTime))
);

// Update runId to segmentId mapping
runToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId);
repairRunToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId);
return true;
} finally {
lock.unlock();
}
}

public boolean renewRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
Lock lock = getLockForRunId(runId);
lock.lock();
try {
long currentTime = System.currentTimeMillis();

// Check if all replicas are already locked by this runId
for (String replica : replicas) {
LockInfo lockInfo = replicaLocks.get(getReplicaLockKey(replica, runId));
if (lockInfo == null || !lockInfo.runId.equals(runId) || lockInfo.expirationTime <= currentTime) {
return false; // Some replica is not validly locked by this runId
}
boolean allReplicasLocked = replicas.stream()
.map(replica -> replicaLocks.get(getReplicaLockKey(replica, runId)))
.allMatch(lockInfo -> lockInfo != null && lockInfo.runId.equals(runId)
&& lockInfo.expirationTime > currentTime);

if (!allReplicasLocked) {
return false; // Some replica is not validly locked by this runId
}

// Renew the lock by extending the expiration time
long newExpirationTime = currentTime + (ttlSeconds * 1000);
for (String replica : replicas) {
replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, newExpirationTime));
}
long newExpirationTime = currentTime + ttlMilliSeconds;
replicas.forEach(replica ->
replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, newExpirationTime))
);

// Ensure the segmentId is linked to the runId
runToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId);
repairRunToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId);
return true;
} finally {
lock.unlock();
}
}

public boolean releaseRunningRepairsForNodes(UUID runId, UUID segmentId, Set<String> replicas) {
Lock lock = getLockForRunId(runId);
lock.lock();
try {
// Remove the lock for replicas
for (String replica : replicas) {
LockInfo lockInfo = replicaLocks.get(getReplicaLockKey(replica, runId));
if (lockInfo != null && lockInfo.runId.equals(runId)) {
replicaLocks.remove(getReplicaLockKey(replica, runId));
}
}
replicas.stream()
.map(replica -> getReplicaLockKey(replica, runId))
.map(replicaLocks::get)
.filter(lockInfo -> lockInfo != null && lockInfo.runId.equals(runId))
.forEach(lockInfo -> replicaLocks.remove(getReplicaLockKey(lockInfo.runId.toString(), runId)));

// Remove the segmentId from the runId mapping
Set<UUID> segments = runToSegmentLocks.get(runId);
Set<UUID> segments = repairRunToSegmentLocks.get(runId);
if (segments != null) {
segments.remove(segmentId);
if (segments.isEmpty()) {
runToSegmentLocks.remove(runId);
repairRunToSegmentLocks.remove(runId);
}
}
return true;
Expand All @@ -126,20 +136,20 @@ public boolean releaseRunningRepairsForNodes(UUID runId, UUID segmentId, Set<Str
}

public Set<UUID> getLockedSegmentsForRun(UUID runId) {
return runToSegmentLocks.getOrDefault(runId, Collections.emptySet());
return repairRunToSegmentLocks.getOrDefault(runId, Collections.emptySet());
}

@VisibleForTesting
public void cleanupExpiredLocks() {
lock.lock();
runIdLocks.values().forEach(Lock::lock);
try {
long currentTime = System.currentTimeMillis();

// Remove expired locks from replicaLocks
replicaLocks.entrySet().removeIf(entry -> entry.getValue().expirationTime <= currentTime);

// Clean up runToSegmentLocks by removing segments with no active replicas
runToSegmentLocks.entrySet().removeIf(entry -> {
repairRunToSegmentLocks.entrySet().removeIf(entry -> {
UUID runId = entry.getKey();
Set<UUID> segments = entry.getValue();

Expand All @@ -152,7 +162,7 @@ public void cleanupExpiredLocks() {
return segments.isEmpty();
});
} finally {
lock.unlock();
runIdLocks.values().forEach(Lock::unlock);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

public final class RepairRunnerHangingTest {

private static final long LEAD_TIME = 1L;
private static final long LEAD_TTL = 1000L;
private static final Logger LOG = LoggerFactory.getLogger(RepairRunnerHangingTest.class);
private static final Set<String> TABLES = ImmutableSet.of("table1");
private static final List<BigInteger> THREE_TOKENS = Lists.newArrayList(
Expand Down Expand Up @@ -243,7 +243,7 @@ public void testHangingRepair() throws InterruptedException, ReaperException, JM
final double intensity = 0.5f;
final int repairThreadCount = 1;
final int segmentTimeout = 1;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
storage.getClusterDao().addCluster(cluster);
RepairUnit cf = storage.getRepairUnitDao().addRepairUnit(
RepairUnit.builder()
Expand Down Expand Up @@ -398,7 +398,7 @@ public void testHangingRepairNewApi() throws InterruptedException, ReaperExcepti
final double intensity = 0.5f;
final int repairThreadCount = 1;
final int segmentTimeout = 1;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
storage.getClusterDao().addCluster(cluster);
DateTimeUtils.setCurrentMillisFixed(timeRun);
RepairUnit cf = storage.getRepairUnitDao().addRepairUnit(
Expand Down Expand Up @@ -554,7 +554,7 @@ public void testDontFailRepairAfterTopologyChangeIncrementalRepair() throws Inte
final int repairThreadCount = 1;
final int segmentTimeout = 30;
final List<BigInteger> tokens = THREE_TOKENS;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
AppContext context = new AppContext();
context.storage = storage;
context.config = new ReaperApplicationConfiguration();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
import static org.mockito.Mockito.when;

public final class RepairRunnerTest {
private static final long LEAD_TIME = 1L;
private static final long LEAD_TTL = 100L;
private static final Set<String> TABLES = ImmutableSet.of("table1");
private static final Duration POLL_INTERVAL = Duration.TWO_SECONDS;
private static final List<BigInteger> THREE_TOKENS = Lists.newArrayList(
Expand Down Expand Up @@ -224,7 +224,7 @@ public void testResumeRepair() throws InterruptedException, ReaperException, Mal
final int repairThreadCount = 1;
final int segmentTimeout = 30;
final List<BigInteger> tokens = THREE_TOKENS;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
AppContext context = new AppContext();
context.storage = storage;
context.config = new ReaperApplicationConfiguration();
Expand Down Expand Up @@ -356,7 +356,7 @@ public void testTooManyPendingCompactions()
final int repairThreadCount = 1;
final int segmentTimeout = 30;
final List<BigInteger> tokens = THREE_TOKENS;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
AppContext context = new AppContext();
context.storage = storage;
context.config = new ReaperApplicationConfiguration();
Expand Down Expand Up @@ -549,7 +549,7 @@ public void testDontFailRepairAfterTopologyChange() throws InterruptedException,
final int repairThreadCount = 1;
final int segmentTimeout = 30;
final List<BigInteger> tokens = THREE_TOKENS;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
AppContext context = new AppContext();
context.storage = storage;
context.config = new ReaperApplicationConfiguration();
Expand Down Expand Up @@ -690,7 +690,7 @@ public void testSubrangeIncrementalRepair() throws InterruptedException, ReaperE
final int repairThreadCount = 1;
final int segmentTimeout = 30;
final List<BigInteger> tokens = THREE_TOKENS;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
AppContext context = new AppContext();
context.storage = storage;
context.config = new ReaperApplicationConfiguration();
Expand Down Expand Up @@ -965,7 +965,7 @@ public void getNodeMetricsInLocalDcAvailabilityForLocalDcNodeTest() throws Excep
final int repairThreadCount = 1;
final int segmentTimeout = 30;
final List<BigInteger> tokens = THREE_TOKENS;
final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME);
final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL);
AppContext context = new AppContext();
context.storage = storage;
context.config = new ReaperApplicationConfiguration();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class ReplicaLockManagerWithTtlTest {

@BeforeEach
public void setUp() {
replicaLockManager = new ReplicaLockManagerWithTtl(1); // TTL of 60 seconds
replicaLockManager = new ReplicaLockManagerWithTtl(1000);
runId = UUID.randomUUID();
segmentId = UUID.randomUUID();
replicas = new HashSet<>(Arrays.asList("replica1", "replica2", "replica3"));
Expand Down

0 comments on commit ae5b482

Please sign in to comment.