Spark RPC層設計概況
spark2.0的RPC框架是基于優秀的網絡通信框架Netty開發的,我們先把Spark中與RPC相關的一些類的關系梳理一下,為了能夠更直觀地表達RPC的設計,我們先從類的設計來看,如下圖所示:
從上圖左半邊可以看出,RPC通信主要有RpcEnv、RpcEndpoint、RpcEndpointRef這三個核心類。
RpcEndpoint是一個通信端,例如Spark集群中的Master,或Worker,都是一個RpcEndpoint。但是,如果想要與一個RpcEndpoint端進行通信,一定需要獲取到該RpcEndpoint一個RpcEndpointRef,通過RpcEndpointRef與RpcEndpoint進行通信,只能通過一個RpcEnv環境對象來獲取RpcEndpoint對應的RPCEndpointRef。
客戶端通過RpcEndpointRef發消息,首先通過RpcEnv來處理這個消息,找到這個消息具體發給誰,然后路由給RpcEndpoint實體。Spark默認使用更加高效的NettyRpcEnv。下面對這個三個類進行詳細介紹。
RpcEnv
RpcEnv是RPC的環境對象,管理著整個RpcEndpoint的生命周期,其主要功能有:根據name或uri注冊endpoints、管理各種消息的處理、停止endpoints。其中RpcEnv只能通過RpcEnvFactory創建得到。
RpcEnv中有一個核心的方法:
def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
通過上面方法,可以注冊一個RpcEndpoint到RpcEnv環境對象中,由RpcEnv來管理RpcEndpoint到RpcEndpointRef的綁定關系。在注冊RpcEndpoint時,每個RpcEndpoint都需要有一個唯一的名稱。
RpcEndpoint
RpcEndpoint定義了RPC通信過程中的通信端對象,除了具有管理一個RpcEndpoint生命周期的操作(constructor-> onStart -> receive* ->onStop),并給出了通信過程中一個RpcEndpoint所具有的基于事件驅動的行為(連接、斷開、網絡異常),實際上對于Spark框架來說RpcEndpoint主要是接收消息并處理。
RpcEndpoint中有兩個核心方法:
def receive:PartialFunction[Any, Unit]={
case_ =>throw newSparkException(self +" does not implement 'receive'")
}
def receiveAndReply(context:RpcCallContext):PartialFunction[Any, Unit]={
case_ => context.sendFailure(newSparkException(self +" won't reply anything"))
}
通過上面的receive方法,接收由RpcEndpointRef.send方法發送的消息,該類消息不需要進行響應消息(Reply),而只是在RpcEndpoint端進行處理。通過receiveAndReply方法,接收由RpcEndpointRef.ask發送的消息,RpcEndpoint端處理完消息后,需要給調用RpcEndpointRef.ask的通信端響應消息。
RpcEndPointRef
RpcEndpointRef是一個對RpcEndpoint的遠程引用對象,通過它可以向遠程的RpcEndpoint端發送消息以進行通信。RpcEndpointRef特質的定義,代碼如下所示:
private[spark] abstract class RpcEndpointRef(conf: SparkConf) extends Serializable with Logging {
private[this] val maxRetries = RpcUtils.numRetries(conf)
private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
def address: RpcAddress
def name: String
def send(message: Any): Unit
def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)
def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)
def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
... ...
}
}
上面代碼中,send方法發送消息后不等待響應,亦即Send-and-forget。而ask方法發送消息后需要等待通信對端給予響應,通過Future來異步獲取響應結果。
Driver Spark Env中NettyRpcEnv創建
Driver Spark Env是Spark Application中Driver的運行環境,其需要創建很多組件,比如SecurityManager、rpcEnv、broadcastManager、mapOutputTracker、memoryManager、blockTransferService、blockManagerMaster、blockManager、metricsSystem等,由于本文是介紹Spark RPC機制的,估只介紹rpcEnv創建過程及服務啟動過程。從NettyRpcEnv.scala的NettyRpcEnvFactory的Create方法說起
private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
def create(config: RpcEnvConfig): RpcEnv = {
val sparkConf = config.conf
//創建序列化
val javaSerializerInstance = new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
//new 一個NettyRpcEnv實例
val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = {
actualPort => nettyEnv.startServer(actualPort)
(nettyEnv, nettyEnv.address.port)
}
try {
// 根據指定的端口號和主機,啟動Driver Rpc服務
Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
}
catch {
case NonFatal(e) =>
nettyEnv.shutdown()
throw e
}
}
nettyEnv
}
}
NettyRpcEnvFactory繼承RpcEnvFactory并實現其Create方法,create方法中最重要的就是聲明一個NettyRpc實例和啟動服務。
1. 創建NettyRpcEnv
private[netty] class NettyRpcEnv(val conf: SparkConf, javaSerializerInstance: JavaSerializerInstance, host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
// 創建transportConf
private[netty] val transportConf = SparkTransportConf.fromSparkConf(conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0))
//創建Dispatcher,主要用戶消息的分發處理
private val dispatcher: Dispatcher = new Dispatcher(this)
//創建streamManager
private val streamManager = new NettyStreamManager(this)
//創建一個transportContext,主要用于創建Netty的Server和Client,其中Spark將Netty框架進行封裝,以transportContext為外部切入口,與NettyRpcEndpoint等Spark代碼對應,從而創建底層通信的服務端和客戶端。后面會詳細介紹Spark對Netty的封裝。
private val transportContext = new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this, streamManager))
private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
if (securityManager.isAuthenticationEnabled()) {
java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, securityManager.isSaslEncryptionEnabled()))
} else {
java.util.Collections.emptyList[TransportClientBootstrap]
}
}
// 聲明一個clientFactory,用戶創建通信的客戶端
private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
/**
* A separate client factory for file downloads. This avoids using the same RPC handler as
* the main RPC context, so that events caused by these clients are kept isolated from the
* main RPC traffic.
*
* It also allows for different configuration of certain properties, such as the number of
* connections per peer.
*/
@volatile private var fileDownloadFactory: TransportClientFactory = _
//創建一個netty-rpc-env-timeout的守護線程
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
// to implement non-blocking send/ask.
// TODO: a non-blocking TransportClientFactory.createClient in future
private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", conf.getInt("spark.rpc.connect.threads", 64))
@volatile private var server: TransportServer = _
private val stopped = new AtomicBoolean(false)
/**
* A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],
* we just put messages to its [[Outbox]] to implement a non-blocking `send` method.
*/
private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
/**
* Remove the address's Outbox and stop it.
*/
private[netty] def removeOutbox(address: RpcAddress): Unit = {
val outbox = outboxes.remove(address)
if (outbox != null) {
outbox.stop()
}
}
//根據指定端口,啟動transportServer
def startServer(port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if(securityManager.isAuthenticationEnabled()) {
java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
} else {
java.util.Collections.emptyList()
}
//通過transportContext啟動通信底層的服務端
server = transportContext.createServer(host, port, bootstraps)
//注冊一個RpcEndpointVerifier,對Server進行驗證
dispatcher.registerRpcEndpoint(RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
@Nullable override lazy val address: RpcAddress = {
if (server != null)
RpcAddress(host, server.getPort())
else
null
}
//重寫rpcEnv的setupEndpoint方法,用戶rpcEndpoint在rpcEnv上進行注冊
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
dispatcher.registerRpcEndpoint(name, endpoint)
}
def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
val addr = RpcEndpointAddress(uri)
val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
val verifier = new NettyRpcEndpointRef(conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this)
verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => if (find) { Future.successful(endpointRef) } else { Future.failed(new RpcEndpointNotFoundException(uri)) } }(ThreadUtils.sameThread)
}
override def stop(endpointRef: RpcEndpointRef): Unit = {
require(endpointRef.isInstanceOf[NettyRpcEndpointRef])
dispatcher.stop(endpointRef)
}
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
message.sendWith(receiver.client)
} else {
require(receiver.address != null, "Cannot send message to client endpoint with no listen address.")
val targetOutbox = {
val outbox = outboxes.get(receiver.address)
if (outbox == null) {
val newOutbox = new Outbox(this, receiver.address)
val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)
if (oldOutbox == null) {
newOutbox
} else {
oldOutbox
}
} else {
outbox
}
}
if (stopped.get) {
// It's possible that we put `targetOutbox` after stopping. So we need to clean it.
outboxes.remove(receiver.address)
targetOutbox.stop()
} else {
targetOutbox.send(message)
}
}
}
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
try {
dispatcher.postOneWayMessage(message)
}
catch {
case e: RpcEnvStoppedException => logWarning(e.getMessage)
}
} else {
// Message to a remote RPC endpoint.
postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
}
}
private[netty] def createClient(address: RpcAddress): TransportClient = { clientFactory.createClient(address.host, address.port) }
private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
logWarning(s"Ignored failure: $e")
}
}
def onSuccess(reply: Any): Unit = reply match {
case RpcFailure(e) => onFailure(e)
case rpcReply => if (!promise.trySuccess(rpcReply)) { logWarning(s"Ignored message: $reply") }
}
try {
if (remoteAddr == address) {
val p = Promise[Any]()
p.future.onComplete {
case Success(response) => onSuccess(response)
case Failure(e) => onFailure(e)
}(ThreadUtils.sameThread)
dispatcher.postLocalMessage(message, p)
} else {
val rpcMessage = RpcOutboxMessage(serialize(message), onFailure, (client, response) => onSuccess(deserialize[Any](client, response)))
postToOutbox(message.receiver, rpcMessage)
promise.future.onFailure {
case _: TimeoutException => rpcMessage.onTimeout()
case _ =>
}(ThreadUtils.sameThread)
}
val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
override def run(): Unit = {
onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}"))
}
}, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
promise.future.onComplete { v =>
timeoutCancelable.cancel(true)
}(ThreadUtils.sameThread)
} catch {
case NonFatal(e) => onFailure(e)
}
promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}
private[netty] def serialize(content: Any): ByteBuffer = {
javaSerializerInstance.serialize(content)
}
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
NettyRpcEnv.currentClient.withValue(client) {
deserialize {
() => javaSerializerInstance.deserialize[T](bytes)
}
}
}
override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
dispatcher.getRpcEndpointRef(endpoint)
}
override def shutdown(): Unit = {
cleanup()
}
override def awaitTermination(): Unit = {
dispatcher.awaitTermination()
}
private def cleanup(): Unit = {
if (!stopped.compareAndSet(false, true)) {
return
}
val iter = outboxes.values().iterator()
while (iter.hasNext()) {
val outbox = iter.next()
outboxes.remove(outbox.address)
outbox.stop()
}
if (timeoutScheduler != null) {
timeoutScheduler.shutdownNow()
}
if (dispatcher != null) {
dispatcher.stop()
}
if (server != null) {
server.close()
}
if (clientFactory != null) {
clientFactory.close()
}
if (clientConnectionExecutor != null) {
clientConnectionExecutor.shutdownNow()
}
if (fileDownloadFactory != null) {
fileDownloadFactory.close()
}
}
override def deserialize[T](deserializationAction: () => T): T = {
NettyRpcEnv.currentEnv.withValue(this) {
deserializationAction()
}
}
override def fileServer: RpcEnvFileServer = streamManager
override def openChannel(uri: String): ReadableByteChannel = {
val parsedUri = new URI(uri)
require(parsedUri.getHost() != null, "Host name must be defined.")
require(parsedUri.getPort() > 0, "Port must be defined.")
require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.")
val pipe = Pipe.open()
val source = new FileDownloadChannel(pipe.source())
try {
val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
val callback = new FileDownloadCallback(pipe.sink(), source, client)
client.stream(parsedUri.getPath(), callback)
} catch {
case e: Exception =>
pipe.sink().close()
source.close()
throw e
}
source
}
private def downloadClient(host: String, port: Int): TransportClient = {
if (fileDownloadFactory == null)
synchronized {
if (fileDownloadFactory == null) {
val module = "files"
val prefix = "spark.rpc.io."
val clone = conf.clone()
// Copy any RPC configuration that is not overridden in the spark.files namespace.
conf.getAll.foreach {
case (key, value) =>
if (key.startsWith(prefix)) {
val opt = key.substring(prefix.length())
clone.setIfMissing(s"spark.$module.io.$opt", value)
}
}
val ioThreads = clone.getInt("spark.files.io.threads", 1)
val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)
val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)
fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())
}
}
fileDownloadFactory.createClient(host, port)
}
private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
@volatile private var error: Throwable = _
def setError(e: Throwable): Unit = { error = e source.close() }
override def read(dst: ByteBuffer): Int = {
Try(source.read(dst)) match {
case Success(bytesRead) => bytesRead
case Failure(readErr) =>
if (error != null) {
throw error
} else {
throw readErr
}
}
}
override def close(): Unit = source.close()
override def isOpen(): Boolean = source.isOpen()
}
private class FileDownloadCallback(sink: WritableByteChannel, source: FileDownloadChannel, client: TransportClient) extends StreamCallback {
override def onData(streamId: String, buf: ByteBuffer): Unit = {
while (buf.remaining() > 0) {
sink.write(buf)
}
}
override def onComplete(streamId: String): Unit = {
sink.close()
}
override def onFailure(streamId: String, cause: Throwable): Unit = {
logDebug(s"Error downloading stream $streamId.", cause)
source.setError(cause)
sink.close()
}
}
}
新創建的NettyRpcEnv主要用于Endpoint的注冊、啟動transportServer、獲得RPCEndpointRef、創建客戶端等等;其主要成員有dispatcher、transportContext。
1.1 Dispatcher介紹
Dispatcher的主要作用是保存注冊的RpcEndpoint、分發相應的Message到RpcEndPoint中進行處理。
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
// Dispatcher的內部類,主要是聲明一個
private class EndpointData(val name: String, val endpoint: RpcEndpoint, val ref: NettyRpcEndpointRef) {
val inbox = new Inbox(ref, endpoint)
}
// 維護一個HaskMap,保存Name與EndpointData的關系
private val endpoints = new ConcurrentHashMap[String, EndpointData]
// 維護一個HaskMap,保存RpcEndpoint與RpcEndpointRef的關系
private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
// Track the receivers whose inboxes may contain messages.
//維護一個BlockingQueue的隊列,用于保存擁有消息的EndpointData,注冊Endpoint、
//發送消息時、停止RpcEnv時、取消注冊的Endpoint時,會在receivers中添加相應的EndpointData
private val receivers = new LinkedBlockingQueue[EndpointData]
/**
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced immediately.
*/
@GuardedBy("this") private var stopped = false
// 根據Name和RPCEndpoint,在RpcEnv上進行注冊
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
//根據NettyEnv的address和參數Name,創建RpcEndpointAddress
val addr = RpcEndpointAddress(nettyEnv.address, name)
//創建對應的NettyRpcEndpointRef
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
//新建一個EndpointData,里面主要包含一個inbox成員,后面會講到。
//將新創建的EndpointData和對應的Name添加到endpoints中
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
}
val data = endpoints.get(name)
//將endpoint和對應的endpointRef添加到endpointRefs中
endpointRefs.put(data.endpoint, data.ref)
//在receivers中添加新創建的endpointData
receivers.offer(data)
// for the OnStart message
}
//返回對應的EndpointRef
endpointRef
}
//根據endpoint獲取對應的endpointRef
def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)
//從endpointRefs中移除對應的endpoint
def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)
// Should be idempotent private
// 根據Name,取消其在NettyRpcEnv中注冊的endpoint
def unregisterRpcEndpoint(name: String): Unit = {
//從endpoints中移除對應的endpointData
val data = endpoints.remove(name)
if (data != null) {
//調用endpointData中inbox的stop方法,停止endpointData
data.inbox.stop()
//將endpointData添加到receivers中,以便守護線程能執行endpointData.inbox的message
receivers.offer(data)
// for the OnStop message
}
// Don't clean `endpointRefs` here because it's possible that some messages are being processed
// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
// `removeRpcEndpointRef`.
}
def stop(rpcEndpointRef: RpcEndpointRef): Unit = {
synchronized {
if (stopped) {
// This endpoint will be stopped by Dispatcher.stop() method.
return
}
unregisterRpcEndpoint(rpcEndpointRef.name)
}
}
/**
* Send a message to all registered [[RpcEndpoint]]s in this process.
*
* This can be used to make network events known to all end points (e.g. "a new node connected").
*/
//向所有已經注冊的RpcEndpoint發送消息
def postToAll(message: InboxMessage): Unit = {
val iter = endpoints.keySet().iterator()
while (iter.hasNext) {
val name = iter.next
postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}"))
}
}
/** Posts a message sent by a remote endpoint. */
//發布一個由遠端endpoint發送的消息
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
val rpcCallContext = new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
}
/** Posts a message sent by a local endpoint. */
//發布一個由本地endpoint發送的消息
def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
val rpcCallContext = new LocalNettyRpcCallContext(message.senderAddress, p)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
}
/** Posts a one-way message. */
def postOneWayMessage(message: RequestMessage): Unit = {
postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), (e) => throw e)
}
/**
* Posts a message to a specific endpoint.
*
* @param endpointName name of the endpoint.
* @param message the message to post
* @param callbackIfStopped callback function if the endpoint is stopped.
*/
//將消息發送給特定的endpoint進行處理,參數1:endpoint的名字,參數2:消息,參數3:當endpoint停止時的回調函數
private def postMessage(endpointName: String, message: InboxMessage, callbackIfStopped: (Exception) => Unit): Unit = {
val error = synchronized {
// 根據endpointName獲得對應的endpointData
val data = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
} else if (data == null) {
Some(new SparkException(s"Could not find $endpointName."))
} else {
//將Message添加到該endpointData的inbox的message中
data.inbox.post(message)
//將endpointData添加到receivers中
receivers.offer(data)
None
}
}
// We don't need to call `onStop` in the `synchronized` block
error.foreach(callbackIfStopped)
}
def stop(): Unit = {
synchronized {
if (stopped) {
return
}
stopped = true
}
// Stop all endpoints. This will queue all endpoints for processing by the message loops.
endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
// Enqueue a message that tells the message loops to stop. receivers.offer(PoisonPill)
threadpool.shutdown()
}
def awaitTermination(): Unit = {
threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
}
/**
* Return if the endpoint exists
*/
//判斷endpoints中是否包含對應的endpointName
def verify(name: String): Boolean = { endpoints.containsKey(name) }
/** Thread pool used for dispatching messages. */
//創建一個線程組,用于分發消息
private val threadpool: ThreadPoolExecutor = {
//根據配置項,獲的線程組中線程個數
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", math.max(2, Runtime.getRuntime.availableProcessors()))
//創建線程組
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
//創建多線程,執行相應的MessageLoop
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
/** Message loop used for dispatching messages. */
//聲明一個MessageLoop繼承Runnable
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
//從receivers中獲得一個endpointData,由于receivers是LinkBlockingQueue,所以如果receivers中沒有元素時,該線程會阻塞
val data = receivers.take()
//獲取的元素如果是PoisonPill,將停止該線程,同時 將PoisonPill繼續放回receivers中,以便停止所有線程
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
//調用rpcEndpointData中inbox的process方法,處理響應RpcEndpointData中的Message
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case ie: InterruptedException => // exit
}
}
}
/** A poison endpoint that indicates MessageLoop should exit its message loop. */
private val PoisonPill = new EndpointData(null, null, null)}
根據上面的代碼可以看出,Dispatcher在進行Message分發到相應的Endpoint進行處理時,實際上是將Message分發到endpointData中進行處理了,而EndpointData類中最重要的成員就是inbox,下面介紹Inbox。
1.2 Inbox
private[netty] class Inbox(val endpointRef: NettyRpcEndpointRef, val endpoint: RpcEndpoint) extends Logging {
inbox =>
// Give this an alias so we can use it more clearly in closures.
// 聲明一個InboxMessage類型的LinkedList,命名為message
@GuardedBy("this") protected val messages = new java.util.LinkedList[InboxMessage]()
/** True if the inbox (and its associated endpoint) is stopped. */
@GuardedBy("this") private var stopped = false
/** Allow multiple threads to process messages at the same time. */
//允許多個線程同時處理message
@GuardedBy("this") private var enableConcurrent = false
/** The number of threads processing messages for this inbox. */
//對當前處理message的進程的計數
@GuardedBy("this") private var numActiveThreads = 0
// OnStart should be the first message to process
//最開始在聲明的時候就將OnStart消息添加到message中
inbox.synchronized {
messages.add(OnStart)
}
/**
* Process stored messages.
*/
//處理消息
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
if (!enableConcurrent && numActiveThreads != 0) {
return
}
//獲取list中頭部的第一個message
message = messages.poll()
//去過message不為Null,就將numActiveThreads加1
if (message != null) {
numActiveThreads += 1
} else {
return
}
}
//對Message進行匹配,然后執行
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
case NonFatal(e) =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
assert(activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")
case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it every time.
if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
numActiveThreads -= 1
return
}
//獲取message中的下一個元素,繼續進行匹配執行
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}
//將message消息添加到messages列表中
def post(message: InboxMessage): Unit = inbox.synchronized {
//如果inbox已經停止,就將OnStop添加到messages中
if (stopped) {
// We already put "OnStop" into "messages", so we should drop further messages
onDrop(message)
} else {
messages.add(message)
false
}
}
def stop(): Unit = inbox.synchronized {
// The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
// message
if (!stopped) {
// We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
// thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
// safely.
enableConcurrent = false
stopped = true
messages.add(OnStop)
// Note: The concurrent events in messages will be processed one by one.
}
}
//判斷messages是否為空
def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }
/**
* Called when we are dropping a message. Test cases override this to test message dropping.
* Exposed for testing.
*/
protected def onDrop(message: InboxMessage): Unit = {
logWarning(s"Drop $message because $endpointRef is stopped")
}
/**
* Calls action closure, and calls the endpoint's onError function in the case of exceptions.
*/
private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
try action catch {
case NonFatal(e) =>
try endpoint.onError(e) catch {
case NonFatal(ee) => logError(s"Ignoring error", ee)
}
}
}
}
至此,NettyRpcEnv中的Dispatcher已經講完了,主要流程是:
- 創建Dispatcher
- 聲明線程組,并監控receivers是否有新的EndpointData
- 如果有消息,并且不為PoisonPill,調用相應EndpointData的Inbox的process方法進行消息處理
1). 依次從相應的EndpointData的inbox的messages中獲取第一個元素
2). 匹配消息,并調用對應的endpoint的相應方法進行處理 - 如果沒有消息,則阻塞等待
- 如果有消息,但是為PoisonPill,則將PoisonPill繼續添加到receivers中,然后停止該線程
- 如果有消息,并且不為PoisonPill,調用相應EndpointData的Inbox的process方法進行消息處理
- 根據name和endpoint,在NettyRpcEnv進行注冊
- 根據nettyEnv.conf、RpcEndpointAddress和nettyEnv創建對應的NettyRpcEndpointRef
- 根據name、endpoint、endpointRef創建新的EndpointData
- 將name -> EndpointData添加到endpoints中
- 將endpoint -> endpointRef添加到endpointRefs中
- 將新建的EndpointData添加到receivers中
- 將InboxMessage消息分發到相應的EndpointData中進行處理
- 根據Name獲取EndpointData
- 將Message添加到EndpointData的Inbox的messages中
- 將EndpointData添加到receivers中
接下來重點介紹下RpcEndpointRef的生成方法,根據name和rpcendpoint在NettyRpcEnv注冊時,首先會根據name和NettyEnv的address創建RpcEndpointAddress,然后再根據RpcEndpointAddress、NettyEnv.conf和NettyEnv創建一個相應的NettyRpcEndpointRef,也就是說NettyRpcEndpointRef的生成與實際的RPCEndpoint并沒有什么直接聯系,只是在NettyRpcEnv中依據某個Name生成一個NettyRpcEndpointRef,然后客戶端通過NettyRpcEndpotinRef發送消息時,NettyRpcEnv會根據消息中的name,將消息發送給對應的NettyRpcEndpoint進行相應消息處理。
1.3 NettyRpcEndpointRef
private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, endpointAddress: RpcEndpointAddress, @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) with Serializable with Logging {
//聲明一個transportClient
@transient @volatile var client: TransportClient = _
//根據endpointAddress獲得NettyRpcEnv的host地址
private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
//聲明一個_name變量并賦值為endpointAddress的Name
private val _name = endpointAddress.name
override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
//讀對象
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
nettyEnv = NettyRpcEnv.currentEnv.value
client = NettyRpcEnv.currentClient.value
}
//寫對象
private def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
}
override def name: String = _name
//重寫RPCEndpointRef的ask方法
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
}
//重寫RPCEndpointRef的send方法
override def send(message: Any): Unit = {
require(message != null, "Message is null")
nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
}
override def toString: String = s"NettyRpcEndpointRef(${_address})"
def toURI: URI = new URI(_address.toString)
final override def equals(that: Any): Boolean = that match {
case other: NettyRpcEndpointRef => _address == other._address
case _ => false
}
final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()}
至此,Spark RPC通信模塊中的NettyRpcEnv、NettyRpcEndpoint、NettyRpcEndpointRef已經全部梳理完成。