矩陣乘法Strassen算法

閱讀經典——《算法導論》03

矩陣乘法是種極其耗時的運算。

以C = A ? B為例,其中A和B都是 n x n 的矩陣。根據矩陣乘法的定義,計算過程如下:

SQUARE-MATRIX-MULTIPLY(A, B)
n = A.rows
let C be a new nxn matrix
for i = 1 to n
    for j = 1 to n
        c[i][j] = 0
        for k = 1 to n
            c[i][j] += a[i][k] * b[k][j]
return C

由于存在三層循環,它的時間復雜度將達到O(n3)。

這是一個很可怕的數字。但是,憑著科學家們的智慧,這個數正在一步步下降。本文介紹經典的Strassen算法,該算法將時間復雜度降低到O(nlg7) ≈ O(n2.81)。別小看這個細微的改進,當n非常大時,該算法將比平凡算法節約大量時間。

分治法

Strassen算法基于分治的思想,因此我們首先考慮一個簡單的分治策略。

每個 n x n 的矩陣都可以分割為四個 n/2 x n/2 的矩陣:

<small>(式3-1)</small>


因此可以將公式C = A ? B改寫為

<small>(式3-2)</small>


于是上式就等價于如下四個公式:

<small>(式3-3)</small>
C11 = A11 ? B11 + A12 ? B21
C12 = A11 ? B12 + A12 ? B22
C21 = A21 ? B11 + A22 ? B21
C22 = A21 ? B12 + A22 ? B22

每個公式需要計算兩次矩陣乘法和一次矩陣加法,使用T(n)表示 n x n 矩陣乘法的時間復雜度,那么我們可以根據上面的分解得到一個遞推公式。

T(n) = 8T(n/2) + Θ(n2)

其中,8T(n/2)表示8次矩陣乘法,而且相乘的矩陣規模降到了n/2。Θ(n2)表示4次矩陣加法的時間復雜度以及合并C矩陣的時間復雜度。

要想計算出T(n)并不復雜,可以采用畫遞歸樹的方式計算,或采用下一篇文章中講的“主方法”直接計算。結果是

T(n) = Θ(n3)

可見,簡單的分治策略并沒有起到加速運算的效果。

Strassen算法

1969年,Volker Strassen發表文章提出一種漸進快于平凡算法的矩陣相乘算法,引起巨大轟動。在此之前,很少人敢設想一個算法能漸近快于平凡算法。矩陣乘法的漸近上界自此被改進了。

讓我們回頭觀察前面使用分治策略的時候為什么無法提高速度。

因為分解后的問題包含了8次矩陣相乘和4次矩陣相加,就是這8次矩陣相乘導致了速度不能提升。于是我們想到能不能減少矩陣相乘的次數,取而代之的是矩陣相加的次數增加。Strassen正是利用了這一點。

現在,我們來看一下Strassen算法的原理。

仍然把每個矩陣分割為4份,然后創建如下10個中間矩陣:

S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12

接著,計算7次矩陣乘法:

P1 = A11 ? S1
P2 = S2 ? B22
P3 = S3 ? B11
P4 = A22 ? S4
P5 = S5 ? S6
P6 = S7 ? S8
P7 = S9 ? S10

最后,根據這7個結果就可以計算出C矩陣:
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7

是不是很神奇呢?話說我第一次看到這個算法的時候真的是驚呆了,10個S矩陣和7個P矩陣究竟是怎么湊出來的,簡直不可思議。

我們可以把P矩陣和S矩陣展開,并帶入最后的式子計算,會發現恰好是公式3中的四個式子。也就是說,Strassen為了計算公式3,繞了一大圈,用了更多的步驟,成功的把計算量變成了7個矩陣乘法和18個矩陣加法。雖然矩陣加法增加了好幾倍,而矩陣乘法只減小了1個,但在數量級面前,18個加法仍然漸進快于1個乘法。這就是該算法的精妙之處。

同樣地,我們可以寫出Strassen算法的遞推公式:

T(n) = 7T(n/2) + Θ(n2)

使用遞歸樹或主方法可以計算出結果:

T(n) = Θ(nlg7) ≈ Θ(n2.81)

下圖展示了平凡算法和Strassen算法的性能差異,n越大,Strassen算法節約的時間越多。

性能比較

小技巧:如何計算n是否為2的冪

在矩陣分解的過程中,我遇到了這樣一個問題:如何判斷一個 n x n 的矩陣是否能恰好分解為4個大小相同的矩陣。它的本質是判斷n是否為2的冪。

最先想到的方法是不斷除以2,直到余數不為0時判斷當前的被除數是否為1,是則為2的冪,否則不是2的冪。這相當于通過右移檢查n的二進制形式是否為1000...0。

但這種方式有些繁瑣,需要循環判斷。為了提高效率,我發現有位高手用下面這行代碼解決了這個問題:

n & (n - 1) == 0

沒錯,只需要一行代碼,而且只做了一次加法運算和一次與運算,效率大大提高。其原理也很容易解釋,把nn-1的二進制形式寫出來一看就明白了。假設n=0010 0000,那么n-1 = 0001 1111,相與得到

  0010 0000
& 0001 1111
------------
  0000 0000

恰好是0。只要把n中右邊的任意一個0換成1,結果都不再是0。

還有一種類似的方法:

(n & -n) == n

本質和前面是一樣的。據說后一種做法來自JDK,但我沒有考證到。

參考資料

計算機算法:Strassen矩陣相乘算法 Stoimen

Gaussian Elimination is not Optimal Strassen

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容