基本介紹
最近在做一個文件archive的事情,其中需要對目錄文件下的索引排序,最開始是用的內部歸并排序,這在目錄里面文件還比較少的時候,沒什么大問題;但是發現有一個目錄下的文件數太多,無法正常排序,因為那樣會OOM;所以就打算先通過rdd里面的sortByKey來先將文件分段排序然后再整合到目標文件中。所以我寫下了這么一段代碼:
sc.parallelize(data)
.flatMap(dealFunction)
.sortByKey(_._1)
.someOtherOperations
sortByKey 主要用途就是將目標tuples根據key值在不同的range段排序:比如有原始數據((5, 5), (4, 4), (3, 3), (3, 3), (2, 2), (4, 4), (5, 5), (6, 6))。我們希望在兩個range partition中排序,那么最終的結果為(((2, 2), (3,3), (3,3)), ((4, 4),(4, 4),(5, 5), (5, 5), (6, 6)))。那么如何確定各個段的邊界呢?那么這里就有一個統計學原理知識,就是先對數據抽樣,然后根據抽樣數據來決定各個段的邊界以此來保證段中的數據盡量均勻。
sortByKey的基本用法我就不介紹了,這里主要來講講里面的一些具體實現及一些為新手所不能理解的地方:
- sortByKey被認為是一個transformation, 但是我在看spark UI的時候,卻為sortByKey產生了一個job,因為稍微了解spark的同學都知道只有action才會產生job。
- 我的flatMap中的dealFunction函數被調用了兩次。
原理分析
經過Google及閱讀源碼發現:
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
: RDD[(K, V)] = self.withScope
{
val part = new RangePartitioner(numPartitions, self, ascending)
new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
var partition = 0
if (rangeBounds.length <= 128) {
// If we have less than 128 partitions naive search
while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
partition += 1
}
} else {
// Determine which binary search method to use only once.
partition = binarySearch(rangeBounds, k)
// binarySearch either returns the match location or -[insertion point]-1
if (partition < 0) {
partition = -partition-1
}
if (partition > rangeBounds.length) {
partition = rangeBounds.length
}
}
if (ascending) {
partition
} else {
rangeBounds.length - partition
}
}
在上述代碼中我們可以看到在Partitioner 里面的getPartition函數中,是根據key在rangeBounds里面的位置來判斷對應的key是處于哪一個range partition中的,那么我們來看一下rangeBounds的生成。其基本思路就是根據一些參數來決定抽樣的樣本數量,并獲取樣本數來劃分range段的邊界。
代碼如下:
private var rangeBounds: Array[K] = {
if (partitions <= 1) {
Array.empty
} else {
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
// Cast to double to avoid overflowing ints or longs
val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) // colect sample map key from target rdd
if (numItems == 0L) {
Array.empty
} else {
// If a partition contains much more than the average number of items, we re-sample from it
// to ensure that enough items are collected from that partition.
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
val candidates = ArrayBuffer.empty[(K, Float)]
val imbalancedPartitions = mutable.Set.empty[Int]
sketched.foreach { case (idx, n, sample) =>
if (fraction * n > sampleSizePerPartition) {
imbalancedPartitions += idx
} else {
// The weight is 1 over the sampling probability.
val weight = (n.toDouble / sample.length).toFloat
for (key <- sample) {
candidates += ((key, weight))
}
}
}
if (imbalancedPartitions.nonEmpty) {
// Re-sample imbalanced partitions with the desired sampling probability.
val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
val seed = byteswap32(-rdd.id - 1)
val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
val weight = (1.0 / fraction).toFloat
candidates ++= reSampled.map(x => (x, weight))
}
RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))
}
}
}
def sketch[K : ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx ^ (shift << 16))
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
}.collect() // here generated a job
val numItems = sketched.map(_._2).sum
(numItems, sketched)
}
獲取樣本sample的函數為sketch,我們可以看到這段代碼中有一個collect操作,所以這就不難解釋我們的疑惑1了,因為在sample的過程中有一個collect會產生一個job。
那么第二個疑惑是為啥呢?產生job就產生job唄,為啥我之前的flatMap里面的操作你要執行兩次?
那這里就要回到spark中的stage概念來了,在spark中一個job會劃分為多個stage,而stage的劃分是跟wide transformation有關。flatMap是一個narrow transformation,這樣的話由于在同一個stage中,所以sortByKey中的sample job會把其所在的stage中的操作跑一遍,而外層的job會把整個所有stage都跑一遍這樣你sortByKey所在的stage中的操作就會跑兩遍,具體見圖:
在上述兩個圖中,stage1中的flatMap和stage2中的flatMap其實是同一個flatmap操作,這樣就可以解釋為啥我的flatMap中的操作為啥執行兩次了。
解決方案
如果我的flatMap的操作比較重,都是一些訪問文件的操作,那么有什么好的方法可以避免因為sample而導致的兩次執行問題嗎?
那么這里就可以介紹一下spark中的stage cache:就是在shuffle結束一個stage的時候,spark會cache住stage中的結果數據,這樣下一次如果遇到要重新運行該stage的時候可以直接拿最終的結果,而不需要重新運行完整的stage過程。所以結合上圖我們可以在stage1中的flatMap后面加一個shuffle操作來拆分一個stage,這樣下一次執行stage1的時候就可以直接獲取數據了,我們可以通過添加一個repartition來切分一些stage,以保證sortByKey的sample執行時是在一個新的stage中,這樣sample job 和 原始job可以復用一個stage1中的數據,
代碼如下:
sc.parallelize(data)
.flatMap(dealFunction)
.repartition(partitions)
.sortByKey(_._1)
.someOtherOperations
最終的結果如下:
注:因為shuffle的過程是比較耗時的,對于內存和IO也是有較高損耗的,目前這個方法是我目前能想到的比較好的解決比較重的重復操作的方法,我玩Spark的時間也比較短,如果大神能能指教一下最好了
shuffle performance-impactl#performance-impact
最后又發現可以在sortByKey之前調用一次cache或者persist進行rdd緩存。