目录
在RPC的领域里摸爬滚打了很长时间,是时候抽身出来看一看其他东西了。顺着SparkEnv初始化的思路继续看,下一个主要组件是广播管理器BroadcastManager。本文就主要讲解Spark中广播机制的实现。
广播变量是Spark两种共享变量中的一种(另一种是累加器)。它适合处理多节点跨Stage的共享数据,特别是输入数据量较大的集合,可以提高效率。
BroadcastManager在SparkEnv中是直接初始化的,其代码逻辑也很短,如下。
代码#11.1 - o.a.s.broadcast.BroadcastManager类
private[spark] class BroadcastManager(
val isDriver: Boolean,
conf: SparkConf,
securityManager: SecurityManager)
extends Logging {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
initialize()
private def initialize() {
synchronized {
if (!initialized) {
broadcastFactory = new TorrentBroadcastFactory
broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
}
}
def stop() {
broadcastFactory.stop()
}
private val nextBroadcastId = new AtomicLong(0)
private[broadcast] val cachedValues = {
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
}
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
}
}
BroadcastManager在构造时有三个参数,分别是isDriver(是否为Driver节点)、conf(对应的SparkConf配置)、securityManager(对应的SecurityManager)。非常简单,不再赘述。
BroadcastManager内有四个属性成员:
initialize()方法做的事情也非常简单,它首先判断BroadcastManager是否已初始化。如果未初始化,就新建广播工厂TorrentBroadcastFactory,将其初始化,然后将初始化标记设为true。
BroadcastManager提供的方法有两个:newBroadcast()方法,用于创建一个新的广播变量;以及unbroadcast()方法,将已存在的广播变量取消广播。它们都是简单地调用了TorrentBroadcastFactory中的同名方法,因此我们必须通过阅读TorrentBroadcastFactory的相关源码,才能了解Spark广播机制的细节。
来看TorrentBroadcastFactory.newBroadcast()方法。
代码#11.2 - o.a.s.broadcast.TorrentBroadcastFactory.newBroadcast()方法
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id)
}
可见只是简单地(真的很简单吗?)创建了一个TorrentBroadcast对象实例,它就是前面一直在说的“广播变量”的庐山真面目。下面我们来仔细研究它。
这个类中的属性不算少哦。
代码#11.3 - o.a.s.broadcast.TorrentBroadcast类的属性成员
@transient private lazy val _value: T = readBroadcastBlock()
@transient private var compressionCodec: Option[CompressionCodec] = _
@transient private var blockSize: Int = _
private val broadcastId = BroadcastBlockId(id)
private val numBlocks: Int = writeBlocks(obj)
private var checksumEnabled: Boolean = false
private var checksums: Array[Int] = _
private def setConf(conf: SparkConf) {
compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
Some(CompressionCodec.createCodec(conf))
} else {
None
}
blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
}
setConf(SparkEnv.get.conf)
上面已经提到在TorrentBroadcast构造时会直接调用writeBlocks()方法,来看一看它的代码。
代码#11.4 - o.a.s.broadcast.TorrentBroadcast.writeBlocks()方法
private def writeBlocks(value: T): Int = {
import StorageLevel._
val blockManager = SparkEnv.get.blockManager
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
blocks.zipWithIndex.foreach { case (block, i) =>
if (checksumEnabled) {
checksums(i) = calcChecksum(block)
}
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
blocks.length
}
这个方法中涉及到了块管理器BlockManager,它是Spark存储子系统中的基础组件,我们现在暂时不考虑它,后面还会对它进行十分详尽的分析。writeBlocks()方法的执行逻辑如下:
上面提到的blockifyObject()、calcChecksum()方法的实现都比较简单,就不再赘述。
先来看readBroadcastBlock()方法。
代码#11.5 - o.a.s.broadcast.TorrentBroadcast.readBroadcastBlock()方法
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) =>
if (blockResult.data.hasNext) {
val x = blockResult.data.next().asInstanceOf[T]
releaseLock(broadcastId)
if (x != null) {
broadcastCache.put(broadcastId, x)
}
x
} else {
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
}
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks()
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
try {
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
if (obj != null) {
broadcastCache.put(broadcastId, obj)
}
obj
} finally {
blocks.foreach(_.dispose())
}
}
}
}
}
其执行逻辑如下:
readBlocks()方法的具体实现如下所示。
代码#11.6 - o.a.s.broadcast.TorrentBroadcast.readBlocks()方法
private def readBlocks(): Array[BlockData] = {
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}
该方法会首先对所有广播数据的piece进行打散,然后对打散之后的每个piece执行以下步骤:
上面单单通过文字叙述可能会令人费解,因此下面画一个标准的Flow chart来描述它的过程。
图#11.1 - 广播数据的读取流程
本文从广播管理器BroadcastManager的初始化入手,揭示了广播变量的本质——TorrentBroadcast,并通过引入块管理器BlockManager的相关知识,详细分析了广播数据的写入和读取流程。
— THE END —