Spark BroadCast 解析

前言
在實際使用中對于一些許多rdd需要用到的大的只讀數(shù)據(jù)集變量可以使用共享變量的方式來提高性能,例如查內(nèi)存表,默認情況下會每個task都保存一份,這樣太浪費資源,所以一般會采用共享變量的方式來查表,代碼中經(jīng)常使用,但還沒細致研究過,這次剛好借著閱讀Spark RDD API源碼的機會來深入解析一下broadcast。

Broadcast代碼還涉及到spark底層存儲代碼BlockManager、BlockId等。

簡介
Broadcast變量使得編程人員在每一臺機器上保存一份只讀類型的變量而不需要為每一個task保存一份。在為每一個節(jié)點保存一份較大的輸入數(shù)據(jù)集時這是一種很高效的手段,另外spark還嘗試用高效的高效broadcast算法去減少通信開銷。

基礎(chǔ)類

abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {

該虛類有兩種實現(xiàn)方式:

對應著兩種網(wǎng)絡(luò)協(xié)議類型,http協(xié)議和比特流bittorrent協(xié)議。

BroadcastFactory接口用來初始化和新建不同類型的broadcast變量,sparkContext會為不同用戶產(chǎn)生特定的broadcast變量。

trait BroadcastFactory {

一共有下列方法:

該接口也有兩種繼承方式:

BroadcastManager負責具體的broadcast的初始化、刪除和管理工作

private[spark] class BroadcastManager(
    val isDriver: Boolean,
    conf: SparkConf,
    securityManager: SecurityManager)
  extends Logging {

對應的方法和變量有:


bitTorrent-like broadcast

這里先簡單介紹下比特流協(xié)議:

比特流Bit-torrent是一種內(nèi)容分發(fā)協(xié)議,有布拉姆科恩自主開發(fā)。它采用高效的軟件分發(fā)系統(tǒng)和P2P技術(shù)共享大體積文件(如一部電影或電視節(jié)目),并使每個用戶像網(wǎng)絡(luò)重新分配結(jié)點那樣提供上傳服務。一般的下載服務器為每一個發(fā)出下載請求的用戶提供下載服務,而bitTorrent的工作方式與之不同。分配器或文件的持有者將文件發(fā)送給其中一名用戶,再由這名用戶轉(zhuǎn)發(fā)給其他用戶,用戶之間相互轉(zhuǎn)發(fā)自己所擁有的文件部分,直到每個用戶的下載全部完成。這種方法可以使下載服務器同時處理多個大體積文件的下載請求,而無需占用大量帶寬。

首先是TorrentBroadcastFactory:

class TorrentBroadcastFactory extends BroadcastFactory {
  override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }

  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
    new TorrentBroadcast[T](value_, id)
  }

  override def stop() { }

  /**
   * Remove all persisted state associated with the torrent broadcast with the given ID.
   * @param removeFromDriver Whether to remove state from the driver.
   * @param blocking Whether to block until unbroadcasted
   */
  override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
  }
}

 

5個功能函數(shù):

注意Initialize和stop都是空函數(shù),沒有實際的操作。

TorrentBroadcast是重點:

private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {

繼承自Broadcast類,spark命名空間下的私有類

代碼結(jié)構(gòu):

注意Object TorrentBroadcast中的方法。

下面開始詳細分析這個類

該類是對org.apache.spark.broadcast.Broadcast類的一種類似比特流形式的實現(xiàn),具體機制如下:

Driver將序列化后的對象切分成許多小塊,將這些小塊保存在driver的BlockManager中。在每個executor上,每個executor首先嘗試從自己的本地BlockManager上去獲取這些小塊,如果不存在,就會從driver或者其他的executor上去獲取,一旦獲取到了目標小塊,該executor就會將小塊保存在自己的BlockManager中,等待被其他的executor獲取。

這種機制使得在driver發(fā)送多份broadcast數(shù)據(jù)時(對每一個executor而言)避免成為系統(tǒng)的瓶頸,如果采用前面提到的org.apache.spark.broadcast.HttpBroadcast方式的話就使得driver成為整個系統(tǒng)的瓶頸了。

在初始化的時候,TorrentBroadcast 對象會去讀取SparkEnv.get.conf。

Executor上的broadcast的對應值,值由readBroadcastBlock方法獲取,通過讀取存儲在driver或者其他executor上的block獲得,在driver上,只有當真正需要該值時,才會通過blockManager去惰性讀取。

@transient private lazy val _value: T = readBroadcastBlock()

setConf:

通過配置文件獲取是否需要對broadcast進行壓縮,并設(shè)置環(huán)境配置。

private def setConf(conf: SparkConf) {
  compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
    Some(CompressionCodec.createCodec(conf))
  } else {
    None
  }
  // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided
  blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
}
setConf(SparkEnv.get.conf)

