From 43879ca13aa261e6a775178c7015b8cf758053ab Mon Sep 17 00:00:00 2001 From: maobaolong Date: Thu, 9 Jan 2025 20:46:04 +0800 Subject: [PATCH] [MINOR] improvement(spark-client): Refactor RssShuffleManager for spark v2 and v3 to reduce redundant code (#2330) ### What changes were proposed in this pull request? Refactor and abstract the same code into base class. ### Why are the changes needed? Reduce the redundant code and simplify the development in rss spark-client scope. The Spark version unrelated code should be placed into RssShuffleManagerBase class, for this RssShuffleManager Class, it should only maintains the spark api related codes. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Existing UTs. - Our test cluster, tested both spark v2 and v3 version. --- .../manager/RssShuffleManagerBase.java | 452 ++++++++++++++++- .../spark/shuffle/RssShuffleManager.java | 409 ++------------- .../spark/shuffle/RssShuffleManager.java | 480 ++---------------- 3 files changed, 521 insertions(+), 820 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index c1fc5b68eb..d869c64fe6 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -17,6 +17,7 @@ package org.apache.uniffle.shuffle.manager; +import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; @@ -28,6 +29,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -40,6 +44,7 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.security.UserGroupInformation; import org.apache.spark.MapOutputTracker; @@ -58,6 +63,8 @@ import org.apache.spark.shuffle.handle.ShuffleHandleInfo; import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; +import org.apache.spark.shuffle.writer.AddBlockEvent; +import org.apache.spark.shuffle.writer.DataPusher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -65,6 +72,7 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; +import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient; import org.apache.uniffle.client.request.RssFetchClientConfRequest; import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest; @@ -83,17 +91,38 @@ import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.rpc.GrpcServer; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.ExpiringCloseableSupplier; +import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.RetryUtils; +import org.apache.uniffle.common.util.RssUtils; +import org.apache.uniffle.common.util.ThreadUtils; import org.apache.uniffle.shuffle.BlockIdManager; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED; +import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT; import static org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX; +import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE; import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED; public abstract class RssShuffleManagerBase implements RssShuffleManagerInterface, ShuffleManager { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManagerBase.class); + protected final int dataTransferPoolSize; + protected final int dataCommitPoolSize; + protected final int dataReplica; + protected final int dataReplicaWrite; + protected final int dataReplicaRead; + protected final boolean dataReplicaSkipEnabled; + protected final Map> taskToSuccessBlockIds; + protected final Map taskToFailedBlockSendTracker; + private Set failedTaskIds = Sets.newConcurrentHashSet(); + private AtomicBoolean isInitialized = new AtomicBoolean(false); private Method unregisterAllMapOutputMethod; private Method registerShuffleMethod; @@ -107,7 +136,8 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac protected int maxConcurrencyPerPartitionToWrite; protected String clientType; - protected SparkConf sparkConf; + protected final SparkConf sparkConf; + protected final RssConf rssConf; protected Map shuffleIdToPartitionNum; protected Map shuffleIdToNumMapTasks; protected Supplier managerClientSupplier; @@ -126,9 +156,227 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac protected boolean partitionReassignEnabled; protected boolean shuffleManagerRpcServiceEnabled; - public RssShuffleManagerBase() { + protected boolean heartbeatStarted = false; + protected final long heartbeatInterval; + protected final long heartbeatTimeout; + protected String user; + protected String uuid; + protected ScheduledExecutorService heartBeatScheduledExecutorService; + protected final int maxFailures; + protected final boolean speculation; + protected final BlockIdLayout blockIdLayout; + private ShuffleManagerGrpcService service; + protected GrpcServer shuffleManagerServer; + protected DataPusher dataPusher; + + public RssShuffleManagerBase(SparkConf conf, boolean isDriver) { LOG.info( "Uniffle {} version: {}", this.getClass().getName(), Constants.VERSION_AND_REVISION_SHORT); + this.sparkConf = conf; + checkSupported(sparkConf); + boolean supportsRelocation = + Optional.ofNullable(SparkEnv.get()) + .map(env -> env.serializer().supportsRelocationOfSerializedObjects()) + .orElse(true); + if (!supportsRelocation) { + LOG.warn( + "RSSShuffleManager requires a serializer which supports relocations of serialized object. Please set " + + "spark.serializer to org.apache.spark.serializer.KryoSerializer instead"); + } + this.user = sparkConf.get("spark.rss.quota.user", "user"); + this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis())); + this.dynamicConfEnabled = sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED); + + // fetch client conf and apply them if necessary + if (isDriver && this.dynamicConfEnabled) { + fetchAndApplyDynamicConf(sparkConf); + } + RssSparkShuffleUtils.validateRssClientConf(sparkConf); + + // convert spark conf to rss conf after fetching dynamic client conf + this.rssConf = RssSparkConfig.toRssConf(sparkConf); + RssUtils.setExtraJavaProperties(rssConf); + + // set & check replica config + this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA); + this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE); + this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ); + this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED); + LOG.info( + "Check quorum config [" + + dataReplica + + ":" + + dataReplicaWrite + + ":" + + dataReplicaRead + + ":" + + dataReplicaSkipEnabled + + "]"); + RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); + + this.maxConcurrencyPerPartitionToWrite = rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE); + + this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); + + // configure block id layout + this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4); + this.speculation = sparkConf.getBoolean("spark.speculation", false); + // configureBlockIdLayout requires maxFailures and speculation to be initialized + configureBlockIdLayout(sparkConf, rssConf); + this.blockIdLayout = BlockIdLayout.from(rssConf); + + this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE); + this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE); + + // External shuffle service is not supported when using remote shuffle service + sparkConf.set("spark.shuffle.service.enabled", "false"); + sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false"); + sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true"); + LOG.info("Disable external shuffle service in RssShuffleManager."); + sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false"); + LOG.info("Disable local shuffle reader in RssShuffleManager."); + // If we store shuffle data in distributed filesystem or in a disaggregated + // shuffle cluster, we don't need shuffle data locality + sparkConf.set("spark.shuffle.reduceLocality.enabled", "false"); + LOG.info("Disable shuffle data locality in RssShuffleManager."); + + taskToSuccessBlockIds = JavaUtils.newConcurrentMap(); + taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap(); + this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap(); + this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap(); + + // stage retry for write/fetch failure + rssStageRetryForFetchFailureEnabled = + rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED); + rssStageRetryForWriteFailureEnabled = + rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED); + if (rssStageRetryForFetchFailureEnabled || rssStageRetryForWriteFailureEnabled) { + rssStageRetryEnabled = true; + List logTips = new ArrayList<>(); + if (rssStageRetryForWriteFailureEnabled) { + logTips.add("write"); + } + if (rssStageRetryForWriteFailureEnabled) { + logTips.add("fetch"); + } + LOG.info( + "Activate the stage retry mechanism that will resubmit stage on {} failure", + StringUtils.join(logTips, "/")); + } + + this.partitionReassignEnabled = rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED); + // The feature of partition reassign is exclusive with multiple replicas and stage retry. + if (partitionReassignEnabled) { + if (dataReplica > 1) { + throw new RssException( + "The feature of task partition reassign is incompatible with multiple replicas mechanism."); + } + } + this.blockIdSelfManagedEnabled = rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED); + this.shuffleManagerRpcServiceEnabled = + partitionReassignEnabled || rssStageRetryEnabled || blockIdSelfManagedEnabled; + + if (isDriver) { + heartBeatScheduledExecutorService = + ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); + if (shuffleManagerRpcServiceEnabled) { + LOG.info("stage resubmit is supported and enabled"); + // start shuffle manager server + rssConf.set(RPC_SERVER_PORT, 0); + ShuffleManagerServerFactory factory = new ShuffleManagerServerFactory(this, rssConf); + service = factory.getService(); + shuffleManagerServer = factory.getServer(service); + try { + shuffleManagerServer.start(); + // pass this as a spark.rss.shuffle.manager.grpc.port config, so it can be propagated to + // executor properly. + sparkConf.set( + RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, shuffleManagerServer.getPort()); + } catch (Exception e) { + LOG.error("Failed to start shuffle manager server", e); + throw new RssException(e); + } + } + } + if (shuffleManagerRpcServiceEnabled) { + getOrCreateShuffleManagerClientSupplier(); + } + + // Start heartbeat thread. + this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL); + this.heartbeatTimeout = + sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); + heartBeatScheduledExecutorService = + ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); + + this.shuffleWriteClient = createShuffleWriteClient(); + registerCoordinator(); + + LOG.info("Rss data pusher is starting..."); + int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); + int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); + this.dataPusher = + new DataPusher( + shuffleWriteClient, + taskToSuccessBlockIds, + taskToFailedBlockSendTracker, + failedTaskIds, + poolSize, + keepAliveTime); + this.partitionReassignMaxServerNum = + rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); + this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(); + } + + @VisibleForTesting + protected RssShuffleManagerBase( + SparkConf conf, + boolean isDriver, + DataPusher dataPusher, + Map> taskToSuccessBlockIds, + Map taskToFailedBlockSendTracker) { + this.sparkConf = conf; + this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); + this.rssConf = RssSparkConfig.toRssConf(sparkConf); + this.dataDistributionType = rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE); + this.blockIdLayout = BlockIdLayout.from(rssConf); + this.maxConcurrencyPerPartitionToWrite = rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE); + this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4); + this.speculation = sparkConf.getBoolean("spark.speculation", false); + // configureBlockIdLayout requires maxFailures and speculation to be initialized + configureBlockIdLayout(sparkConf, rssConf); + this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL); + this.heartbeatTimeout = + sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); + this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA); + this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE); + this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ); + this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED); + LOG.info( + "Check quorum config [" + + dataReplica + + ":" + + dataReplicaWrite + + ":" + + dataReplicaRead + + ":" + + dataReplicaSkipEnabled + + "]"); + RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); + + this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE); + this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE); + createShuffleWriteClient(); + + this.taskToSuccessBlockIds = taskToSuccessBlockIds; + this.heartBeatScheduledExecutorService = null; + this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker; + this.dataPusher = dataPusher; + this.partitionReassignMaxServerNum = + rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); + this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(); } public BlockIdManager getBlockIdManager() { @@ -145,14 +393,35 @@ public BlockIdManager getBlockIdManager() { @Override public boolean unregisterShuffle(int shuffleId) { - if (blockIdManager != null) { - blockIdManager.remove(shuffleId); + try { + if (blockIdManager != null) { + blockIdManager.remove(shuffleId); + } + if (SparkEnv.get().executorId().equals("driver")) { + shuffleWriteClient.unregisterShuffle(getAppId(), shuffleId); + shuffleIdToPartitionNum.remove(shuffleId); + shuffleIdToNumMapTasks.remove(shuffleId); + if (service != null) { + service.unregisterShuffle(shuffleId); + } + } + } catch (Exception e) { + LOG.warn("Errors on unregistering from remote shuffle-servers", e); } return true; } - /** See static overload of this method. */ - public abstract void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf); + /** + * Derives block id layout config from maximum number of allowed partitions. Computes the number + * of required bits for partition id and task attempt id and reserves remaining bits for sequence + * number. + * + * @param sparkConf Spark config providing max partitions + * @param rssConf Rss config to amend + */ + public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) { + configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation); + } /** * Derives block id layout config from maximum number of allowed partitions. This value can be set @@ -344,7 +613,10 @@ private static void configureBlockIdLayoutFromLayoutConfig( } /** See static overload of this method. */ - public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo); + public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) { + return getTaskAttemptIdForBlockId( + mapIndex, attemptNo, maxFailures, speculation, blockIdLayout.taskAttemptIdBits); + } /** * Provides a task attempt id to be used in the block id, that is unique for a shuffle stage. @@ -809,7 +1081,8 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( LOG.info( "Register the new partition->servers assignment on reassign. {}", newServerToPartitions); - registerShuffleServers(id.get(), shuffleId, newServerToPartitions, getRemoteStorageInfo()); + registerShuffleServers( + getAppId(), shuffleId, newServerToPartitions, getRemoteStorageInfo()); } LOG.info( @@ -852,8 +1125,76 @@ public void stop() { && managerClientSupplier instanceof ExpiringCloseableSupplier) { ((ExpiringCloseableSupplier) managerClientSupplier).close(); } + if (heartBeatScheduledExecutorService != null) { + heartBeatScheduledExecutorService.shutdownNow(); + } + if (shuffleWriteClient != null) { + // Unregister shuffle before closing shuffle write client. + shuffleWriteClient.unregisterShuffle(getAppId()); + shuffleWriteClient.close(); + } + if (dataPusher != null) { + try { + dataPusher.close(); + } catch (IOException e) { + LOG.warn("Errors on closing data pusher", e); + } + } + + if (shuffleManagerServer != null) { + try { + shuffleManagerServer.stop(); + } catch (InterruptedException e) { + // ignore + LOG.info("shuffle manager server is interrupted during stop"); + } + } + } + + /** @return the unique spark id for rss shuffle */ + @Override + public String getAppId() { + return id.get(); + } + + @Override + public int getPartitionNum(int shuffleId) { + return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0); } + /** + * @param shuffleId the shuffle id to query + * @return the num of map tasks for current shuffle with shuffle id. + */ + @Override + public int getNumMaps(int shuffleId) { + return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0); + } + + @VisibleForTesting + public void addSuccessBlockIds(String taskId, Set blockIds) { + if (taskToSuccessBlockIds.get(taskId) == null) { + taskToSuccessBlockIds.put(taskId, Sets.newHashSet()); + } + taskToSuccessBlockIds.get(taskId).addAll(blockIds); + } + + @VisibleForTesting + public void addFailedBlockSendTracker( + String taskId, FailedBlockSendTracker failedBlockSendTracker) { + taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker); + } + + /** Create the shuffleWriteClient. */ + protected abstract ShuffleWriteClient createShuffleWriteClient(); + + /** + * Check whether the configuration is supported. + * + * @param sparkConf the sparkConf + */ + protected void checkSupported(SparkConf sparkConf) {} + /** * Creating the shuffleAssignmentInfo from the servers and partitionIds * @@ -931,7 +1272,7 @@ private Map> requestShuffleAssignment( try { ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( - id.get(), + getAppId(), shuffleId, partitionNum, partitionNumPerRange, @@ -949,7 +1290,7 @@ private Map> requestShuffleAssignment( response = reassignmentHandler.apply(response); } registerShuffleServers( - id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo()); + getAppId(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo()); return response.getPartitionToServers(); } catch (Throwable throwable) { throw new RssException("registerShuffle failed!", throwable); @@ -1131,4 +1472,95 @@ public Map sparkConfToMap(SparkConf sparkConf) { public ShuffleWriteClient getShuffleWriteClient() { return shuffleWriteClient; } + + protected synchronized void startHeartbeat() { + shuffleWriteClient.registerApplicationInfo(getAppId(), heartbeatTimeout, user); + if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) && !heartbeatStarted) { + heartBeatScheduledExecutorService.scheduleAtFixedRate( + () -> { + try { + String appId = getAppId(); + shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout); + LOG.info("Finish send heartbeat to coordinator and servers"); + } catch (Exception e) { + LOG.warn("Fail to send heartbeat to coordinator and servers", e); + } + }, + heartbeatInterval / 2, + heartbeatInterval, + TimeUnit.MILLISECONDS); + heartbeatStarted = true; + } + } + + public void clearTaskMeta(String taskId) { + taskToSuccessBlockIds.remove(taskId); + taskToFailedBlockSendTracker.remove(taskId); + } + + @VisibleForTesting + protected void registerCoordinator() { + String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key()); + LOG.info("Start Registering coordinators {}", coordinators); + shuffleWriteClient.registerCoordinators( + coordinators, + this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX), + this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX)); + } + + public Set getFailedBlockIds(String taskId) { + FailedBlockSendTracker blockIdsFailedSendTracker = getBlockIdsFailedSendTracker(taskId); + if (blockIdsFailedSendTracker == null) { + return Collections.emptySet(); + } + return blockIdsFailedSendTracker.getFailedBlockIds(); + } + + public Set getSuccessBlockIds(String taskId) { + Set result = taskToSuccessBlockIds.get(taskId); + if (result == null) { + result = Collections.emptySet(); + } + return result; + } + + public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) { + return taskToFailedBlockSendTracker.get(taskId); + } + + public boolean markFailedTask(String taskId) { + LOG.info("Mark the task: {} failed.", taskId); + failedTaskIds.add(taskId); + return true; + } + + public boolean isValidTask(String taskId) { + return !failedTaskIds.contains(taskId); + } + + @VisibleForTesting + public void setDataPusher(DataPusher dataPusher) { + this.dataPusher = dataPusher; + } + + public DataPusher getDataPusher() { + return dataPusher; + } + + @VisibleForTesting + public Map> getTaskToSuccessBlockIds() { + return taskToSuccessBlockIds; + } + + @VisibleForTesting + public Map getTaskToFailedBlockSendTracker() { + return taskToFailedBlockSendTracker; + } + + public CompletableFuture sendData(AddBlockEvent event) { + if (dataPusher != null && event != null) { + return dataPusher.send(event); + } + return new CompletableFuture<>(); + } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index d628272438..8e6b8dfca3 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -17,15 +17,10 @@ package org.apache.spark.shuffle; -import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import scala.Option; import scala.Tuple2; @@ -34,7 +29,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; -import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; @@ -47,8 +41,6 @@ import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; import org.apache.spark.shuffle.reader.RssShuffleReader; -import org.apache.spark.shuffle.writer.AddBlockEvent; -import org.apache.spark.shuffle.writer.DataPusher; import org.apache.spark.shuffle.writer.RssShuffleWriter; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManagerId; @@ -56,205 +48,20 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.uniffle.client.impl.FailedBlockSendTracker; +import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleServerInfo; -import org.apache.uniffle.common.config.RssClientConf; -import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; -import org.apache.uniffle.common.rpc.GrpcServer; -import org.apache.uniffle.common.util.BlockIdLayout; -import org.apache.uniffle.common.util.JavaUtils; -import org.apache.uniffle.common.util.RssUtils; -import org.apache.uniffle.common.util.ThreadUtils; import org.apache.uniffle.shuffle.RssShuffleClientFactory; import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase; -import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService; -import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory; - -import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED; -import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; -import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED; -import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT; -import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE; public class RssShuffleManager extends RssShuffleManagerBase { - private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManager.class); - private final long heartbeatInterval; - private final long heartbeatTimeout; - private ScheduledExecutorService heartBeatScheduledExecutorService; - private Map> taskToSuccessBlockIds = JavaUtils.newConcurrentMap(); - private Map taskToFailedBlockSendTracker = - JavaUtils.newConcurrentMap(); - private final int dataReplica; - private final int dataReplicaWrite; - private final int dataReplicaRead; - private final boolean dataReplicaSkipEnabled; - private final int dataTransferPoolSize; - private final int dataCommitPoolSize; - private Set failedTaskIds = Sets.newConcurrentHashSet(); - private boolean heartbeatStarted = false; - private final int maxFailures; - private final boolean speculation; - private final BlockIdLayout blockIdLayout; - private final String user; - private final String uuid; - private DataPusher dataPusher; - private GrpcServer shuffleManagerServer; - private ShuffleManagerGrpcService service; public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { - if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) { - throw new IllegalArgumentException( - "Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be false."); - } - this.sparkConf = sparkConf; - this.user = sparkConf.get("spark.rss.quota.user", "user"); - this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis())); - this.dynamicConfEnabled = sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED); - - // fetch client conf and apply them if necessary - if (isDriver && this.dynamicConfEnabled) { - fetchAndApplyDynamicConf(sparkConf); - } - RssSparkShuffleUtils.validateRssClientConf(sparkConf); - - // configure block id layout - this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4); - this.speculation = sparkConf.getBoolean("spark.speculation", false); - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - RssUtils.setExtraJavaProperties(rssConf); - // configureBlockIdLayout requires maxFailures and speculation to be initialized - configureBlockIdLayout(sparkConf, rssConf); - this.blockIdLayout = BlockIdLayout.from(rssConf); - - // set & check replica config - this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA); - this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE); - this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ); - this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE); - this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED); - this.maxConcurrencyPerPartitionToWrite = - RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE); - LOG.info( - "Check quorum config [{}:{}:{}:{}]", - dataReplica, - dataReplicaWrite, - dataReplicaRead, - dataReplicaSkipEnabled); - RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); - - this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); - this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL); - this.heartbeatTimeout = - sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); - int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX); - long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX); - int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM); - this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE); - int unregisterThreadPoolSize = - sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE); - int unregisterTimeoutSec = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC); - int unregisterRequestTimeoutSec = - sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC); - // External shuffle service is not supported when using remote shuffle service - sparkConf.set("spark.shuffle.service.enabled", "false"); - LOG.info("Disable external shuffle service in RssShuffleManager."); - // If we store shuffle data in distributed filesystem or in a disaggregated - // shuffle cluster, we don't need shuffle data locality - sparkConf.set("spark.shuffle.reduceLocality.enabled", "false"); - LOG.info("Disable shuffle data locality in RssShuffleManager."); - - this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap(); - this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap(); - // stage retry for write/fetch failure - rssStageRetryForFetchFailureEnabled = - rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED); - rssStageRetryForWriteFailureEnabled = - rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED); - if (rssStageRetryForFetchFailureEnabled || rssStageRetryForWriteFailureEnabled) { - rssStageRetryEnabled = true; - List logTips = new ArrayList<>(); - if (rssStageRetryForWriteFailureEnabled) { - logTips.add("write"); - } - if (rssStageRetryForWriteFailureEnabled) { - logTips.add("fetch"); - } - LOG.info( - "Activate the stage retry mechanism that will resubmit stage on {} failure", - StringUtils.join(logTips, "/")); - } - this.partitionReassignEnabled = rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED); - this.blockIdSelfManagedEnabled = rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED); - this.shuffleManagerRpcServiceEnabled = - partitionReassignEnabled || rssStageRetryEnabled || blockIdSelfManagedEnabled; - if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) { - if (isDriver) { - heartBeatScheduledExecutorService = - ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); - if (shuffleManagerRpcServiceEnabled) { - LOG.info("stage resubmit is supported and enabled"); - // start shuffle manager server - rssConf.set(RPC_SERVER_PORT, 0); - ShuffleManagerServerFactory factory = new ShuffleManagerServerFactory(this, rssConf); - service = factory.getService(); - shuffleManagerServer = factory.getServer(service); - try { - shuffleManagerServer.start(); - // pass this as a spark.rss.shuffle.manager.grpc.port config, so it can be propagated to - // executor properly. - sparkConf.set( - RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, shuffleManagerServer.getPort()); - } catch (Exception e) { - LOG.error("Failed to start shuffle manager server", e); - throw new RssException(e); - } - } - } - if (shuffleManagerRpcServiceEnabled) { - getOrCreateShuffleManagerClientSupplier(); - } - this.shuffleWriteClient = - RssShuffleClientFactory.getInstance() - .createShuffleWriteClient( - RssShuffleClientFactory.newWriteBuilder() - .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled) - .managerClientSupplier(managerClientSupplier) - .clientType(clientType) - .retryMax(retryMax) - .retryIntervalMax(retryIntervalMax) - .heartBeatThreadNum(heartBeatThreadNum) - .replica(dataReplica) - .replicaWrite(dataReplicaWrite) - .replicaRead(dataReplicaRead) - .replicaSkipEnabled(dataReplicaSkipEnabled) - .dataTransferPoolSize(dataTransferPoolSize) - .dataCommitPoolSize(dataCommitPoolSize) - .unregisterThreadPoolSize(unregisterThreadPoolSize) - .unregisterTimeSec(unregisterTimeoutSec) - .unregisterRequestTimeSec(unregisterRequestTimeoutSec) - .rssConf(rssConf)); - registerCoordinator(); - - // for non-driver executor, start a thread for sending shuffle data to shuffle server - LOG.info("RSS data pusher is starting..."); - int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); - int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); - this.dataPusher = - new DataPusher( - shuffleWriteClient, - taskToSuccessBlockIds, - taskToFailedBlockSendTracker, - failedTaskIds, - poolSize, - keepAliveTime); - } - this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); - this.rssStageResubmitManager = new RssStageResubmitManager(); + super(sparkConf, isDriver); } // This method is called in Spark driver side, @@ -380,42 +187,6 @@ public ShuffleHandle registerShuffle( return new RssShuffleHandle(shuffleId, appId, numMaps, dependency, hdlInfoBd); } - private void startHeartbeat() { - shuffleWriteClient.registerApplicationInfo(appId, heartbeatTimeout, user); - if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) && !heartbeatStarted) { - heartBeatScheduledExecutorService.scheduleAtFixedRate( - () -> { - try { - shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout); - LOG.info("Finish send heartbeat to coordinator and servers"); - } catch (Exception e) { - LOG.warn("Fail to send heartbeat to coordinator and servers", e); - } - }, - heartbeatInterval / 2, - heartbeatInterval, - TimeUnit.MILLISECONDS); - heartbeatStarted = true; - } - } - - @VisibleForTesting - protected void registerCoordinator() { - String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key()); - LOG.info("Registering coordinators {}", coordinators); - shuffleWriteClient.registerCoordinators( - coordinators, - this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX), - this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX)); - } - - public CompletableFuture sendData(AddBlockEvent event) { - if (dataPusher != null && event != null) { - return dataPusher.send(event); - } - return new CompletableFuture<>(); - } - // This method is called in Spark executor, // getting information from Spark driver via the ShuffleHandle. @Override @@ -447,24 +218,6 @@ public ShuffleWriter getWriter( } } - /** - * Derives block id layout config from maximum number of allowed partitions. Computes the number - * of required bits for partition id and task attempt id and reserves remaining bits for sequence - * number. - * - * @param sparkConf Spark config providing max partitions - * @param rssConf Rss config to amend - */ - public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) { - configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation); - } - - @Override - public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) { - return getTaskAttemptIdForBlockId( - mapIndex, attemptNo, maxFailures, speculation, blockIdLayout.taskAttemptIdBits); - } - // This method is called in Spark executor, // getting information from Spark driver via the ShuffleHandle. @Override @@ -563,44 +316,6 @@ public ShuffleReader getReader( return null; } - @Override - public boolean unregisterShuffle(int shuffleId) { - try { - super.unregisterShuffle(shuffleId); - if (SparkEnv.get().executorId().equals("driver")) { - shuffleWriteClient.unregisterShuffle(appId, shuffleId); - shuffleIdToNumMapTasks.remove(shuffleId); - shuffleIdToPartitionNum.remove(shuffleId); - if (service != null) { - service.unregisterShuffle(shuffleId); - } - } - } catch (Exception e) { - LOG.warn("Errors on unregistering from remote shuffle-servers", e); - } - return true; - } - - @Override - public void stop() { - super.stop(); - if (heartBeatScheduledExecutorService != null) { - heartBeatScheduledExecutorService.shutdownNow(); - } - if (dataPusher != null) { - try { - dataPusher.close(); - } catch (IOException e) { - LOG.warn("Errors on closing data pusher", e); - } - } - if (shuffleWriteClient != null) { - // Unregister shuffle before closing shuffle write client. - shuffleWriteClient.unregisterShuffle(appId); - shuffleWriteClient.close(); - } - } - @Override public ShuffleBlockResolver shuffleBlockResolver() { throw new RssException("RssShuffleManager.shuffleBlockResolver is not implemented"); @@ -631,93 +346,17 @@ private Roaring64NavigableMap getExpectedTasks( return taskIdBitmap; } - public Set getFailedBlockIds(String taskId) { - FailedBlockSendTracker blockIdsFailedSendTracker = getBlockIdsFailedSendTracker(taskId); - if (blockIdsFailedSendTracker == null) { - return Collections.emptySet(); - } - return blockIdsFailedSendTracker.getFailedBlockIds(); - } - - public Set getSuccessBlockIds(String taskId) { - Set result = taskToSuccessBlockIds.get(taskId); - if (result == null) { - result = Collections.emptySet(); - } - return result; - } - - @VisibleForTesting - public void addSuccessBlockIds(String taskId, Set blockIds) { - if (taskToSuccessBlockIds.get(taskId) == null) { - taskToSuccessBlockIds.put(taskId, Sets.newHashSet()); - } - taskToSuccessBlockIds.get(taskId).addAll(blockIds); - } - - @VisibleForTesting - public void addFailedBlockSendTracker( - String taskId, FailedBlockSendTracker failedBlockSendTracker) { - taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker); - } - - public void clearTaskMeta(String taskId) { - taskToSuccessBlockIds.remove(taskId); - taskToFailedBlockSendTracker.remove(taskId); - } - - @VisibleForTesting - public SparkConf getSparkConf() { - return sparkConf; - } - @VisibleForTesting public void setAppId(String appId) { this.appId = appId; } - public boolean markFailedTask(String taskId) { - LOG.info("Mark the task: {} failed.", taskId); - failedTaskIds.add(taskId); - return true; - } - - public boolean isValidTask(String taskId) { - return !failedTaskIds.contains(taskId); - } - - public DataPusher getDataPusher() { - return dataPusher; - } - - public void setDataPusher(DataPusher dataPusher) { - this.dataPusher = dataPusher; - } - /** @return the unique spark id for rss shuffle */ @Override public String getAppId() { return appId; } - /** - * @param shuffleId the shuffleId to query - * @return the num of partitions(a.k.a reduce tasks) for shuffle with shuffle id. - */ - @Override - public int getPartitionNum(int shuffleId) { - return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0); - } - - /** - * @param shuffleId the shuffle id to query - * @return the num of map tasks for current shuffle with shuffle id. - */ - @Override - public int getNumMaps(int shuffleId) { - return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0); - } - private Roaring64NavigableMap getShuffleResult( String clientType, Set shuffleServerInfoSet, @@ -740,10 +379,6 @@ private Roaring64NavigableMap getShuffleResult( } } - public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) { - return taskToFailedBlockSendTracker.get(taskId); - } - private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) { Set faultyServerIds = Sets.newHashSet(faultyShuffleServerId); faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList()); @@ -754,4 +389,44 @@ private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffl } return null; } + + @Override + protected ShuffleWriteClient createShuffleWriteClient() { + int unregisterThreadPoolSize = + sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE); + int unregisterTimeoutSec = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC); + int unregisterRequestTimeoutSec = + sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC); + long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX); + int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM); + + final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX); + return RssShuffleClientFactory.getInstance() + .createShuffleWriteClient( + RssShuffleClientFactory.newWriteBuilder() + .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled) + .managerClientSupplier(managerClientSupplier) + .clientType(clientType) + .retryMax(retryMax) + .retryIntervalMax(retryIntervalMax) + .heartBeatThreadNum(heartBeatThreadNum) + .replica(dataReplica) + .replicaWrite(dataReplicaWrite) + .replicaRead(dataReplicaRead) + .replicaSkipEnabled(dataReplicaSkipEnabled) + .dataTransferPoolSize(dataTransferPoolSize) + .dataCommitPoolSize(dataCommitPoolSize) + .unregisterThreadPoolSize(unregisterThreadPoolSize) + .unregisterTimeSec(unregisterTimeoutSec) + .unregisterRequestTimeSec(unregisterRequestTimeoutSec) + .rssConf(rssConf)); + } + + @Override + protected void checkSupported(SparkConf sparkConf) { + if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) { + throw new IllegalArgumentException( + "Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be false."); + } + } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 37692a05a8..5e2a941029 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -17,16 +17,10 @@ package org.apache.spark.shuffle; -import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -38,7 +32,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.spark.MapOutputTracker; import org.apache.spark.ShuffleDependency; @@ -53,7 +46,6 @@ import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; import org.apache.spark.shuffle.reader.RssShuffleReader; -import org.apache.spark.shuffle.writer.AddBlockEvent; import org.apache.spark.shuffle.writer.DataPusher; import org.apache.spark.shuffle.writer.RssShuffleWriter; import org.apache.spark.sql.internal.SQLConf; @@ -64,6 +56,7 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking; +import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.RemoteStorageInfo; @@ -73,235 +66,16 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; -import org.apache.uniffle.common.rpc.GrpcServer; -import org.apache.uniffle.common.util.BlockIdLayout; -import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.RssUtils; -import org.apache.uniffle.common.util.ThreadUtils; import org.apache.uniffle.shuffle.RssShuffleClientFactory; import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase; -import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService; -import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory; - -import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED; -import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM; -import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; -import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED; -import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT; -import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE; public class RssShuffleManager extends RssShuffleManagerBase { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManager.class); - private final long heartbeatInterval; - private final long heartbeatTimeout; - private final int dataReplica; - private final int dataReplicaWrite; - private final int dataReplicaRead; - private final boolean dataReplicaSkipEnabled; - private final int dataTransferPoolSize; - private final int dataCommitPoolSize; - private final Map> taskToSuccessBlockIds; - private final Map taskToFailedBlockSendTracker; - private ScheduledExecutorService heartBeatScheduledExecutorService; - private boolean heartbeatStarted = false; - private final BlockIdLayout blockIdLayout; - private final int maxFailures; - private final boolean speculation; - private String user; - private String uuid; - private Set failedTaskIds = Sets.newConcurrentHashSet(); - private DataPusher dataPusher; - private ShuffleManagerGrpcService service; - private GrpcServer shuffleManagerServer; public RssShuffleManager(SparkConf conf, boolean isDriver) { - this.sparkConf = conf; - boolean supportsRelocation = - Optional.ofNullable(SparkEnv.get()) - .map(env -> env.serializer().supportsRelocationOfSerializedObjects()) - .orElse(true); - if (!supportsRelocation) { - LOG.warn( - "RSSShuffleManager requires a serializer which supports relocations of serialized object. Please set " - + "spark.serializer to org.apache.spark.serializer.KryoSerializer instead"); - } - this.user = sparkConf.get("spark.rss.quota.user", "user"); - this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis())); - this.dynamicConfEnabled = sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED); - - // fetch client conf and apply them if necessary - if (isDriver && this.dynamicConfEnabled) { - fetchAndApplyDynamicConf(sparkConf); - } - RssSparkShuffleUtils.validateRssClientConf(sparkConf); - - // set & check replica config - this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA); - this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE); - this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ); - this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED); - LOG.info( - "Check quorum config [" - + dataReplica - + ":" - + dataReplicaWrite - + ":" - + dataReplicaRead - + ":" - + dataReplicaSkipEnabled - + "]"); - RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); - - this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL); - this.heartbeatTimeout = - sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); - final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX); - this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); + super(conf, isDriver); this.dataDistributionType = getDataDistributionType(sparkConf); - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - RssUtils.setExtraJavaProperties(rssConf); - this.maxConcurrencyPerPartitionToWrite = rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE); - this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4); - this.speculation = sparkConf.getBoolean("spark.speculation", false); - // configureBlockIdLayout requires maxFailures and speculation to be initialized - configureBlockIdLayout(sparkConf, rssConf); - this.blockIdLayout = BlockIdLayout.from(rssConf); - this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE); - this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE); - // External shuffle service is not supported when using remote shuffle service - sparkConf.set("spark.shuffle.service.enabled", "false"); - sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false"); - sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true"); - LOG.info("Disable external shuffle service in RssShuffleManager."); - sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false"); - LOG.info("Disable local shuffle reader in RssShuffleManager."); - // If we store shuffle data in distributed filesystem or in a disaggregated - // shuffle cluster, we don't need shuffle data locality - sparkConf.set("spark.shuffle.reduceLocality.enabled", "false"); - LOG.info("Disable shuffle data locality in RssShuffleManager."); - taskToSuccessBlockIds = JavaUtils.newConcurrentMap(); - taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap(); - this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap(); - this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap(); - this.partitionReassignEnabled = rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED); - - // stage retry for write/fetch failure - rssStageRetryForFetchFailureEnabled = - rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED); - rssStageRetryForWriteFailureEnabled = - rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED); - if (rssStageRetryForFetchFailureEnabled || rssStageRetryForWriteFailureEnabled) { - rssStageRetryEnabled = true; - List logTips = new ArrayList<>(); - if (rssStageRetryForWriteFailureEnabled) { - logTips.add("write"); - } - if (rssStageRetryForWriteFailureEnabled) { - logTips.add("fetch"); - } - LOG.info( - "Activate the stage retry mechanism that will resubmit stage on {} failure", - StringUtils.join(logTips, "/")); - } - - // The feature of partition reassign is exclusive with multiple replicas and stage retry. - if (partitionReassignEnabled) { - if (dataReplica > 1) { - throw new RssException( - "The feature of task partition reassign is incompatible with multiple replicas mechanism."); - } - } - - this.blockIdSelfManagedEnabled = rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED); - this.shuffleManagerRpcServiceEnabled = - partitionReassignEnabled || rssStageRetryEnabled || blockIdSelfManagedEnabled; - if (isDriver) { - heartBeatScheduledExecutorService = - ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); - if (shuffleManagerRpcServiceEnabled) { - LOG.info("stage resubmit is supported and enabled"); - // start shuffle manager server - rssConf.set(RPC_SERVER_PORT, 0); - ShuffleManagerServerFactory factory = new ShuffleManagerServerFactory(this, rssConf); - service = factory.getService(); - shuffleManagerServer = factory.getServer(service); - try { - shuffleManagerServer.start(); - // pass this as a spark.rss.shuffle.manager.grpc.port config, so it can be propagated to - // executor properly. - sparkConf.set( - RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, shuffleManagerServer.getPort()); - } catch (Exception e) { - LOG.error("Failed to start shuffle manager server", e); - throw new RssException(e); - } - } - } - if (shuffleManagerRpcServiceEnabled) { - getOrCreateShuffleManagerClientSupplier(); - } - int unregisterThreadPoolSize = - sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE); - int unregisterTimeoutSec = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC); - int unregisterRequestTimeoutSec = - sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC); - long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX); - int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM); - shuffleWriteClient = - RssShuffleClientFactory.getInstance() - .createShuffleWriteClient( - RssShuffleClientFactory.newWriteBuilder() - .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled) - .managerClientSupplier(managerClientSupplier) - .clientType(clientType) - .retryMax(retryMax) - .retryIntervalMax(retryIntervalMax) - .heartBeatThreadNum(heartBeatThreadNum) - .replica(dataReplica) - .replicaWrite(dataReplicaWrite) - .replicaRead(dataReplicaRead) - .replicaSkipEnabled(dataReplicaSkipEnabled) - .dataTransferPoolSize(dataTransferPoolSize) - .dataCommitPoolSize(dataCommitPoolSize) - .unregisterThreadPoolSize(unregisterThreadPoolSize) - .unregisterTimeSec(unregisterTimeoutSec) - .unregisterRequestTimeSec(unregisterRequestTimeoutSec) - .rssConf(rssConf)); - registerCoordinator(); - - LOG.info("Rss data pusher is starting..."); - int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); - int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); - this.dataPusher = - new DataPusher( - shuffleWriteClient, - taskToSuccessBlockIds, - taskToFailedBlockSendTracker, - failedTaskIds, - poolSize, - keepAliveTime); - this.partitionReassignMaxServerNum = - rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); - this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); - this.rssStageResubmitManager = new RssStageResubmitManager(); - } - - public CompletableFuture sendData(AddBlockEvent event) { - if (dataPusher != null && event != null) { - return dataPusher.send(event); - } - return new CompletableFuture<>(); - } - - @VisibleForTesting - protected static ShuffleDataDistributionType getDataDistributionType(SparkConf sparkConf) { - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - if ((boolean) sparkConf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED()) - && !rssConf.containsKey(RssClientConf.DATA_DISTRIBUTION_TYPE.key())) { - return ShuffleDataDistributionType.LOCAL_ORDER; - } - - return rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE); } // For testing only @@ -312,72 +86,7 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s DataPusher dataPusher, Map> taskToSuccessBlockIds, Map taskToFailedBlockSendTracker) { - this.sparkConf = conf; - this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - this.dataDistributionType = rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE); - this.blockIdLayout = BlockIdLayout.from(rssConf); - this.maxConcurrencyPerPartitionToWrite = rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE); - this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4); - this.speculation = sparkConf.getBoolean("spark.speculation", false); - // configureBlockIdLayout requires maxFailures and speculation to be initialized - configureBlockIdLayout(sparkConf, rssConf); - this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL); - this.heartbeatTimeout = - sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); - this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA); - this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE); - this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ); - this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED); - LOG.info( - "Check quorum config [" - + dataReplica - + ":" - + dataReplicaWrite - + ":" - + dataReplicaRead - + ":" - + dataReplicaSkipEnabled - + "]"); - RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); - - int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX); - long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX); - int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM); - this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE); - this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE); - int unregisterThreadPoolSize = - sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE); - int unregisterTimeoutSec = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC); - int unregisterRequestTimeoutSec = - sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC); - shuffleWriteClient = - RssShuffleClientFactory.getInstance() - .createShuffleWriteClient( - RssShuffleClientFactory.getInstance() - .newWriteBuilder() - .clientType(clientType) - .retryMax(retryMax) - .retryIntervalMax(retryIntervalMax) - .heartBeatThreadNum(heartBeatThreadNum) - .replica(dataReplica) - .replicaWrite(dataReplicaWrite) - .replicaRead(dataReplicaRead) - .replicaSkipEnabled(dataReplicaSkipEnabled) - .dataTransferPoolSize(dataTransferPoolSize) - .dataCommitPoolSize(dataCommitPoolSize) - .unregisterThreadPoolSize(unregisterThreadPoolSize) - .unregisterTimeSec(unregisterTimeoutSec) - .unregisterRequestTimeSec(unregisterRequestTimeoutSec) - .rssConf(rssConf)); - this.taskToSuccessBlockIds = taskToSuccessBlockIds; - this.heartBeatScheduledExecutorService = null; - this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker; - this.dataPusher = dataPusher; - this.partitionReassignMaxServerNum = - rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); - this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); - this.rssStageResubmitManager = new RssStageResubmitManager(); + super(conf, isDriver, dataPusher, taskToSuccessBlockIds, taskToFailedBlockSendTracker); } // This method is called in Spark driver side, @@ -527,17 +236,6 @@ public ShuffleWriter getWriter( context); } - @Override - public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) { - configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation); - } - - @Override - public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) { - return getTaskAttemptIdForBlockId( - mapIndex, attemptNo, maxFailures, speculation, blockIdLayout.taskAttemptIdBits); - } - public void setPusherAppId(RssShuffleHandle rssShuffleHandle) { // todo: this implement is tricky, we should refactor it if (id.get() == null) { @@ -845,127 +543,52 @@ private Roaring64NavigableMap getExpectedTasksByRange( return taskIdBitmap; } - @Override - public boolean unregisterShuffle(int shuffleId) { - try { - super.unregisterShuffle(shuffleId); - if (SparkEnv.get().executorId().equals("driver")) { - shuffleWriteClient.unregisterShuffle(id.get(), shuffleId); - shuffleIdToPartitionNum.remove(shuffleId); - shuffleIdToNumMapTasks.remove(shuffleId); - if (service != null) { - service.unregisterShuffle(shuffleId); - } - } - } catch (Exception e) { - LOG.warn("Errors on unregistering from remote shuffle-servers", e); - } - return true; - } - @Override public ShuffleBlockResolver shuffleBlockResolver() { throw new RssException("RssShuffleManager.shuffleBlockResolver is not implemented"); } @Override - public void stop() { - super.stop(); - if (heartBeatScheduledExecutorService != null) { - heartBeatScheduledExecutorService.shutdownNow(); - } - if (shuffleWriteClient != null) { - // Unregister shuffle before closing shuffle write client. - shuffleWriteClient.unregisterShuffle(getAppId()); - shuffleWriteClient.close(); - } - if (dataPusher != null) { - try { - dataPusher.close(); - } catch (IOException e) { - LOG.warn("Errors on closing data pusher", e); - } - } - - if (shuffleManagerServer != null) { - try { - shuffleManagerServer.stop(); - } catch (InterruptedException e) { - // ignore - LOG.info("shuffle manager server is interrupted during stop"); - } - } - } + protected ShuffleWriteClient createShuffleWriteClient() { + int unregisterThreadPoolSize = + sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE); + int unregisterTimeoutSec = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC); + int unregisterRequestTimeoutSec = + sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC); + long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX); + int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM); - public void clearTaskMeta(String taskId) { - taskToSuccessBlockIds.remove(taskId); - taskToFailedBlockSendTracker.remove(taskId); + final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX); + return RssShuffleClientFactory.getInstance() + .createShuffleWriteClient( + RssShuffleClientFactory.newWriteBuilder() + .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled) + .managerClientSupplier(managerClientSupplier) + .clientType(clientType) + .retryMax(retryMax) + .retryIntervalMax(retryIntervalMax) + .heartBeatThreadNum(heartBeatThreadNum) + .replica(dataReplica) + .replicaWrite(dataReplicaWrite) + .replicaRead(dataReplicaRead) + .replicaSkipEnabled(dataReplicaSkipEnabled) + .dataTransferPoolSize(dataTransferPoolSize) + .dataCommitPoolSize(dataCommitPoolSize) + .unregisterThreadPoolSize(unregisterThreadPoolSize) + .unregisterTimeSec(unregisterTimeoutSec) + .unregisterRequestTimeSec(unregisterRequestTimeoutSec) + .rssConf(rssConf)); } @VisibleForTesting - protected void registerCoordinator() { - String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key()); - LOG.info("Start Registering coordinators {}", coordinators); - shuffleWriteClient.registerCoordinators( - coordinators, - this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX), - this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX)); - } - - private synchronized void startHeartbeat() { - shuffleWriteClient.registerApplicationInfo(id.get(), heartbeatTimeout, user); - if (!heartbeatStarted) { - heartBeatScheduledExecutorService.scheduleAtFixedRate( - () -> { - try { - String appId = id.get(); - shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout); - LOG.info("Finish send heartbeat to coordinator and servers"); - } catch (Exception e) { - LOG.warn("Fail to send heartbeat to coordinator and servers", e); - } - }, - heartbeatInterval / 2, - heartbeatInterval, - TimeUnit.MILLISECONDS); - heartbeatStarted = true; - } - } - - public Set getFailedBlockIds(String taskId) { - FailedBlockSendTracker blockIdsFailedSendTracker = getBlockIdsFailedSendTracker(taskId); - if (blockIdsFailedSendTracker == null) { - return Collections.emptySet(); - } - return blockIdsFailedSendTracker.getFailedBlockIds(); - } - - public Set getSuccessBlockIds(String taskId) { - Set result = taskToSuccessBlockIds.get(taskId); - if (result == null) { - result = Collections.emptySet(); + protected static ShuffleDataDistributionType getDataDistributionType(SparkConf sparkConf) { + RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); + if ((boolean) sparkConf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED()) + && !rssConf.containsKey(RssClientConf.DATA_DISTRIBUTION_TYPE.key())) { + return ShuffleDataDistributionType.LOCAL_ORDER; } - return result; - } - /** @return the unique spark id for rss shuffle */ - @Override - public String getAppId() { - return id.get(); - } - - @Override - public int getPartitionNum(int shuffleId) { - return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0); - } - - /** - * @param shuffleId the shuffle id to query - * @return the num of map tasks for current shuffle with shuffle id. - */ - @Override - public int getNumMaps(int shuffleId) { - return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0); + return rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE); } static class ReadMetrics extends ShuffleReadMetrics { @@ -1019,16 +642,6 @@ public void setAppId(String appId) { this.id = new AtomicReference<>(appId); } - public boolean markFailedTask(String taskId) { - LOG.info("Mark the task: {} failed.", taskId); - failedTaskIds.add(taskId); - return true; - } - - public boolean isValidTask(String taskId) { - return !failedTaskIds.contains(taskId); - } - private Roaring64NavigableMap getShuffleResultForMultiPart( String clientType, Map> serverToPartitions, @@ -1050,23 +663,4 @@ private Roaring64NavigableMap getShuffleResultForMultiPart( managerClientSupplier, e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions); } } - - public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) { - return taskToFailedBlockSendTracker.get(taskId); - } - - @VisibleForTesting - public void setDataPusher(DataPusher dataPusher) { - this.dataPusher = dataPusher; - } - - @VisibleForTesting - public Map> getTaskToSuccessBlockIds() { - return taskToSuccessBlockIds; - } - - @VisibleForTesting - public Map getTaskToFailedBlockSendTracker() { - return taskToFailedBlockSendTracker; - } }