Median of Two Sorted Arrays
這是一個leetcode上的算法題目,標記為hard。具體描述如下:
There are two sorted arrays nums1 and nums2 of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
leetcode上要求實現接口如下:
class Solution(object):
def findMedianSortedArrays(self, nums1, nums2):
"""
:type nums1: List[int]
:type nums2: List[int]
:rtype: float
"""
在做題之前,首先要明白什么是中位數。以下是來自某度的解釋:
中位數(又稱中值,英語:Median),統計學中的專有名詞,代表一個樣本、種群或概率分布中的一個數值,其可將數值集合劃分為相等的上下兩部分。對于有限的數集,可以通過把所有觀察值高低排序后找出正中間的一個作為中位數。如果觀察值有偶數個,通常取最中間的兩個數值的平均數作為中位數。
舉個栗子,數列[1,2,2,3,3,4]的中位數為2.5。兩個序列的中位數則是將兩個數組merge到同一個序列中,然后取中位數。栗子又來了[1,3,7,8]和[2,4,5,6]的中位數是4.5。
看到算法中提到的時間復雜度為log(m+n),很明顯,這里需要二分搜索。
二分搜索的挑戰
Knuth在其鴻篇巨制The Art of Computer Programming, volume3,Sorting and Searching的6.2.1節曾指出,雖然第一篇二分搜索論文在1946年就發表了,但第一個沒有錯誤的二分搜索程序卻直到1962年才出現。所以,的確很難。繼續閱讀下文之前,各位讀者不妨先拿出紙筆,花幾分鐘粗略設計一下這個題目的算法。
設計實現
回到題目,設兩個有序數組為A和B,長度分別為m、n,如何才能最快的找到中位數呢。不失一般性,可以假定m 小于等于 n,原因是A和B的中位數必然等于B和A的中位數,原因是AB的順序并不影響AB組合后的序列,因此得證。
確定邊界
二分查找,每一步都需要縮小搜索的范圍,那么不難想到,本題一個可以縮小搜索范圍的做法是通過比較兩個數組各自的中位數m1,m2,有以下三種情況來區分(數組中冒號是借用python切片表達):
- m1 == m2, 中大獎,直接返回m1(或m2)
- m1 < m2, 返回數組A[?:?]和B[?:?] 的中位數
- m1 > m2, 返回數組A[?:?] 和 B[?:?]的中位數
上述列表中的邊界中有諸多問號,這也是二分查找的關鍵之一。我們通過分析分別填上,考慮下面兩個因素:
- 中位數在序列包含元素的奇偶性上表現不同:如果序列元素個數為奇數作為中位數是數組中的某數,否則是兩個數的平均值。所以在處理邊界上,也是和奇偶相關的。
- 范圍縮小的一致性:這里指的并不是等比例,而是具體的數字。即當數組A減少了n個元素時,數組B也必須減少n個,否則結果肯定是不對的,試想A=[2,2], B=[1,2,3...8,9]。
基于這兩點,
if nums1 and nums2:
m1 = findMedianOfSingleSortedArray(nums1)
m2 = findMedianOfSingleSortedArray(nums2)
if m1 == m2:
return m1
if m1 < m2:
if len1 % 2 == 0:
return self.findMedianSortedArrays(nums1[len1 / 2 - 1:], nums2[:len2 - len1 / 2 + 1])
return self.findMedianSortedArrays(nums1[len1 / 2:], nums2[: len2 - len1 / 2])
else:
if len1 % 2 == 0:
return self.findMedianSortedArrays(nums1[:len1 / 2 + 1 ], nums2[len1 / 2 - 1:])
return self.findMedianSortedArrays(nums1[:len1 / 2 + 1 ], nums2[len1 / 2:])
邊界確定后,下一步就是要找到結束條件了。
結束條件
開始之前,還是慣例,希望讀者能先思考一下。何時終止二分查找。
同時,這里先插播一個有意思的感覺:做英語選擇題拿不太準時,比如第一感覺選A,然后修改了C,結果改錯了。然后英語老師強調第一感覺很重要 (其實頗有點孕婦效應的感覺)。 數學的題目第一感覺選B,后來仔細推理,原來D才是正確答案,數學老師強調的是千萬別信第一感覺。當然不排除邪惡的數學出題人故意挖坑給大家。回到這個題目,其實的確也是個數學題,小坑呢,也是有的:
當范圍逐漸縮小,是否最終是其中一個數組變為空,然后計算另外一個數組的中位數就可以了呢?估計我不說你也能猜出來,這個想法是錯的。原因就在于,有可能中位數包含在了逐漸縮小的范圍中,尤其是在最后變為空的之前一段時間。
如果不為空,那各個數組包含多少元素時應該結束呢。仔細一想,這取決與最后的中位數到底需要幾個數字才能算出來,答案是,當序列總數是偶數時,需要2個,奇數時,需要1個。嘗試解釋一下原因:
假設[2,3]和[0,1,5,6]這種情況,[2,3]是不可再縮減的,因為一旦再縮小范圍,那么中位數就無法得出了。同理,可得出奇數時需要1個的結論。
那么結束條件,也不難得出,
if len1 == 1:
if len2 % 2 == 0:
return least_1_even(nums1, nums2)
else:
return least_1_odd(nums1, nums2)
if len1 == 2:
if len2 % 2 == 0:
return least_2_even(nums1, nums2)
else:
return least_2_odd(nums1, nums2)
簡單解釋一下,為何只需要判斷第一個數組長度是否到達了1和2。因為之前假定A的長度小于B,并且我們在縮小范圍時,總是縮小相同的數目。所以必然是A先達到1、2。不需要關注B的長度。一個數組長度為1或2,另外一個數組長度為奇數或偶數需要不同處理,這里都是苦力活,即比較A中的1個或2個數字與B的中位數的大小關系。具體的參考代碼,邏輯比較簡單,但需要非常仔細。千萬不要遺漏任何一種情況。
def least_2_even(nums1, nums2):
a, b = nums1[0],nums1[1]
len2 = len(nums2)
p, q = nums2[len2 / 2 - 1], nums2[len2 / 2]
if b <= q and p <= a:
return (a + b) / 2.
if b <= p:
if len2 > 2:
return (max(b, nums2[len2 / 2 - 2]) + p) / 2.
return (b + p) / 2.
if a >= q:
if len2 > 2:
return (min(a, nums2[len2 / 2 + 1]) + q) / 2.
return (a + q) / 2.
if a <=p and q <=b:
return (p + q) / 2.
if p <= a <= q <= b:
return (a + q) / 2.
return (b + p) / 2.
def least_2_odd(nums1, nums2):
a, b = nums1[0],nums1[1]
len2 = len(nums2)
m = nums2[len2 / 2]
if a <= m <= b:
return m
if b <= m:
if len2 > 1:
return max(b, nums2[len2 / 2 - 1])
return b
if m <= a:
if len2 > 1:
return min(a, nums2[len2 / 2 + 1])
return a
def least_1_odd(nums1, nums2):
a = nums1[0]
len2 = len(nums2)
m = nums2[len2 / 2]
if len2 > 1:
if a >= nums2[len2 / 2]:
return (min(a, nums2[len2 / 2 + 1]) + m) / 2.
else:
return (max(a, nums2[len2 / 2 - 1]) + m) / 2.
else:
return (a + m) / 2.
def least_1_even(nums1, nums2):
a = nums1[0]
len2 = len(nums2)
p, q = nums2[len2 / 2 - 1], nums2[len2 / 2]
if p <= a <= q:
return a
if a < p:
return p
return q
例外
先別急著提交代碼,還有一種異常情況需要處理呢。
- A or B 為空:直接返回不空的數組的中位數
- A and B都為空,題目不存在這種情況,忽略
代碼:
if not nums1:
return findMedianOfSingleSortedArray(nums2)
if not nums2:
return findMedianOfSingleSortedArray(nums1)
附上工具函數:
def findMedianOfSingleSortedArray(l):
''' Array l must not be []'''
length = len(l)
if length % 2 == 0:
return (l[length/2] + l[length /2 - 1])/2.
return l[length/2]
OK,完成。
時間復雜度分析
對于數組A來講,算法在每一步都能夠縮短搜索范圍一半,直至元素個數為2。這需要log(m)次操作,當兩個數組元素分別個數為2和n-m+2時,需要常數次比較即可獲得中位數,因此該算法時間復雜度為O(log(min(m,n))),是快于題目給出的O(log(m+n))的。(對于python的切片操作,可以有等價的傳遞數組index的O(1)的方式替代,故此處省略不記。)
總結
正如前面提到的,二分搜索的代碼其實坑很多,整數除2到底是取上限還是下限,類似off-by-one則是每次寫都會碰到。但我們只要時刻銘記二分搜索中的三個部分,就能夠強迫自己保持清醒了:
- 初始化: 循環不變式(loop invariant)始終為真
- 保持: 如果在某次迭代開始時以及循環執行時,不變式都為真,那么循環執行完畢不變式依然為真
- 終止: 循環能夠終止,并且得到期望結果。
所有代碼在我的github上