writeBlocks:

/**
 * Divide the object into multiple blocks and put those blocks in the block manager.
 * @param value the object to divide
 * @return number of blocks this broadcast variable is divided into
 */
private def writeBlocks(value: T): Int = {
  // Store a copy of the broadcast variable in the driver so that tasks run on the driver
  // do not create a duplicate copy of the broadcast variable's value.
  SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
    tellMaster = false)
  val blocks =
    TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
  blocks.zipWithIndex.foreach { case (block, i) =>
    SparkEnv.get.blockManager.putBytes(
      BroadcastBlockId(id, "piece" + i),
      block,
      StorageLevel.MEMORY_AND_DISK_SER,
      tellMaster = true)
  }
  blocks.length
}

第一行代碼,putSingle函數(shù)參數(shù)broadcast的Id,具體值value即將要存儲的obj,存儲級別,是否告知Master。

在driver上保存一份broadcast的值,這樣在driver上運行的task就無需再創(chuàng)建一份對應的拷貝了。

由之前可知,在該類中有一個private的TorrentBroadcast的object,第二行就用到了該object的方法blockifyObject。

def blockifyObject[T: ClassTag](
    obj: T,
    blockSize: Int,
    serializer: Serializer,
    compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
  val bos = new ByteArrayChunkOutputStream(blockSize)
  val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
  val ser = serializer.newInstance()
  val serOut = ser.serializeStream(out)
  serOut.writeObject[T](obj).close()
  bos.toArrays.map(ByteBuffer.wrap)
}

入?yún)⒂芯唧w要切分存儲的obj,blockSize默認為4Mb,序列化方法,壓縮方法。最終是將壓縮和序列化后的obj以Byte Array的形式寫入spark的存儲block。

接上面,切分寫完之后,將blocks做zipWithIndex的遍歷,調(diào)用puteBytes方法,將切分好寫入block的每一份broadcast的每一個block都以bytes的形式保存進blockManager之中。

最后返回的是blocks的個數(shù)即一共寫了幾個block。


readBlocks:

/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[ByteBuffer] = {
 // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
 // to the driver, so other executors can pull these chunks from this executor as well.
 val blocks = new Array[ByteBuffer](numBlocks)
 val bm = SparkEnv.get.blockManager

 for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
   val pieceId = BroadcastBlockId(id, "piece" + pid)
   logDebug(s"Reading piece $pieceId of $broadcastId")
   // First try getLocalBytes because there is a chance that previous attempts to fetch the
   // broadcast blocks have already fetched some of the blocks. In that case, some blocks
   // would be available locally (on this executor).
   def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
   def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
     // If we found the block from remote executors/driver's BlockManager, put the block
     // in this executor's BlockManager.
     SparkEnv.get.blockManager.putBytes(
       pieceId,
       block,
       StorageLevel.MEMORY_AND_DISK_SER,
       tellMaster = true)
     block
   }
   val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
     throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
   blocks(pid) = block
 }
 blocks
}

從driver或者executor上獲取所有的blocks,通過blockManager來實現(xiàn),首先在本地local嘗試,沒有的話就從driver或者其他executor上獲取,獲取之后并保存在當前executor的blockManager里面。

歸根結(jié)底是通過指定的broadcastId和并遍歷pieceId利用blockManager的getLocalBytes和getRemoteBytes函數(shù)來獲得對應的block然后通過解壓和反序列化獲取最終我們所需的value。

readBroadcastBlock:

真正的去讀取broadcastBlock的具體value:

private def readBroadcastBlock(): T = Utils.tryOrIOException {
  TorrentBroadcast.synchronized {
    setConf(SparkEnv.get.conf)
    SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
      case Some(x) =>
        x.asInstanceOf[T]

      case None =>
        logInfo("Started reading broadcast variable " + id)
        val startTimeMs = System.currentTimeMillis()
        val blocks = readBlocks()
        logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))

        val obj = TorrentBroadcast.unBlockifyObject[T](
          blocks, SparkEnv.get.serializer, compressionCodec)
        // Store the merged copy in BlockManager so other tasks on this executor don't
        // need to re-fetch it.
        SparkEnv.get.blockManager.putSingle(
          broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
        obj
    }
  }
}

def unBlockifyObject[T: ClassTag](
    blocks: Array[ByteBuffer],
    serializer: Serializer,
    compressionCodec: Option[CompressionCodec]): T = {
  require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
  val is = new SequenceInputStream(
    blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)
  val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
  val ser = serializer.newInstance()
  val serIn = ser.deserializeStream(in)
  val obj = serIn.readObject[T]()
  serIn.close()
  obj
}

