前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Spark DAGScheduler源码解读1-stage划分

Spark DAGScheduler源码解读1-stage划分

原创
作者头像
幽鸿
发布2020-03-31 11:16:31
4380
发布2020-03-31 11:16:31
举报

首先看DAGScheduler的Job启动方法:

代码语言:javascript
复制
/**
 * Run an action job on the given RDD and pass all the results to the resultHandler function as
 * they arrive.
 * @param rdd target RDD to run tasks on
 * @param func a function to run on each partition of the RDD
 * @param partitions set of partitions to run on; some jobs may not want to compute on all
 *   partitions of the target RDD, e.g. for operations like first()
 * @param callSite where in the user program this job was called
 * @param resultHandler callback to pass each result to
 * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
 *
 * @throws Exception when the job fails
 */
def runJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: CallSite,
    resultHandler: (Int, U) => Unit,
    properties: Properties): Unit = {
  val start = System.nanoTime
  //关键代码:提交job
  val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
  // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`,
  // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that
  // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's
  // safe to pass in null here. For more detail, see SPARK-13747.
  val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
  waiter.completionFuture.ready(Duration.Inf)(awaitPermission)
  waiter.completionFuture.value.get match {
    case scala.util.Success(_) =>
      logInfo("Job %d finished: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
    case scala.util.Failure(exception) =>
      logInfo("Job %d failed: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
      // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
      val callerStackTrace = Thread.currentThread().getStackTrace.tail
      exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
      throw exception
  }
}

将任务提交到scheduler:

代码语言:javascript
复制
/**
 * Submit an action job to the scheduler.
 *
 * @param rdd target RDD to run tasks on
 * @param func a function to run on each partition of the RDD
 * @param partitions set of partitions to run on; some jobs may not want to compute on all
 *   partitions of the target RDD, e.g. for operations like first()
 * @param callSite where in the user program this job was called
 * @param resultHandler callback to pass each result to
 * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
 *
 * @return a JobWaiter object that can be used to block until the job finishes executing
 *         or can be used to cancel the job.
 *
 * @throws IllegalArgumentException when partitions ids are illegal
 */
def submitJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: CallSite,
    resultHandler: (Int, U) => Unit,
    properties: Properties): JobWaiter[U] = {
  // Check to make sure we are not launching a task on a partition that does not exist.
  val maxPartitions = rdd.partitions.length
  partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
    throw new IllegalArgumentException(
      "Attempting to access a non-existent partition: " + p + ". " +
        "Total number of partitions: " + maxPartitions)
  }

  val jobId = nextJobId.getAndIncrement()
  if (partitions.size == 0) {
    // Return immediately if the job is running 0 tasks
    return new JobWaiter[U](this, jobId, 0, resultHandler)
  }

  assert(partitions.size > 0)
  val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
  //关键代码:post job信息
  val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
  eventProcessLoop.post(JobSubmitted(
    jobId, rdd, func2, partitions.toArray, callSite, waiter,
    SerializationUtils.clone(properties)))
  waiter
}

这里JobSubmitted方法是使用样例类实现的,

具体实现如下:

Scheduler在处理提交的Job的时候,会生成ResultStage,如下:

代码语言:javascript
复制
finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)

这里创建一个stage,并且将stage放入scheduler的HashMap中进行管理:

代码语言:javascript
复制
   stageIdToStage(id) = stage
    updateJobIdStageIdMaps(jobId, stage)

第二步,用finalStage创建一个Job:

代码语言:javascript
复制
val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)

第三步,将job加入缓存:

代码语言:javascript
复制
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.setActiveJob(job)

第四步,这里很关键了,提交stage:

代码语言:javascript
复制
submitStage(finalStage)

来来来,接下来就是最核心的stage划分了:

代码语言:javascript
复制
/** 从最后一个stage开始递归计算父stage */
private def submitStage(stage: Stage) {
  val jobId = activeJobForStage(stage)
  if (jobId.isDefined) {
    logDebug("submitStage(" + stage + ")")
    if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
      val missing = getMissingParentStages(stage).sortBy(_.id)
      logDebug("missing: " + missing)
      //这里返回递归调用,直至第一个stage,没有父stage为止,其余的stage都在waitingStages中
      if (missing.isEmpty) {
        logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
        //提交stage的时候会创建一批task,task数量与partition数量相同
        submitMissingTasks(stage, jobId.get)
      } else {
        for (parent <- missing) {
          //这里很巧妙,继续递归调用 parentStage,并同时加入到waitingStages中
          submitStage(parent)
        }
        //这里和上面的line14配合,先提交,后加入waitingStages等待执行队列中
        waitingStages += stage
      }
    }
  } else {
    abortStage(stage, "No active job for stage " + stage.id, None)
  }
}

这里在获取父stage的时候是使用stack来进行实现的:

代码语言:javascript
复制
//stage的划分核心代码
private def getMissingParentStages(stage: Stage): List[Stage] = {
  val missing = new HashSet[Stage]
  val visited = new HashSet[RDD[_]]
  // 使用Stack来进行存储父stage
  val waitingForVisit = new Stack[RDD[_]]
  def visit(rdd: RDD[_]) {
    if (!visited(rdd)) {
      visited += rdd
      val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
      if (rddHasUncachedPartitions) {
        //遍历rdd的依赖
        for (dep <- rdd.dependencies) {
          dep match {
              //宽依赖处理
            case shufDep: ShuffleDependency[_, _, _] =>
              val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
              if (!mapStage.isAvailable) {
                missing += mapStage
              }
              //窄依赖处理
            case narrowDep: NarrowDependency[_] =>
              //如果是窄依赖,直接将rdd重新入栈
              waitingForVisit.push(narrowDep.rdd)
          }
        }
      }
    }
  }
  //stage rdd入栈
  waitingForVisit.push(stage.rdd)
  while (waitingForVisit.nonEmpty) {
    //这里调用第7行自己内部的visit方法  
    visit(waitingForVisit.pop())
  }
  missing.toList
}

特别注意的是,这里再处理宽依赖的时候,getShuffleMapStage方法里会创建宽依依赖stage:

代码语言:javascript
复制
val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)

在这里主要是创建带dependency的shuffleDep:

这样就会导致最后一个stage不是shuffleMapStage,但是之前的都是ShuffleMapStage。这样就实现了stage的划分:对一个stage,如果它的最后一个rdd的所有依赖都是窄依赖,那么就不会创建任何新的stage;如果该stage宽依赖了某个rdd,那么就用宽依赖的那个rdd,创建一个新的stage,然后立即将新的stage返回。

在这里还有个一个核心点,就是task的创建,限于篇幅,另外写篇文章吧。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
大数据
全栈大数据产品,面向海量数据场景,帮助您 “智理无数,心中有数”!
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档