diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index dada1765ce..fcf6a3d9ca 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -366,14 +366,16 @@ public void sendShuffleData( ShuffleServerGrpcMetrics.SEND_SHUFFLE_DATA_METHOD, transportTime); } } - int requireSize = shuffleServer.getShuffleTaskManager().getRequireBufferSize(requireBufferId); + int requireSize = + shuffleServer.getShuffleTaskManager().getRequireBufferSize(appId, requireBufferId); StatusCode ret = StatusCode.SUCCESS; String responseMessage = "OK"; if (req.getShuffleDataCount() > 0) { ShuffleServerMetrics.counterTotalReceivedDataSize.inc(requireSize); ShuffleTaskManager manager = shuffleServer.getShuffleTaskManager(); - PreAllocatedBufferInfo info = manager.getAndRemovePreAllocatedBuffer(requireBufferId); + PreAllocatedBufferInfo info = + manager.getAndRemovePreAllocatedBuffer(appId, requireBufferId); boolean isPreAllocated = info != null; if (!isPreAllocated) { String errorMsg = diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index f90ec7c8a9..0124ef3746 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -46,6 +46,7 @@ import com.google.common.collect.Queues; import com.google.common.collect.Range; import com.google.common.collect.Sets; +import org.apache.commons.collections.MapUtils; import org.apache.commons.collections4.CollectionUtils; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; @@ -114,7 +115,9 @@ public class ShuffleTaskManager { private Map> partitionsToBlockIds; private final ShuffleBufferManager shuffleBufferManager; private Map shuffleTaskInfos = JavaUtils.newConcurrentMap(); - private Map requireBufferIds = JavaUtils.newConcurrentMap(); + // appId -> {requireBufferId -> PreAllocatedBufferInfo} + private Map> appIdToRequireBufferIdsMap = + JavaUtils.newConcurrentMap(); private Thread clearResourceThread; private BlockingQueue expiredAppIdQueue = Queues.newLinkedBlockingQueue(); private final Cache appLocks; @@ -319,8 +322,12 @@ public StatusCode cacheShuffleData( return shuffleBufferManager.cacheShuffleData(appId, shuffleId, isPreAllocated, spd); } - public PreAllocatedBufferInfo getAndRemovePreAllocatedBuffer(long requireBufferId) { - return requireBufferIds.remove(requireBufferId); + public PreAllocatedBufferInfo getAndRemovePreAllocatedBuffer(String appId, long requireBufferId) { + Map requireBufferIdMap = appIdToRequireBufferIdsMap.get(appId); + if (requireBufferIdMap == null) { + return null; + } + return requireBufferIdMap.remove(requireBufferId); } public void releasePreAllocatedSize(long requireSize) { @@ -328,8 +335,8 @@ public void releasePreAllocatedSize(long requireSize) { } @VisibleForTesting - void removeAndReleasePreAllocatedBuffer(long requireBufferId) { - PreAllocatedBufferInfo info = getAndRemovePreAllocatedBuffer(requireBufferId); + void removeAndReleasePreAllocatedBuffer(String appId, long requireBufferId) { + PreAllocatedBufferInfo info = getAndRemovePreAllocatedBuffer(appId, requireBufferId); if (info != null) { releasePreAllocatedSize(info.getRequireSize()); } @@ -541,9 +548,18 @@ public long requireBuffer( public long requireBuffer(String appId, int requireSize) { if (shuffleBufferManager.requireMemory(requireSize, true)) { long requireId = requireBufferId.incrementAndGet(); - requireBufferIds.put( - requireId, - new PreAllocatedBufferInfo(appId, requireId, System.currentTimeMillis(), requireSize)); + ReentrantReadWriteLock.WriteLock appLock = getAppWriteLock(appId); + try { + // preAllocatedBufferCheck will obtain lock and remove the empty appId + appLock.lock(); + Map requireBufferMaps = + appIdToRequireBufferIdsMap.computeIfAbsent(appId, x -> JavaUtils.newConcurrentMap()); + requireBufferMaps.put( + requireId, + new PreAllocatedBufferInfo(appId, requireId, System.currentTimeMillis(), requireSize)); + } finally { + appLock.unlock(); + } return requireId; } else { LOG.warn("Failed to require buffer, require size: {}", requireSize); @@ -829,6 +845,12 @@ public void removeResources(String appId, boolean checkAppExpired) { partitionsToBlockIds.remove(appId); shuffleBufferManager.removeBuffer(appId); shuffleFlushManager.removeResources(appId); + Map requireBufferIdsMap = appIdToRequireBufferIdsMap.get(appId); + if (requireBufferIdsMap != null) { + for (PreAllocatedBufferInfo info : requireBufferIdsMap.values()) { + removeAndReleasePreAllocatedBuffer(appId, info.getRequireId()); + } + } String operationMsg = String.format("removing storage data for appId:%s", appId); withTimeoutExecution( @@ -896,25 +918,51 @@ public void refreshAppId(String appId) { private void preAllocatedBufferCheck() { try { long current = System.currentTimeMillis(); - List removeIds = Lists.newArrayList(); - for (PreAllocatedBufferInfo info : requireBufferIds.values()) { - if (current - info.getTimestamp() > preAllocationExpired) { - removeIds.add(info.getRequireId()); + for (Map.Entry> entry : + appIdToRequireBufferIdsMap.entrySet()) { + String appId = entry.getKey(); + if (MapUtils.isEmpty(entry.getValue())) { + ReentrantReadWriteLock.WriteLock appLock = getAppWriteLock(appId); + try { + appLock.lock(); + // After double check, remove empty map related this appId from + // appIdToRequireBufferIdsMap + if (MapUtils.isEmpty(entry.getValue())) { + // Keep single point remove appId from appIdToRequireBufferIdsMap + appIdToRequireBufferIdsMap.remove(appId); + continue; + } + } finally { + appLock.unlock(); + } } - } - for (Long requireId : removeIds) { - PreAllocatedBufferInfo info = requireBufferIds.remove(requireId); - if (info != null) { - // move release memory code down to here as the requiredBuffer could be consumed during - // removing processing. - shuffleBufferManager.releaseMemory(info.getRequireSize(), false, true); - LOG.warn( - "Remove expired preAllocatedBuffer[id={}] that required by app: {}", - requireId, - info.getAppId()); - ShuffleServerMetrics.counterPreAllocatedBufferExpired.inc(); - } else { - LOG.info("PreAllocatedBuffer[id={}] has already be used", requireId); + List toRemoveIds = Lists.newArrayList(); + for (PreAllocatedBufferInfo info : entry.getValue().values()) { + if (current - info.getTimestamp() > preAllocationExpired) { + toRemoveIds.add(info.getRequireId()); + } + } + List removedIds = Lists.newArrayList(); + List usedIds = Lists.newArrayList(); + for (Long requireId : toRemoveIds) { + PreAllocatedBufferInfo info = getAndRemovePreAllocatedBuffer(appId, requireId); + if (info != null) { + // move release memory code down to here as the requiredBuffer could be consumed during + // removing processing. + shuffleBufferManager.releaseMemory(info.getRequireSize(), false, true); + removedIds.add(requireId); + ShuffleServerMetrics.counterPreAllocatedBufferExpired.inc(); + } else { + usedIds.add(requireId); + } + if (removedIds.size() > 0) { + LOG.info( + "Remove expired preAllocatedBuffer[id={}] for app[{}], removedIds: {}, usedIds: {}", + requireId, + appId, + removedIds, + usedIds); + } } } } catch (Exception e) { @@ -922,8 +970,12 @@ private void preAllocatedBufferCheck() { } } - public int getRequireBufferSize(long requireId) { - PreAllocatedBufferInfo pabi = requireBufferIds.get(requireId); + public int getRequireBufferSize(String appId, long requireId) { + Map requireBufferIdMap = appIdToRequireBufferIdsMap.get(appId); + if (requireBufferIdMap == null) { + return 0; + } + PreAllocatedBufferInfo pabi = requireBufferIdMap.get(requireId); if (pabi == null) { return 0; } @@ -940,8 +992,8 @@ public Set getAppIds() { } @VisibleForTesting - Map getRequireBufferIds() { - return requireBufferIds; + Supplier> getRequireBufferIdSizeByAppId(String appId) { + return () -> appIdToRequireBufferIdsMap.get(appId); } @VisibleForTesting diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index 35eeb429df..21183ff129 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -121,7 +121,7 @@ public void handleSendShuffleDataRequest(TransportClient client, SendShuffleData // thread, // otherwise we need to release the required size. PreAllocatedBufferInfo info = - shuffleTaskManager.getAndRemovePreAllocatedBuffer(requireBufferId); + shuffleTaskManager.getAndRemovePreAllocatedBuffer(appId, requireBufferId); int requireSize = info == null ? 0 : info.getRequireSize(); int requireBlocksSize = requireSize - req.encodedLength() < 0 ? 0 : requireSize - req.encodedLength(); diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java index c9cbdf49a6..3b97a9c0d8 100644 --- a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java @@ -28,6 +28,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -365,15 +366,17 @@ public void writeProcessTest() throws Exception { StringUtils.EMPTY); final List expectedBlocks1 = Lists.newArrayList(); final List expectedBlocks2 = Lists.newArrayList(); - final Map bufferIds = shuffleTaskManager.getRequireBufferIds(); + // Since requireBuffer doesn't specify the appId, "EMPTY" is used instead. + final Supplier> bufferIds = + shuffleTaskManager.getRequireBufferIdSizeByAppId("EMPTY"); shuffleTaskManager.requireBuffer(10); shuffleTaskManager.requireBuffer(10); shuffleTaskManager.requireBuffer(10); - assertEquals(3, bufferIds.size()); + assertEquals(3, bufferIds.get().size()); // required buffer should be clear if it doesn't receive data after timeout Thread.sleep(6000); - assertEquals(0, bufferIds.size()); + assertEquals(0, bufferIds.get() == null ? 0 : bufferIds.get().size()); shuffleTaskManager.commitShuffle(appId, shuffleId); @@ -381,17 +384,17 @@ public void writeProcessTest() throws Exception { ShufflePartitionedData partitionedData0 = createPartitionedData(1, 1, 35); expectedBlocks1.addAll(Lists.newArrayList(partitionedData0.getBlockList())); long bufferId = shuffleTaskManager.requireBuffer(35); - assertEquals(1, bufferIds.size()); - PreAllocatedBufferInfo pabi = bufferIds.get(bufferId); + assertEquals(1, bufferIds.get().size()); + PreAllocatedBufferInfo pabi = bufferIds.get().get(bufferId); assertEquals(35, pabi.getRequireSize()); StatusCode sc = shuffleTaskManager.cacheShuffleData(appId, shuffleId, true, partitionedData0); shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, partitionedData0.getBlockList()); // the required id won't be removed in shuffleTaskManager, it is removed in Grpc service - assertEquals(1, bufferIds.size()); + assertEquals(1, bufferIds.get().size()); assertEquals(StatusCode.SUCCESS, sc); shuffleTaskManager.commitShuffle(appId, shuffleId); // manually release the pre allocate buffer - shuffleTaskManager.removeAndReleasePreAllocatedBuffer(bufferId); + shuffleTaskManager.removeAndReleasePreAllocatedBuffer(appId, bufferId); ShuffleFlushManager shuffleFlushManager = shuffleServer.getShuffleFlushManager(); assertEquals( @@ -404,7 +407,7 @@ public void writeProcessTest() throws Exception { sc = shuffleTaskManager.cacheShuffleData(appId, shuffleId, true, partitionedData1); shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, partitionedData1.getBlockList()); assertEquals(StatusCode.SUCCESS, sc); - shuffleTaskManager.removeAndReleasePreAllocatedBuffer(bufferId); + shuffleTaskManager.removeAndReleasePreAllocatedBuffer(appId, bufferId); waitForFlush(shuffleFlushManager, appId, shuffleId, 2 + 1); // won't flush for partition 1-1 @@ -421,7 +424,7 @@ public void writeProcessTest() throws Exception { bufferId = shuffleTaskManager.requireBuffer(30); sc = shuffleTaskManager.cacheShuffleData(appId, shuffleId, true, partitionedData3); shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, partitionedData3.getBlockList()); - shuffleTaskManager.removeAndReleasePreAllocatedBuffer(bufferId); + shuffleTaskManager.removeAndReleasePreAllocatedBuffer(appId, bufferId); assertEquals(StatusCode.SUCCESS, sc); // flush for partition 2-2 @@ -430,7 +433,7 @@ public void writeProcessTest() throws Exception { bufferId = shuffleTaskManager.requireBuffer(35); sc = shuffleTaskManager.cacheShuffleData(appId, shuffleId, true, partitionedData4); shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, partitionedData4.getBlockList()); - shuffleTaskManager.removeAndReleasePreAllocatedBuffer(bufferId); + shuffleTaskManager.removeAndReleasePreAllocatedBuffer(appId, bufferId); assertEquals(StatusCode.SUCCESS, sc); shuffleTaskManager.commitShuffle(appId, shuffleId); @@ -444,7 +447,7 @@ public void writeProcessTest() throws Exception { bufferId = shuffleTaskManager.requireBuffer(70); sc = shuffleTaskManager.cacheShuffleData(appId, shuffleId, true, partitionedData5); assertEquals(StatusCode.SUCCESS, sc); - shuffleTaskManager.removeAndReleasePreAllocatedBuffer(bufferId); + shuffleTaskManager.removeAndReleasePreAllocatedBuffer(appId, bufferId); // 2 new blocks should be committed waitForFlush(shuffleFlushManager, appId, shuffleId, 2 + 1 + 3 + 2); @@ -460,7 +463,7 @@ public void writeProcessTest() throws Exception { bufferId = shuffleTaskManager.requireBuffer(70); sc = shuffleTaskManager.cacheShuffleData(appId, shuffleId, true, partitionedData7); assertEquals(StatusCode.SUCCESS, sc); - shuffleTaskManager.removeAndReleasePreAllocatedBuffer(bufferId); + shuffleTaskManager.removeAndReleasePreAllocatedBuffer(appId, bufferId); // 2 new blocks should be committed waitForFlush(shuffleFlushManager, appId, shuffleId, 2 + 1 + 3 + 2 + 2);