與寫block的過程和方法相似,就不詳細介紹了,有一點差別就是這個read操作會真正的將對應的broadcast的值value解壓反序列化讀出來,對應的業(yè)務代碼api就是broadcast變量的value函數(shù),我們上面提到過的惰性求值的那個_value也會觸發(fā)該函數(shù)的執(zhí)行。

另外對于 broadcast的清除包括徹底和非徹底區(qū)別是是否清除driver上內(nèi)容。

剛才一開始講過TorrentBroadcastFactory類,它只要是完成TorrentBroadcast的具體的初始化、停止、實例化等等、該類的實現(xiàn)和實例化是在統(tǒng)一的BroadcastManager中實現(xiàn)的,該類管理者httpBroadcast實例和torrentBroadcast實例。

BroadcastManager:

private[spark] class BroadcastManager(
    val isDriver: Boolean,
    conf: SparkConf,
    securityManager: SecurityManager)
  extends Logging {

  private var initialized = false
  private var broadcastFactory: BroadcastFactory = null
  initialize()
  // Called by SparkContext or Executor before using Broadcast
  private def initialize() {
    synchronized {
      if (!initialized) {
        val broadcastFactoryClass =
          conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")

        broadcastFactory =
          Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

        // Initialize appropriate BroadcastFactory and BroadcastObject
        broadcastFactory.initialize(isDriver, conf, securityManager)

        initialized = true
      }
    }
  }

  def stop() {
    broadcastFactory.stop()
  }

  private val nextBroadcastId = new AtomicLong(0)

  def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
    broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
  }

  def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
  }
}

在該類中會根據(jù)配置文件中指出的類型來實例化具體的broadcastFactory類,考慮到性能問題,默認為torrentBroadcast。

該類的函數(shù)包括broadcast環(huán)境的初始化、新建broadcast實例、停止和清除broadcast等等。

BroadcastManager在SparkEnv.scala中實例化:

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

SparkEnv類負責了一個spark 運行實例(master或者worker)所需要的所有的運行時環(huán)境對象,包括serializer、akka actor system、blockManager、map output tracker等等,目前spark代碼通過一個全局變量來訪問SparkEnv,所以所有的線程都可以訪問同一個SparkEnv。在創(chuàng)建完SparkContext之后可通過SparkEnv.get來訪問。

SparkContext:

具體的某一個broadcast的實例化是在SparkContext.scala中實現(xiàn)的:

/**
 * Broadcast a read-only variable to the cluster, returning a
 * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
 * The variable will be sent to each cluster only once.
 */
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
  assertNotStopped()
  if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
    // This is a warning instead of an exception in order to avoid breaking user programs that
    // might have created RDD broadcast variables but not used them:
    logWarning("Can not directly broadcast RDDs; instead, call collect() and "
      + "broadcast the result (see SPARK-5063)")
  }
  val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
  val callSite = getCallSite
  logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
  cleaner.foreach(_.registerBroadcastForCleanup(bc))
  bc
}

這里也就是我們在業(yè)務代碼中的入口比如:

val bcMiddleTime = sc.broadcast(mapMiddleTime)

mapMiddleTime就是我們需要廣播的value。

httpBroadcast
下面簡單分析下httpBroadcast。

HTTPBroadcastFactory類與之前的torrentBroadcastFactory類似,不過httpBroadcast實現(xiàn)了initialize和stop方法。

HttpBroadcast類:

/**
 * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
 * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a
 * task) is deserialized in the executor, the broadcasted data is fetched from the driver
 * (through a HTTP server running at the driver) and stored in the BlockManager of the
 * executor to speed up future accesses.
 */
