[spark] DAGScheduler 提交stage源码解析

DAGScheduler在划分完Stage后([spark] DAGScheduler划分stage源码解析 ),将会通过submitStage(finalStage)来提交stage:

 private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    if (jobId.isDefined) {
      logDebug("submitStage(" + stage + ")")
      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
        //获取未计算完的parentStage,判断是否计算完的条件是
        //_numAvailableOutputs == numPartitions,既有效输出个数是否等于分区数。
        //根据stageid从小到大排序,是因为越前面的stageid越小。
        val missing = getMissingParentStages(stage).sortBy(_.id)
        logDebug("missing: " + missing)
        if (missing.isEmpty) { 
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          submitMissingTasks(stage, jobId.get) //若当前stage没有任何依赖或者所有依赖都已经准备好,则提交task。
        } else {
          //若有未提交的父Stage,则递归提交父Stage
          //标记当前stage为waitingStages ,先等待父stage执行完。
          for (parent <- missing) {
            submitStage(parent) 
          }
          waitingStages += stage
        }
      }
    } else {
      abortStage(stage, "No active job for stage " + stage.id, None)
    }
  }

看看getMissingParentStages的实现:

private def getMissingParentStages(stage: Stage): List[Stage] = {
    val missing = new HashSet[Stage] //未计算完的stage
    val visited = new HashSet[RDD[_]] //被访问过的stage
    // We are manually maintaining a stack here to prevent StackOverflowError
    // caused by recursively visiting
    val waitingForVisit = new Stack[RDD[_]] //等待被访问的stage
    def visit(rdd: RDD[_]) {
      if (!visited(rdd)) {
        visited += rdd
        //先判断是否有未cache的分区,若全部都被cache了就不用计算parent Stage了。
        //遍历rdd的所有依赖,当是宽依赖时获取其对应依赖的宽依赖并判断该stage是否可用。
        //判断条件是该stage输出个数是否等于该stage的finalRDD分区数。
        //不等于时说明还有未计算的分区,则将该stage加入missing;
        //若为窄依赖则继续往上遍历。
        val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
        if (rddHasUncachedPartitions) {
          for (dep <- rdd.dependencies) {
            dep match {
              case shufDep: ShuffleDependency[_, _, _] =>
                val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
                if (!mapStage.isAvailable) {
                  missing += mapStage
                }
              case narrowDep: NarrowDependency[_] =>
                waitingForVisit.push(narrowDep.rdd)
            }
          }
        }
      }
    }
    waitingForVisit.push(stage.rdd)
    while (waitingForVisit.nonEmpty) {
      visit(waitingForVisit.pop())
    }
    missing.toList
  }

若当前stage没有任何依赖或者所有依赖都已经准备好,则提交通过submitMissingTasks来提交task,看看具体实现:

private def submitMissingTasks(stage: Stage, jobId: Int) {
    stage.pendingPartitions.clear()
    // 获取需要计算的分区
    val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
    val properties = jobIdToActiveJob(jobId).properties
    runningStages += stage  // 标记stage为running状态
    ......
    //获取task最佳计算位置
    val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
          val job = s.activeJob.get
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    } catch {
        ...
    }

    stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
    listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

    var taskBinary: Broadcast[Array[Byte]] = null
    try {
      val taskBinaryBytes: Array[Byte] = stage match {
        case stage: ShuffleMapStage =>
          JavaUtils.bufferToArray(
            closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
        case stage: ResultStage =>
          JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
      }
      taskBinary = sc.broadcast(taskBinaryBytes)
    } catch {
       ...
    }
    val tasks: Seq[Task[_]] = try {
      stage match {
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.latestInfo.taskMetrics, properties)
          }
        case stage: ResultStage =>
          val job = stage.activeJob.get
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics)
          }
      }
    } catch {
      ...
    }
    if (tasks.size > 0) {
      stage.pendingPartitions ++= tasks.map(_.partitionId)
      taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
      stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
    } else {  
        ...
    }
  }

下面将对每个步骤详细讲解: stage.findMissingPartitions获取需要计算的分区,不同的stage有不同的实现:

//ShuffleMapStage
//根据partition是否有对应的outputLocs来判断哪些分区需要被计算,计算过的partition会被outputLocs记录
override def findMissingPartitions(): Seq[Int] = {
    val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
    assert(missing.size == numPartitions - _numAvailableOutputs,
      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
    missing
  }

//ResultStage
//计算过的分区会被job记录为finish
 override def findMissingPartitions(): Seq[Int] = {
    val job = activeJob.get
    (0 until job.numPartitions).filter(id => !job.finished(id))
  }

taskIdToLocations获取task最佳计算位置,主要是通过getPreferredLocs方法实现:

private def getPreferredLocsInternal(
      rdd: RDD[_],
      partition: Int,
      visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
    // If the partition has already been visited, no need to re-visit.
    // This avoids exponential path exploration.  SPARK-695
    if (!visited.add((rdd, partition))) {
      // Nil has already been returned for previously visited partitions.
      return Nil
    }
    // If the partition is cached, return the cache locations
    val cached = getCacheLocs(rdd)(partition)
    if (cached.nonEmpty) {
      return cached
    }
    // If the RDD has some placement preferences (as is the case for input RDDs), get those
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {
      return rddPrefs.map(TaskLocation(_))
    }
    // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency
    // that has any placement preferences. Ideally we would choose based on transfer sizes,
    // but this will do for now.
    rdd.dependencies.foreach {
      case n: NarrowDependency[_] =>
        for (inPart <- n.getParents(partition)) {
          val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
          if (locs != Nil) {
            return locs
          }
        }
      case _ =>
    }
    Nil
  }
  • getCacheLocs方法中,cacheLocs 维护着RDD的partitions 的 location信息,该信息是TaskLocation的实例。如果从cacheLocs中获取到partition的location信息直接返回,若获取不到:如果RDD的存储级别为空返回nil,并填入cacheLocs,否则会通过blocakManagerMaster来获取持有该partition信息的 blockManager 并实例化ExecutorCacheTaskLocation放入cacheLocs中。
  • rdd.preferredLocations,该方法先尝试从checpoint中获取partition信息,若未获取到再通过rdd的getPreferredLocations(split)方法获取,不同rdd有不同实现,如HadoopRDD即通过Hadoop InputSplit 来获取当前partition的位置。
  • 前两者都没有获取到时,则通过递归寻找parentRDD的partition的最佳位置信息。注意:只适用于窄依赖。

获取到task最佳位置后,根据不同stage会广播不同序列化后的二进制信息到每个excutor,如果是shuffleMapStage,广播该Stage的FinalRDD和stage的shffleDep;如果是ResultStage,广播Stage的FinalRDD和stage.func。即将task的实际执行逻辑已经序列化到taskBinary中并broadcast到每个executor上。

 var taskBinary: Broadcast[Array[Byte]] = null
    try {
      // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
      // For ResultTask, serialize and broadcast (rdd, func).
      val taskBinaryBytes: Array[Byte] = stage match {
        case stage: ShuffleMapStage =>
          JavaUtils.bufferToArray(
            closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
        case stage: ResultStage =>
          JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
      }

      taskBinary = sc.broadcast(taskBinaryBytes)
    } catch {
    }

根据不同的stage生成不同的类型task,每个partition对应一个task且每个task都包含目标partition的location信息,最终所有tasks将被包装成taskSet进行提交。

stage match {
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.latestInfo.taskMetrics, properties)
          }
        case stage: ResultStage =>
          val job = stage.activeJob.get
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics)
          }
      }

taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) 
    }

至此,DAGScheduler已经完成对stage的划分并以taskSet的形式提交给taskSchecduler,接着由TaskScheduler来提交管理tasks,后序将会推出。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏牛肉圆粉不加葱

[Spark源码剖析] DAGScheduler提交stage

DAGScheduler通过调用submitStage来提交stage,实现如下:

742
来自专栏微信公众号:Java团长

Java类加载器详解(下)

这个类中定义了一个加密和解密的算法,很简单的,就是将字节和oxff异或一下即可,而且这个算法是加密和解密的都可以用,很是神奇呀!

1503
来自专栏数据科学与人工智能

【Spark研究】Lambda表达式让Spark编程更容易

近日,Databricks官方网站发表了一篇博文,用示例说明了lambda表达式如何让Spark编程更容易。文章开头即指出,Spark的主要目标之一是使编写大数...

3205
来自专栏浪淘沙

KafKa 代码实现

1933
来自专栏菩提树下的杨过

Jboss EAP:native management API学习

上一节已经学习了CLI命令行来控制JBOSS,如果想在程序中以编码方式来控制JBOSS,可以参考下面的代码,实际上在前面的文章,用代码控制Jboss上的Data...

2079
来自专栏JavaEdge

Mybatis#BaseExecutor源码解析BaseExecutor源码解析

BaseExecutor是Executor的一个子类,是一个抽象类,实现接口Executor的部分方法,并提供了三个抽象方法

1186
来自专栏个人分享

Spark常用函数(源码阅读六)

  源码层面整理下我们常用的操作RDD数据处理与分析的函数,从而能更好的应用于工作中。

992
来自专栏小勇DW3

Mybatis使用动态代理实现拦截器功能

  拦截器顾名思义为拦截某个功能的一个武器,在众多框架中均有“拦截器”。这个Plugin有什么用呢?或者说拦截器有什么用呢?可以想想拦截器是怎么实现的。Plug...

4092
来自专栏java达人

Spring aop 的代理机制

Spring aop 是通过代理实现的,代理有静态代理,jdk动态代理和cglib动态代理,代理就像我们生活中的房产中介,你不直接与房主,银行接触,而是通过中介...

2079
来自专栏Java帮帮-微信公众号-技术文章全总结

Java面试系列17-编程题-读取服务器字符、实现序列化、计数器、1000阶乘、n出列问题等

一,Java的通信编程,编程题(或问答),用JAVA SOCKET编程,读服务器几个字符,再写入本地显示? Server端程序: package test;...

4698

扫码关注云+社区

领取腾讯云代金券