From 4c39534179ad1280c4502bd8914a1e890fff96e1 Mon Sep 17 00:00:00 2001 From: baoloongmao Date: Wed, 11 Dec 2024 09:39:29 +0800 Subject: [PATCH] Add UT for DelegationRssShuffleManager#getReader --- .../shuffle/DelegationRssShuffleManager.java | 4 ++- .../DelegationRssShuffleManagerTest.java | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java index dfdf587c33..ff2a205ebe 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.Set; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.security.UserGroupInformation; @@ -175,7 +176,8 @@ private boolean tryAccessCluster() { return false; } - private ShuffleManager createShuffleManagerInExecutor() throws RssException { + @VisibleForTesting + protected ShuffleManager createShuffleManagerInExecutor() throws RssException { ShuffleManager shuffleManager; // get useRSS from spark conf boolean useRSS = sparkConf.get(RssSparkConfig.RSS_ENABLED); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java index 092869e31d..0abd9888c1 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java @@ -23,6 +23,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager; import org.junit.jupiter.api.Test; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.storage.util.StorageType; import static org.apache.uniffle.common.rpc.StatusCode.ACCESS_DENIED; @@ -30,6 +31,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; public class DelegationRssShuffleManagerTest extends RssShuffleManagerTestBase { @@ -131,6 +136,31 @@ public void testTryAccessCluster() throws Exception { assertCreateSortShuffleManager(secondConf); } + @Test + public void testGetReader() throws Exception { + ShuffleReader mockReader = mock(ShuffleReader.class); + ShuffleManager mockShuffleManager = mock(ShuffleManager.class); + doReturn(mockReader) + .when(mockShuffleManager) + .getReader(any(), anyInt(), anyInt(), anyInt(), anyInt(), any(), any()); + DelegationRssShuffleManager delegationRssShuffleManager; + SparkConf conf = new SparkConf(); + conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); + delegationRssShuffleManager = + new DelegationRssShuffleManager(conf, false) { + @Override + protected ShuffleManager createShuffleManagerInExecutor() throws RssException { + return mockShuffleManager; + } + }; + assertEquals(mockShuffleManager, delegationRssShuffleManager.getDelegate()); + ShuffleReader reader = + delegationRssShuffleManager.getReader( + new BaseShuffleHandle(0, null), 1, 1, 1, 1, null, null); + + assertEquals(mockReader, reader); + } + private void assertCreateSortShuffleManager(SparkConf conf) throws Exception { DelegationRssShuffleManager delegationRssShuffleManager = new DelegationRssShuffleManager(conf, true);