private[spark] class HttpBroadcast[T: ClassTag](
    @transient var value_ : T, isLocal: Boolean, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {

httpBroadcast使用的是http協(xié)議來實現(xiàn)broadcast,在一開始廣播變量以task的一部分的形式在executor中被序列化,通過運行在driver上的HTTP server,executor獲取broadcast的data,并將獲取到的data保存在executor的BlockManager中緩存。

代碼架構(gòu):

一開始會將value同步保存在driver的blockManager之中。

若是集群狀態(tài),則將調(diào)用HttpBroadcast單例的write函數(shù)。

HttpBroadcast.synchronized {
  SparkEnv.get.blockManager.putSingle(
    blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}

if (!isLocal) {
  HttpBroadcast.write(id, value_)
}

HttpBroadcast單例的代碼如下:

private[broadcast] object HttpBroadcast extends Logging {
  private var initialized = false
  private var broadcastDir: File = null
  private var compress: Boolean = false
  private var bufferSize: Int = 65536
  private var serverUri: String = null
  private var server: HttpServer = null
  private var securityManager: SecurityManager = null

  // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
  private val files = new TimeStampedHashSet[File]
  private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
  private var compressionCodec: CompressionCodec = null
  private var cleaner: MetadataCleaner = null

  def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
    synchronized {
      if (!initialized) {
        bufferSize = conf.getInt("spark.buffer.size", 65536)
        compress = conf.getBoolean("spark.broadcast.compress", true)
        securityManager = securityMgr
        if (isDriver) {
          createServer(conf)
          conf.set("spark.httpBroadcast.uri", serverUri)
        }
        serverUri = conf.get("spark.httpBroadcast.uri")
        cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
        compressionCodec = CompressionCodec.createCodec(conf)
        initialized = true
      }
    }
  }

  def stop() {
    synchronized {
      if (server != null) {
        server.stop()
        server = null
      }
      if (cleaner != null) {
        cleaner.cancel()
        cleaner = null
      }
      compressionCodec = null
      initialized = false
    }
  }

  private def createServer(conf: SparkConf) {
    broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")
    val broadcastPort = conf.getInt("spark.broadcast.port", 0)
    server =
      new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
    server.start()
    serverUri = server.uri
    logInfo("Broadcast server started at " + serverUri)
  }

  def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name)

  private def write(id: Long, value: Any) {
    val file = getFile(id)
    val fileOutputStream = new FileOutputStream(file)
    Utils.tryWithSafeFinally {
      val out: OutputStream = {
        if (compress) {
          compressionCodec.compressedOutputStream(fileOutputStream)
        } else {
          new BufferedOutputStream(fileOutputStream, bufferSize)
        }
      }
      val ser = SparkEnv.get.serializer.newInstance()
      val serOut = ser.serializeStream(out)
      Utils.tryWithSafeFinally {
        serOut.writeObject(value)
      } {
        serOut.close()
      }
      files += file
    } {
      fileOutputStream.close()
    }
  }

  
  /**
   * Remove all persisted blocks associated with this HTTP broadcast on the executors.
   * If removeFromDriver is true, also remove these persisted blocks on the driver
   * and delete the associated broadcast file.
   */
  def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized {
    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
    if (removeFromDriver) {
      val file = getFile(id)
      files.remove(file)
      deleteBroadcastFile(file)
    }
  }

  /**
   * Periodically clean up old broadcasts by removing the associated map entries and
   * deleting the associated files.
   */
  private def cleanup(cleanupTime: Long) {
    val iterator = files.internalMap.entrySet().iterator()
    while(iterator.hasNext) {
      val entry = iterator.next()
      val (file, time) = (entry.getKey, entry.getValue)
      if (time < cleanupTime) {
        iterator.remove()
        deleteBroadcastFile(file)
      }
    }
  }

  private def deleteBroadcastFile(file: File) {
    try {
      if (file.exists) {
        if (file.delete()) {
          logInfo("Deleted broadcast file: %s".format(file))
        } else {
          logWarning("Could not delete broadcast file: %s".format(file))
        }
      }
    } catch {
      case e: Exception =>
        logError("Exception while deleting broadcast file: %s".format(file), e)
    }
  }

調(diào)用的write函數(shù),首先在broadcastDir目錄下創(chuàng)建一個以broadcastId的name為名稱的文件,然后new出來一個fileOutPutStream實例和一個outPutStream實例,獲取序列化方法將value寫入對應文件,并將文件添加到系統(tǒng)的timeStampedHashSet[File]集合中。

doUnPersist和doDestory與torrentBroadcast類似,不同的是前者在刪除driver上broadcast時會刪除具體文件。

上面提到torrentBroadcast沒有真正實現(xiàn)initialize和stop函數(shù),而httpBroadcast實現(xiàn)了這兩個函數(shù)。

Initialize函數(shù)首先從配置文件中獲取bufferSize 為65536和是否壓縮標志為true,接著會判斷是否是driver,是的話在driver上創(chuàng)建http服務,創(chuàng)建一個臨時文件目錄broadcast來保存廣播變量,服務名稱為HTTP broadcast server。

Httpbroadcast在實現(xiàn)時對于value沒有做實際意義上的讀取操作即對于文件的讀取操作沒有被執(zhí)行,value的值就是構(gòu)建broadcast時傳入的value,因為executor都是從driver上通過http服務來獲取的,所以driver在構(gòu)建broadcast時的value就直接拿來作為后來讀取的value了,個人是這么理解的。

接著是創(chuàng)建metaDataCleaner實例和壓縮實例。

cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)

metaDataCleaner實例的入?yún)⒈砻餍枰宄臄?shù)據(jù)類型和清理函數(shù),這個實例會在后臺起一個time task來定期清理那些老的過時的數(shù)據(jù),傳入的清理函數(shù)cleanUp主要是用來清理之前的broadcast Files。

Stop函數(shù)包括了http server的stop、cleaner和壓縮實例的清除。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內(nèi)容