閱讀經典——《算法導論》03
矩陣乘法是種極其耗時的運算。
以C = A ? B為例,其中A和B都是 n x n 的矩陣。根據矩陣乘法的定義,計算過程如下:
SQUARE-MATRIX-MULTIPLY(A, B)
n = A.rows
let C be a new nxn matrix
for i = 1 to n
for j = 1 to n
c[i][j] = 0
for k = 1 to n
c[i][j] += a[i][k] * b[k][j]
return C
由于存在三層循環,它的時間復雜度將達到O(n3)。
這是一個很可怕的數字。但是,憑著科學家們的智慧,這個數正在一步步下降。本文介紹經典的Strassen算法,該算法將時間復雜度降低到O(nlg7) ≈ O(n2.81)。別小看這個細微的改進,當n非常大時,該算法將比平凡算法節約大量時間。
分治法
Strassen算法基于分治的思想,因此我們首先考慮一個簡單的分治策略。
每個 n x n 的矩陣都可以分割為四個 n/2 x n/2 的矩陣:
<small>(式3-1)</small>
因此可以將公式C = A ? B改寫為
<small>(式3-2)</small>
于是上式就等價于如下四個公式:
<small>(式3-3)</small>
C11 = A11 ? B11 + A12 ? B21
C12 = A11 ? B12 + A12 ? B22
C21 = A21 ? B11 + A22 ? B21
C22 = A21 ? B12 + A22 ? B22
每個公式需要計算兩次矩陣乘法和一次矩陣加法,使用T(n)表示 n x n 矩陣乘法的時間復雜度,那么我們可以根據上面的分解得到一個遞推公式。
T(n) = 8T(n/2) + Θ(n2)
其中,8T(n/2)表示8次矩陣乘法,而且相乘的矩陣規模降到了n/2。Θ(n2)表示4次矩陣加法的時間復雜度以及合并C矩陣的時間復雜度。
要想計算出T(n)并不復雜,可以采用畫遞歸樹的方式計算,或采用下一篇文章中講的“主方法”直接計算。結果是
T(n) = Θ(n3)
可見,簡單的分治策略并沒有起到加速運算的效果。
Strassen算法
1969年,Volker Strassen發表文章提出一種漸進快于平凡算法的矩陣相乘算法,引起巨大轟動。在此之前,很少人敢設想一個算法能漸近快于平凡算法。矩陣乘法的漸近上界自此被改進了。
讓我們回頭觀察前面使用分治策略的時候為什么無法提高速度。
因為分解后的問題包含了8次矩陣相乘和4次矩陣相加,就是這8次矩陣相乘導致了速度不能提升。于是我們想到能不能減少矩陣相乘的次數,取而代之的是矩陣相加的次數增加。Strassen正是利用了這一點。
現在,我們來看一下Strassen算法的原理。
仍然把每個矩陣分割為4份,然后創建如下10個中間矩陣:
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12
接著,計算7次矩陣乘法:
P1 = A11 ? S1
P2 = S2 ? B22
P3 = S3 ? B11
P4 = A22 ? S4
P5 = S5 ? S6
P6 = S7 ? S8
P7 = S9 ? S10
最后,根據這7個結果就可以計算出C矩陣:
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
是不是很神奇呢?話說我第一次看到這個算法的時候真的是驚呆了,10個S矩陣和7個P矩陣究竟是怎么湊出來的,簡直不可思議。
我們可以把P矩陣和S矩陣展開,并帶入最后的式子計算,會發現恰好是公式3中的四個式子。也就是說,Strassen為了計算公式3,繞了一大圈,用了更多的步驟,成功的把計算量變成了7個矩陣乘法和18個矩陣加法。雖然矩陣加法增加了好幾倍,而矩陣乘法只減小了1個,但在數量級面前,18個加法仍然漸進快于1個乘法。這就是該算法的精妙之處。
同樣地,我們可以寫出Strassen算法的遞推公式:
T(n) = 7T(n/2) + Θ(n2)
使用遞歸樹或主方法可以計算出結果:
T(n) = Θ(nlg7) ≈ Θ(n2.81)
下圖展示了平凡算法和Strassen算法的性能差異,n越大,Strassen算法節約的時間越多。
小技巧:如何計算n是否為2的冪
在矩陣分解的過程中,我遇到了這樣一個問題:如何判斷一個 n x n 的矩陣是否能恰好分解為4個大小相同的矩陣。它的本質是判斷n是否為2的冪。
最先想到的方法是不斷除以2,直到余數不為0時判斷當前的被除數是否為1,是則為2的冪,否則不是2的冪。這相當于通過右移檢查n的二進制形式是否為1000...0。
但這種方式有些繁瑣,需要循環判斷。為了提高效率,我發現有位高手用下面這行代碼解決了這個問題:
n & (n - 1) == 0
沒錯,只需要一行代碼,而且只做了一次加法運算和一次與運算,效率大大提高。其原理也很容易解釋,把n
和n-1
的二進制形式寫出來一看就明白了。假設n=0010 0000
,那么n-1 = 0001 1111
,相與得到
0010 0000
& 0001 1111
------------
0000 0000
恰好是0。只要把n中右邊的任意一個0換成1,結果都不再是0。
還有一種類似的方法:
(n & -n) == n
本質和前面是一樣的。據說后一種做法來自JDK,但我沒有考證到。
參考資料
計算機算法:Strassen矩陣相乘算法 Stoimen
Gaussian Elimination is not Optimal Strassen