排序是工程中必不可少的功能,很多編程語言SDK都提供了排序相關的實現。作為軟件工程師,我們在學習各類排序算法的同時,是否有思考過,如何去實現一個工業級的排序算法?如果你是Go語言的作者之一,該如何去實現一種能適應多種情況的排序算法?
Go SDK中排序相關的實現主要在sort/sort.go
中,本文主要基于該文件進行相關實現的分析。
首先來看看Go對排序接口的定義,利用Go的interface特性可以輕松實現多種數據類型的排序功能。想要調用sort包的排序功能我們需要實現這個排序接口,排序接口主要定義了三個方法:
-
Len() int
: 返回傳入數據的總數 -
Less(i, j int) bool
: 返回數組中下標為i的數據是否小于下標為j的數據 -
Swap(i, j int)
: 表示執行交換數組中下標為i的數據和下標為j的數據
// A type, typically a collection, that satisfies sort.Interface can be
// sorted by the routines in this package. The methods require that the
// elements of the collection be enumerated by an integer index.
type Interface interface {
// Len is the number of elements in the collection.
Len() int
// Less reports whether the element with
// index i should sort before the element with index j.
Less(i, j int) bool
// Swap swaps the elements with indexes i and j.
Swap(i, j int)
}
了解了包中對sort接口的定義后,再來看看sort包對外提供的主要接口Sort,源碼如下:
// Sort sorts data.
// It makes one call to data.Len to determine n, and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
n := data.Len()
quickSort(data, 0, n, maxDepth(n))
}
如注釋所說,當我們調用Sort方法時,該方法會調用一次data.Len()
,之后會以O(n*log(n))
的時間復雜度調用data.Less
和data.Swap
。我們可以看到,Sort內部調用了包私有的quickSort方法,也就是我們熟悉的快排,同時傳了4個參數,學過快排的同學都能理解前三個參數的含義,但是我們還看到了一個陌生的函數調用maxDepth(n)
,這里的depth究竟代表什么呢?所以先探究一下這個函數,代碼如下:
// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
var depth int
for i := n; i > 0; i >>= 1 {
depth++
}
return depth * 2
}
簡單來說,maxDepth方法返回的深度表示了數據的量級,qiuckSort方法會根據這個量級選擇使用快排還是堆排序,學過堆排序的同學都知道,堆排序的時間復雜度穩定在O(nlogn),有時候比快排還穩定,但是堆排序對數據是跳著訪問的,對CPU緩存不友好。
了解了maxDepth方法以后就可以來看看quickSort的源碼了
func quickSort(data Interface, a, b, maxDepth int) {
for b-a > 12 { // Use ShellSort for slices <= 12 elements
if maxDepth == 0 {
heapSort(data, a, b)
return
}
maxDepth--
mlo, mhi := doPivot(data, a, b)
// Avoiding recursion on the larger subproblem guarantees
// a stack depth of at most lg(b-a).
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}
}
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}
這里代碼的實現方式比較好理解,首先對于數組元素大于12個的情況會在快排和堆排之間選擇,除此之外的情況會使用希爾排序(間隔為6)和插入排序進行排序。
包中對于heapSort的實現中規中矩,使用從上往下堆化的方式建堆。這里就不詳細介紹,對于快排的實現方式,有的同學就發現不同了,這里調用了一個尋找分區點的函數doPivot,但是doPivot返回了兩個值(這里就利用了Go中函數可以有多個返回值的特性)。同時這里可以看到返回mlo,mhi以后并沒有繼續遞歸地在左右分區查找,而是做了一個比較,原因也正如注釋所說,由于使用了遞歸的方式實現排序,就必須要考慮到棧溢出的問題,所以對分區的兩半,把數量多的放到下一次循環繼續切分循環,小的直接遞歸。這里也表明了調用quickSort的最高棧深度為log(b-a),也就是log(n)。
接下來可以看看doPivot函數,為什么會返回兩個分區點呢?因為mlo到mhi之間的數已經被確定了位置,這里考慮到取中位數的時候數組出現大量重復的數會影響到排序性能的問題,可以發現Go作者對這種情況的解決方式充滿著智慧。具體代碼如下:
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
m := int(uint(lo+hi) >> 1) // 首先用位運算的方式求中間點,防止溢出
if hi-lo > 40 {
// 多數取中
// Tukey's ``Ninther,'' median of three medians of three.
s := (hi - lo) / 8
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)
// 接下來要對數據達成以下劃分結果
// data[lo] = pivot (set up by ChoosePivot)
// data[lo < i < a] < pivot
// data[a <= i < b] <= pivot
// data[b <= i < c] unexamined
// data[c <= i < hi-1] > pivot
// data[hi-1] >= pivot
pivot := lo
a, c := lo+1, hi-1
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
for {
for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
}
for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
}
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
data.Swap(b, c-1)
b++
c--
}
// 如果data[c <= i < hi-1] > pivot,hi-c<3 這表明數據中有重復的數,
// 這里保守一些,認為hi-c<5 為邊界,如果重復的數較多,
// 會以直接掃描跳過的方式把pivot左右兩邊的區間縮小
// If hi-c<3 then there are duplicates (by property of median of nine).
// Let's be a bit more conservative, and set border to 5.
protect := hi-c < 5
if !protect && hi-c < (hi-lo)/4 {
// Lets test some points for equality to pivot
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
// m-lo = (hi-lo)/2 > 6
// b-lo > (hi-lo)*3/4-1 > 8
// ==> m < b ==> data[m] <= pivot
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
if protect {
// Protect against a lot of duplicates
// Add invariant:
// data[a <= i < b] unexamined
// data[b <= i < c] = pivot
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
// Swap pivot into middle
data.Swap(pivot, b-1)
return b - 1, c
}