diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index 159da8a84d..77a008c1c0 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -722,7 +722,8 @@ public void registerShuffle( ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - RssProtos.MergeContext mergeContext) {} + RssProtos.MergeContext mergeContext, + Map properties) {} @Override public boolean sendCommit( diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java index d2aaebe045..82a98a84ef 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java @@ -508,7 +508,8 @@ public void registerShuffle( ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - RssProtos.MergeContext mergeContext) {} + RssProtos.MergeContext mergeContext, + Map properties) {} @Override public boolean sendCommit( diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 47f9e271de..11d81ed31d 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -34,6 +34,8 @@ import java.util.function.Supplier; import java.util.stream.Collectors; +import scala.Tuple2; + import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -1016,6 +1018,7 @@ protected void registerShuffleServers( } LOG.info("Start to register shuffleId {}", shuffleId); long start = System.currentTimeMillis(); + Map sparkConfMap = sparkConfToMap(getSparkConf()); serverToPartitionRanges.entrySet().stream() .forEach( entry -> { @@ -1028,7 +1031,8 @@ protected void registerShuffleServers( ShuffleDataDistributionType.NORMAL, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - null); + null, + sparkConfMap); }); LOG.info( "Finish register shuffleId {} with {} ms", shuffleId, (System.currentTimeMillis() - start)); @@ -1045,6 +1049,7 @@ protected void registerShuffleServers( } LOG.info("Start to register shuffleId[{}]", shuffleId); long start = System.currentTimeMillis(); + Map sparkConfMap = sparkConfToMap(getSparkConf()); Set>> entries = serverToPartitionRanges.entrySet(); entries.stream() @@ -1057,7 +1062,8 @@ protected void registerShuffleServers( entry.getValue(), remoteStorage, dataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + sparkConfMap); }); LOG.info( "Finish register shuffleId[{}] with {} ms", @@ -1084,4 +1090,20 @@ public boolean isRssStageRetryForWriteFailureEnabled() { public boolean isRssStageRetryForFetchFailureEnabled() { return rssStageRetryForFetchFailureEnabled; } + + @VisibleForTesting + public SparkConf getSparkConf() { + return sparkConf; + } + + public Map sparkConfToMap(SparkConf sparkConf) { + Map map = new HashMap<>(); + + for (Tuple2 tuple : sparkConf.getAll()) { + String key = tuple._1; + map.put(key, tuple._2); + } + + return map; + } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 70369f503a..95c89bd29d 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -899,11 +899,6 @@ protected void registerCoordinator() { this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX)); } - @VisibleForTesting - public SparkConf getSparkConf() { - return sparkConf; - } - private synchronized void startHeartbeat() { shuffleWriteClient.registerApplicationInfo(id.get(), heartbeatTimeout, user); if (!heartbeatStarted) { diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java index 3765dc83b4..05fd55bc2e 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java @@ -716,7 +716,8 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - RssProtos.MergeContext mergeContext) {} + RssProtos.MergeContext mergeContext, + Map properties) {} @Override public boolean sendCommit( diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index d21c7e67b7..caab46020f 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -73,7 +73,53 @@ default void registerShuffle( dataDistributionType, maxConcurrencyPerPartitionToWrite, 0, - null); + null, + Collections.emptyMap()); + } + + default void registerShuffle( + ShuffleServerInfo shuffleServerInfo, + String appId, + int shuffleId, + List partitionRanges, + RemoteStorageInfo remoteStorage, + ShuffleDataDistributionType dataDistributionType, + int maxConcurrencyPerPartitionToWrite, + Map properties) { + registerShuffle( + shuffleServerInfo, + appId, + shuffleId, + partitionRanges, + remoteStorage, + dataDistributionType, + maxConcurrencyPerPartitionToWrite, + 0, + null, + properties); + } + + default void registerShuffle( + ShuffleServerInfo shuffleServerInfo, + String appId, + int shuffleId, + List partitionRanges, + RemoteStorageInfo remoteStorage, + ShuffleDataDistributionType dataDistributionType, + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber, + MergeContext mergeContext) { + registerShuffle( + shuffleServerInfo, + appId, + shuffleId, + partitionRanges, + remoteStorage, + dataDistributionType, + maxConcurrencyPerPartitionToWrite, + stageAttemptNumber, + mergeContext, + Collections.emptyMap()); } void registerShuffle( @@ -85,7 +131,8 @@ void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - MergeContext mergeContext); + MergeContext mergeContext, + Map properties); boolean sendCommit( Set shuffleServerInfoSet, String appId, int shuffleId, int numMaps); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index c81d3c7255..ac93d57b1e 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -565,7 +565,8 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - MergeContext mergeContext) { + MergeContext mergeContext, + Map properties) { String user = null; try { user = UserGroupInformation.getCurrentUser().getShortUserName(); @@ -583,7 +584,8 @@ public void registerShuffle( dataDistributionType, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - mergeContext); + mergeContext, + properties); RssRegisterShuffleResponse response = getShuffleServerClient(shuffleServerInfo).registerShuffle(request); diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java index 6798a792cb..d856292f20 100644 --- a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java @@ -64,7 +64,8 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - RssProtos.MergeContext mergeContext) {} + RssProtos.MergeContext mergeContext, + Map properties) {} @Override public boolean sendCommit( diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java b/common/src/main/java/org/apache/uniffle/common/util/Constants.java index d63c2e46e8..79ceb2f10f 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java +++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java @@ -91,4 +91,6 @@ private Constants() {} public static final String DRIVER_HOST = "driver.host"; public static final String DATE_PATTERN = "yyyy-MM-dd HH:mm:ss"; + + public static final String SPARK_RSS_CONFIG_PREFIX = "spark."; } diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinWithLocalOrderTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinWithLocalOrderTest.java index aff3ff3e24..c6c8b07963 100644 --- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinWithLocalOrderTest.java +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinWithLocalOrderTest.java @@ -28,7 +28,7 @@ public class AQESkewedJoinWithLocalOrderTest extends AQESkewedJoinTest { @Override public void updateSparkConfCustomer(SparkConf sparkConf) { - sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()); + sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); sparkConf.set( "spark." + RssClientConf.DATA_DISTRIBUTION_TYPE.key(), ShuffleDataDistributionType.LOCAL_ORDER.name()); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index dccd9f9383..98180d647c 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -198,7 +198,8 @@ private ShuffleRegisterResponse doRegisterShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - MergeContext mergeContext) { + MergeContext mergeContext, + Map properties) { ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder(); reqBuilder .setAppId(appId) @@ -207,7 +208,8 @@ private ShuffleRegisterResponse doRegisterShuffle( .setShuffleDataDistribution(RssProtos.DataDistribution.valueOf(dataDistributionType.name())) .setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite) .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)) - .setStageAttemptNumber(stageAttemptNumber); + .setStageAttemptNumber(stageAttemptNumber) + .putAllProperties(properties); if (mergeContext != null) { reqBuilder.setMergeContext(mergeContext); } @@ -484,7 +486,8 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ request.getDataDistributionType(), request.getMaxConcurrencyPerPartitionToWrite(), request.getStageAttemptNumber(), - request.getMergeContext()); + request.getMergeContext(), + request.getProperties()); RssRegisterShuffleResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java index 92ed1e15e9..a2cac5367f 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java @@ -17,8 +17,11 @@ package org.apache.uniffle.client.request; +import java.util.Collections; import java.util.List; +import java.util.Map; +import com.google.common.annotations.VisibleForTesting; import org.apache.commons.lang3.StringUtils; import org.apache.uniffle.common.PartitionRange; @@ -39,7 +42,9 @@ public class RssRegisterShuffleRequest { private int stageAttemptNumber; private final MergeContext mergeContext; + private Map properties; + @VisibleForTesting public RssRegisterShuffleRequest( String appId, int shuffleId, @@ -57,7 +62,8 @@ public RssRegisterShuffleRequest( dataDistributionType, maxConcurrencyPerPartitionToWrite, 0, - null); + null, + Collections.emptyMap()); } public RssRegisterShuffleRequest( @@ -69,7 +75,8 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - MergeContext mergeContext) { + MergeContext mergeContext, + Map properties) { this.appId = appId; this.shuffleId = shuffleId; this.partitionRanges = partitionRanges; @@ -79,8 +86,10 @@ public RssRegisterShuffleRequest( this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite; this.stageAttemptNumber = stageAttemptNumber; this.mergeContext = mergeContext; + this.properties = properties; } + @VisibleForTesting public RssRegisterShuffleRequest( String appId, int shuffleId, @@ -97,7 +106,8 @@ public RssRegisterShuffleRequest( dataDistributionType, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, - null); + null, + Collections.emptyMap()); } public RssRegisterShuffleRequest( @@ -111,7 +121,8 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType.NORMAL, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, - null); + null, + Collections.emptyMap()); } public String getAppId() { @@ -149,4 +160,8 @@ public int getStageAttemptNumber() { public MergeContext getMergeContext() { return mergeContext; } + + public Map getProperties() { + return properties; + } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index d92ec40c7a..5e8cc632d5 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -197,6 +197,7 @@ message ShuffleRegisterRequest { int32 maxConcurrencyPerPartitionToWrite = 7; int32 stageAttemptNumber = 8; MergeContext mergeContext = 9; + map properties = 10; } enum DataDistribution { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java index 61fdc466a4..6fcf7e5c73 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java @@ -46,6 +46,7 @@ import org.apache.uniffle.storage.common.Storage; import org.apache.uniffle.storage.handler.api.ShuffleWriteHandlerWrapper; import org.apache.uniffle.storage.request.CreateShuffleWriteHandlerRequest; +import org.apache.uniffle.storage.util.StorageType; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION; import static org.apache.uniffle.server.ShuffleServerMetrics.COMMITTED_BLOCK_COUNT; @@ -62,6 +63,7 @@ public class ShuffleFlushManager { private final String storageType; private final int storageDataReplica; private final ShuffleServerConf shuffleServerConf; + private final boolean storageTypeWithMemory; private Configuration hadoopConf; // appId -> shuffleId -> committed shuffle blockIds private Map> committedBlockIds = @@ -101,6 +103,7 @@ public ShuffleFlushManager( .mapToLong(bitmap -> bitmap.getLongCardinality()) .sum(), 2 * 60 * 1000L /* 2 minutes */); + this.storageTypeWithMemory = StorageType.withMemory(StorageType.valueOf(storageType)); } public void addToFlushQueue(ShuffleDataFlushEvent event) { @@ -194,11 +197,14 @@ public void processFlushEvent(ShuffleDataFlushEvent event) throws Exception { throw new EventRetryException(); } long endTime = System.currentTimeMillis(); - - // update some metrics for shuffle task - updateCommittedBlockIds(event.getAppId(), event.getShuffleId(), event.getShuffleBlocks()); ShuffleTaskInfo shuffleTaskInfo = shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(event.getAppId()); + if (shuffleTaskInfo == null || !storageTypeWithMemory) { + // With memory storage type should never need cachedBlockIds, + // since client do not need call finish shuffle rpc + // update some metrics for shuffle task + updateCommittedBlockIds(event.getAppId(), event.getShuffleId(), event.getShuffleBlocks()); + } if (isStorageAuditLogEnabled) { AUDIT_LOGGER.info( String.format( diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 695b635919..c33cc69349 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -98,6 +98,7 @@ import org.apache.uniffle.storage.common.Storage; import org.apache.uniffle.storage.common.StorageReadMetrics; import org.apache.uniffle.storage.util.ShuffleStorageUtils; +import org.apache.uniffle.storage.util.StorageType; import static org.apache.uniffle.server.merge.ShuffleMergeManager.MERGE_APP_SUFFIX; @@ -322,7 +323,8 @@ public void registerShuffle( new RemoteStorageInfo(remoteStoragePath, remoteStorageConf), user, shuffleDataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + req.getPropertiesMap()); if (StatusCode.SUCCESS == result && shuffleServer.isRemoteMergeEnable() && req.hasMergeContext()) { @@ -338,7 +340,8 @@ public void registerShuffle( new RemoteStorageInfo(remoteStoragePath, remoteStorageConf), user, shuffleDataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + req.getPropertiesMap()); if (result == StatusCode.SUCCESS) { result = shuffleServer @@ -576,6 +579,18 @@ public void commitShuffleTask( String appId = req.getAppId(); int shuffleId = req.getShuffleId(); auditContext.withAppId(appId).withShuffleId(shuffleId); + org.apache.uniffle.common.StorageType storageType = + shuffleServer.getShuffleServerConf().get(ShuffleServerConf.RSS_STORAGE_TYPE); + boolean storageTypeWithMemory = + StorageType.withMemory(StorageType.valueOf(storageType.name())); + if (storageTypeWithMemory) { + String errorMessage = + String.format( + "commitShuffleTask should not be called while server-side configured StorageType to %s for appId %s", + storageType, appId); + LOG.error(errorMessage); + throw new UnsupportedOperationException(errorMessage); + } StatusCode status = verifyRequest(appId); if (status != StatusCode.SUCCESS) { auditContext.withStatusCode(status); @@ -633,6 +648,18 @@ public void finishShuffle( String appId = req.getAppId(); int shuffleId = req.getShuffleId(); auditContext.withAppId(appId).withShuffleId(shuffleId); + org.apache.uniffle.common.StorageType storageType = + shuffleServer.getShuffleServerConf().get(ShuffleServerConf.RSS_STORAGE_TYPE); + boolean storageTypeWithMemory = + StorageType.withMemory(StorageType.valueOf(storageType.name())); + if (storageTypeWithMemory) { + String errorMessage = + String.format( + "finishShuffle should not be called while server-side configured StorageType to %s for appId %s", + storageType, appId); + LOG.error(errorMessage); + throw new UnsupportedOperationException(errorMessage); + } StatusCode status = verifyRequest(appId); if (status != StatusCode.SUCCESS) { auditContext.withStatusCode(status); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java index be030769ff..11ac2dff34 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import com.google.common.collect.Sets; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -76,6 +77,7 @@ public class ShuffleTaskInfo { private final Map shuffleDetailInfos; private final Map latestStageAttemptNumbers; + private Map properties; public ShuffleTaskInfo(String appId) { this.appId = appId; @@ -315,4 +317,13 @@ public String toString() { + shuffleDetailInfos + '}'; } + + public void setProperties(Map properties) { + Map filteredProperties = + properties.entrySet().stream() + .filter(entry -> entry.getKey().contains(".rss.")) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + this.properties = filteredProperties; + LOGGER.info("{} set properties to {}", appId, properties); + } } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index dbc94c0072..af37646a78 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -86,6 +87,7 @@ import org.apache.uniffle.storage.common.StorageReadMetrics; import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest; import org.apache.uniffle.storage.util.ShuffleStorageUtils; +import org.apache.uniffle.storage.util.StorageType; import static org.apache.uniffle.server.ShuffleServerConf.CLIENT_MAX_CONCURRENCY_LIMITATION_OF_ONE_PARTITION; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION; @@ -96,6 +98,7 @@ public class ShuffleTaskManager { private static final Logger LOG = LoggerFactory.getLogger(ShuffleTaskManager.class); + private final boolean storageTypeWithMemory; private ShuffleFlushManager shuffleFlushManager; private final ScheduledExecutorService scheduledExecutorService; private final ScheduledExecutorService expiredAppCleanupExecutorService; @@ -146,6 +149,12 @@ public ShuffleTaskManager( this.shuffleBufferManager = shuffleBufferManager; this.storageManager = storageManager; this.shuffleMergeManager = shuffleMergeManager; + org.apache.uniffle.common.StorageType storageType = + conf.get(ShuffleServerConf.RSS_STORAGE_TYPE); + this.storageTypeWithMemory = + storageType == null + ? false + : StorageType.withMemory(StorageType.valueOf(storageType.name())); this.appExpiredWithoutHB = conf.getLong(ShuffleServerConf.SERVER_APP_EXPIRED_WITHOUT_HEARTBEAT); this.commitCheckIntervalMax = conf.getLong(ShuffleServerConf.SERVER_COMMIT_CHECK_INTERVAL_MAX); this.preAllocationExpired = conf.getLong(ShuffleServerConf.SERVER_PRE_ALLOCATION_EXPIRED); @@ -301,7 +310,8 @@ public StatusCode registerShuffle( remoteStorageInfo, user, ShuffleDataDistributionType.NORMAL, - -1); + -1, + Collections.emptyMap()); } public StatusCode registerShuffle( @@ -311,13 +321,15 @@ public StatusCode registerShuffle( RemoteStorageInfo remoteStorageInfo, String user, ShuffleDataDistributionType dataDistType, - int maxConcurrencyPerPartitionToWrite) { + int maxConcurrencyPerPartitionToWrite, + Map properties) { ReentrantReadWriteLock.WriteLock lock = getAppWriteLock(appId); lock.lock(); try { refreshAppId(appId); ShuffleTaskInfo taskInfo = shuffleTaskInfos.get(appId); + taskInfo.setProperties(properties); taskInfo.setUser(user); taskInfo.setSpecification( ShuffleSpecification.builder() @@ -520,15 +532,23 @@ public void updateCachedBlockIds( } ShuffleTaskInfo shuffleTaskInfo = shuffleTaskInfos.computeIfAbsent(appId, x -> new ShuffleTaskInfo(appId)); - Roaring64NavigableMap bitmap = - shuffleTaskInfo - .getCachedBlockIds() - .computeIfAbsent(shuffleId, x -> Roaring64NavigableMap.bitmapOf()); - long size = 0L; - synchronized (bitmap) { + // With memory storage type should never need cachedBlockIds, + // since client do not need call finish shuffle rpc + if (!storageTypeWithMemory) { + Roaring64NavigableMap bitmap = + shuffleTaskInfo + .getCachedBlockIds() + .computeIfAbsent(shuffleId, x -> Roaring64NavigableMap.bitmapOf()); + + synchronized (bitmap) { + for (ShufflePartitionedBlock spb : spbs) { + bitmap.addLong(spb.getBlockId()); + size += spb.getEncodedLength(); + } + } + } else { for (ShufflePartitionedBlock spb : spbs) { - bitmap.addLong(spb.getBlockId()); size += spb.getEncodedLength(); } }