矩阵乘法是高等代数中的重要基本运算,本文将介绍Strassen矩阵乘法的基本原理和用C语言进行算法实现的过程。
1. 一般矩阵乘法
首先,我们来看一下一般矩阵乘法的计算过程。
矩阵 A = [ a 11 a 12 … a 1 n a 21 a 22 … a 2 n … … … … a n 1 a n 2 … a n n ] A=\begin{bmatrix} a_{11} & a_{12} & … & a_{1n} \\ a_{21} & a_{22} & … & a_{2n} \\ … & … & … & … \\ a_{n1} & a_{n2} & … & a_{nn} \end{bmatrix} A= a11a21…an1a12a22…an2…………a1na2n…ann ,矩阵 B = [ b 11 b 12 … b 1 n b 21 b 22 … b 2 n … … … … b n 1 b n 2 … b n n ] B=\begin{bmatrix} b_{11} & b_{12} & … & b_{1n} \\ b_{21} & b_{22} & … & b_{2n} \\ … & … & … & … \\ b_{n1} & b_{n2} & … & b_{nn} \end{bmatrix} B= b11b21…bn1b12b22…bn2…………b1nb2n…bnn
令 C = A ⋅ B C=A \cdot B C=A⋅B,那么 C C C中的第 i i i行第 j j j列元素可表示为:
C i j = ∑ k = 1 n a i k b k j = a i 1 b 1 j + a i 2 b 2 j + … + a i n b n j C_{ij}=\sum_{k=1}^{n}a_{ik}b_{kj}=a_{i1}b_{1j}+a_{i2}b_{2j}+…+a_{in}b_{nj} Cij=∑k=1naikbkj=ai1b1j+ai2b2j+…+ainbnj
可以发现两个 n × n n \times n n×n的矩阵相乘,结果还是一个 n × n n \times n n×n的矩阵,结果矩阵中每个元素经过了 n n n次乘法和 ( n − 1 ) (n-1) (n−1)次加法计算出来,那么总共需要 n 3 n^3 n3次乘法和 ( n 3 − n 2 ) (n^3 - n^2) (n3−n2)次加法,因此时间复杂度为 O ( n 3 ) O(n^3) O(n3)。
2. 分块矩阵乘法
如果采用《分治法求解最大子数组》一文中的分治思想,将矩阵 A A A、 B B B、 C C C都划分为分块矩阵:
A = [ A 11 A 12 A 21 A 22 ] A=\begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22}\end{bmatrix} A=[A11A21A12A22], B = [ B 11 B 12 B 21 B 22 ] B=\begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22}\end{bmatrix} B=[B11B21B12B22], C = [ C 11 C 12 C 21 C 22 ] C=\begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22}\end{bmatrix} C=[C11C21C12C22]
其中每一块小矩阵的大小为 n 2 × n 2 \frac{n}{2} \times \frac{n}{2} 2n×2n。
此时有
C 11 = A 11 B 11 + A 12 B 21 C_{11}=A_{11}B_{11}+A_{12}B_{21} C11=A11B11+A12B21
C 12 = A 11 B 12 + A 12 B 22 C_{12}=A_{11}B_{12}+A_{12}B_{22} C12=A11B12+A12B22
C 21 = A 21 B 11 + A 22 B 21 C_{21}=A_{21}B_{11}+A_{22}B_{21} C21=A21B11+A22B21
C 22 = A 21 B 12 + A 22 B 22 C_{22}=A_{21}B_{12}+A_{22}B_{22} C22=A21B12+A22B22
因此 n × n n \times n n×n的矩阵 C C C的计算过程需要 n 2 × n 2 \frac{n}{2} \times \frac{n}{2} 2n×2n小矩阵的8次乘法和4次加法,小矩阵的乘法可以递归调用,直至分解到每个小矩阵只剩一个元素。4次 n 2 × n 2 \frac{n}{2} \times \frac{n}{2} 2n×2n小矩阵的加法时间复杂度为 Θ ( n 2 ) \Theta(n^2) Θ(n2),因此矩阵 C C C计算的整体时间复杂度 T ( n ) T(n) T(n)可以表示为
T ( n ) = 8 T ( n 2 ) + Θ ( n 2 ) T(n)=8T(\frac{n}{2})+\Theta(n^2) T(n)=8T(2n)+Θ(n2), n > 2 n > 2 n>2
可推导出 T ( n ) = O ( n log 8 ) = O ( n 3 ) T(n) = O(n^{\log{8}})=O(n^3) T(n)=O(nlog8)=O(n3),也就是说明这样分块计算的时间复杂度与一般矩阵乘法的时间复杂度相同。
3. Strassen矩阵乘法
在上述分块矩阵的基础上,Strassen矩阵乘法尝试降低乘法的次数,设计出如下的7个中间矩阵:
P = ( A 11 + A 22 ) ( B 11 + B 22 ) P=(A_{11}+A_{22})(B_{11}+B_{22}) P=(A11+A22)(B11+B22)
Q = ( A 21 + A 22 ) B 11 Q=(A_{21}+A_{22})B_{11} Q=(A21+A22)B11
R = A 11 ( B 12 − B 22 ) R=A_{11}(B_{12}-B_{22}) R=A11(B12−B22)
S = A 22 ( B 21 − B 11 ) S=A_{22}(B_{21}-B_{11}) S=A22(B21−B11)
T = ( A 11 + A 12 ) B 22 T=(A_{11}+A_{12})B_{22} T=(A11+A12)B22
U = ( A 21 − A 11 ) ( B 11 + B 12 ) U=(A_{21}-A_{11})(B_{11}+B_{12}) U=(A21−A11)(B11+B12)
V = ( A 12 − A 22 ) ( B 21 + B 22 ) V=(A_{12}-A_{22})(B_{21}+B_{22}) V=(A12−A22)(B21+B22)
此时用了7次乘法和10次加法,然后再通过这7个中间矩阵计算得到矩阵 C C C:
C 11 = P + S − T + V C_{11}=P+S-T+V C11=P+S−T+V
C 12 = R + T C_{12}=R+T C12=R+T
C 21 = Q + S C_{21}=Q+S C21=Q+S
C 22 = P + R − Q + U C_{22}=P+R-Q+U C22=P+R−Q+U
在此过程中,又用了8次加法。因此总共用了7次乘法和18次加法,整体时间复杂度 T ( n ) T(n) T(n)可以表示为
T ( n ) = 7 T ( n 2 ) + Θ ( n 2 ) T(n)=7T(\frac{n}{2})+\Theta(n^2) T(n)=7T(2n)+Θ(n2), n > 2 n > 2 n>2
可推导出 T ( n ) = O ( n log 7 ) ≈ O ( n 2.81 ) T(n) = O(n^{\log{7}}) \approx O(n^{2.81}) T(n)=O(nlog7)≈O(n2.81)。
4. 代码
#include <stdio.h>
#include <stdlib.h>#define N 512 //定义矩阵的大小(行数),默认参与运算的矩阵为方阵 void plus(int size, int **A, int **B, int **C) { //计算 A+B->Cint i, j;for (i = 0; i < size; i++) {for (j = 0; j < size; j++) {C[i][j] = A[i][j] + B[i][j];}}
}void minus(int size, int **A, int **B, int **C) { //计算 A-B->Cint i, j;for (i = 0; i < size; i++) {for (j = 0; j < size; j++) {C[i][j] = A[i][j] - B[i][j];}}
}void multiply(int size, int **A, int **B, int **C) { //常规算法计算 A*B->C int i, j, k;for (i = 0; i < size; i++) { for (j = 0; j < size; j++) { C[i][j] = 0; for (k = 0; k < size; k++) { C[i][j] = C[i][j] + A[i][k]*B[k][j]; } } }
} void strassen(int size, int **A, int **B, int **C) { //strassen算法计算 A*B->C int half = size / 2;if (size == 2) { multiply(size, A, B, C); //当矩阵阶数为 2 时,达到递归边界,直接按常规算法计算} else {int **A11, **A12, **A21, **A22; //Divide matrix A into 4 sub-matrices int **B11, **B12, **B21, **B22; //Divide matrix B into 4 sub-matricesint **C11, **C12, **C21, **C22; //Divide matrix C into 4 sub-matricesint **AA, **BB; //AA记录矩阵A的子矩阵运算的中间结果,BB记录矩阵B的子矩阵运算的中间结果 int **P1, **P2, **P3, **P4, **P5, **P6, **P7;//为上述矩阵开空间 A11 = (int**)malloc(sizeof(int*) * half);A12 = (int**)malloc(sizeof(int*) * half);A21 = (int**)malloc(sizeof(int*) * half);A22 = (int**)malloc(sizeof(int*) * half);B11 = (int**)malloc(sizeof(int*) * half);B12 = (int**)malloc(sizeof(int*) * half);B21 = (int**)malloc(sizeof(int*) * half);B22 = (int**)malloc(sizeof(int*) * half);C11 = (int**)malloc(sizeof(int*) * half);C12 = (int**)malloc(sizeof(int*) * half);C21 = (int**)malloc(sizeof(int*) * half);C22 = (int**)malloc(sizeof(int*) * half);AA = (int**)malloc(sizeof(int*) * half);BB = (int**)malloc(sizeof(int*) * half);P1 = (int**)malloc(sizeof(int*) * half);P2 = (int**)malloc(sizeof(int*) * half);P3 = (int**)malloc(sizeof(int*) * half);P4 = (int**)malloc(sizeof(int*) * half);P5 = (int**)malloc(sizeof(int*) * half);P6 = (int**)malloc(sizeof(int*) * half);P7 = (int**)malloc(sizeof(int*) * half);int i, j; for (i = 0; i < half; i++) {A11[i] = (int*)malloc(sizeof(int) * half);A12[i] = (int*)malloc(sizeof(int) * half);A21[i] = (int*)malloc(sizeof(int) * half);A22[i] = (int*)malloc(sizeof(int) * half);B11[i] = (int*)malloc(sizeof(int) * half);B12[i] = (int*)malloc(sizeof(int) * half);B21[i] = (int*)malloc(sizeof(int) * half);B22[i] = (int*)malloc(sizeof(int) * half);C11[i] = (int*)malloc(sizeof(int) * half);C12[i] = (int*)malloc(sizeof(int) * half);C21[i] = (int*)malloc(sizeof(int) * half);C22[i] = (int*)malloc(sizeof(int) * half);AA[i] = (int*)malloc(sizeof(int) * half);BB[i] = (int*)malloc(sizeof(int) * half);P1[i] = (int*)malloc(sizeof(int) * half);P2[i] = (int*)malloc(sizeof(int) * half);P3[i] = (int*)malloc(sizeof(int) * half);P4[i] = (int*)malloc(sizeof(int) * half);P5[i] = (int*)malloc(sizeof(int) * half);P6[i] = (int*)malloc(sizeof(int) * half);P7[i] = (int*)malloc(sizeof(int) * half);}//将 A, B矩阵填入各自的分块矩阵 for (i = 0; i < half; i++) {for (j = 0; j < half; j++) {A11[i][j] = A[i][j];A12[i][j] = A[i][j + half]; A21[i][j] = A[i + half][j]; A22[i][j] = A[i + half][j + half]; B11[i][j] = B[i][j]; B12[i][j] = B[i][j + half]; B21[i][j] = B[i + half][j]; B22[i][j] = B[i + half][j + half]; }}//计算开始 minus(half, B12, B22, BB); //Calculating P1 = A11 * (B12 - B22)strassen(half, A11, BB, P1);plus(half, A11, A12, AA); //Calculating P2 = (A11 + A12) * B22strassen(half, AA, B22, P2);plus(half, A21, A22, AA); //Calculating P3 = (A21 + A22) * B11strassen(half, AA, B11, P3);minus(half, B21, B11, BB); //Calculating P4 = A22 * (B21 - B11)strassen(half, A22, BB, P4);plus(half, A11, A22, AA); //Calculating P5 = (A11 + A22) * (B11 + B22)plus(half, B11, B22, BB);strassen(half, AA, BB, P5);minus(half, A12, A22, AA); //Calculating P6 = (A12 - A22) * (B21 + B22)plus(half, B21, B22, BB);strassen(half, AA, BB, P6);minus(half, A11, A21, AA); //Calculating P7 = (A11 - A21) * (B11 + B12)plus(half, B11, B12, BB);strassen(half, AA, BB, P7);plus(half, P5, P4, C11); //Calculating C11 = P5 + P4 - P2 + P6minus(half, C11, P2, C11);plus(half, C11, P6, C11);plus(half, P1, P2, C12); //Calculating C12 = P1 + P2plus(half, P3, P4, C21); //Calculating C21 = P3 + P4plus(half, P5, P1, C22); //Calculating C22 = P5 + P1 - P3 - P7minus(half, C22, P3, C22);minus(half, C22, P7, C22);//将矩阵C的四个分块矩阵合并填入C中 for (i = 0; i < half; i++) { for (j = 0; j < half; j++) { C[i][j] = C11[i][j]; C[i][j + half] = C12[i][j]; C[i + half][j] = C21[i][j]; C[i + half][j + half] = C22[i][j]; } } //释放空间 for (i = 0; i < half; i++) {free(A11[i]);free(A12[i]);free(A21[i]);free(A22[i]); free(B11[i]);free(B12[i]);free(B21[i]); free(B22[i]); free(C11[i]);free(C12[i]);free(C21[i]); free(C22[i]); free(P1[i]);free(P2[i]);free(P3[i]);free(P4[i]); free(P5[i]);free(P6[i]);free(P7[i]); free(AA[i]);free(BB[i]); }free(A11);free(A12);free(A21);free(A22); free(B11);free(B12);free(B21);free(B22); free(C11);free(C12);free(C21);free(C22); free(P1);free(P2);free(P3);free(P4);free(P5); free(P6);free(P7); free(AA); free(BB);}
}int main() {int **A = (int**)malloc(sizeof(int*) * N); int **B = (int**)malloc(sizeof(int*) * N);int **C = (int**)malloc(sizeof(int*) * N); int i, j;for (i = 0; i < N; i++) {A[i] = (int*)malloc(sizeof(int) * N);B[i] = (int*)malloc(sizeof(int) * N);C[i] = (int*)malloc(sizeof(int) * N);}//从文件中读取矩阵A FILE *fa;fa = fopen("la.txt", "r");if (fa == NULL) {printf("Cannot get matrix A!\n");exit(0);}else {for (i = 0; i < N; i++) {for (j = 0; j < N; j++) {fscanf(fa, "%d", &A[i][j]);}}}fclose(fa);//从文件中读取矩阵B FILE *fb;fb = fopen("lb.txt", "r");if (fb == NULL) {printf("Cannot get matrix B!\n");exit(0);}else {for (i = 0; i < N; i++) {for (j = 0; j < N; j++) {fscanf(fb, "%d", &B[i][j]);}}}fclose(fb);strassen(N, A, B, C); //strassen算法计算 A*B->C //打印结果到屏幕的同时,格式化输出到文件 FILE *fc = fopen("lc.txt", "w");for (i = 0; i < N; i++) {for (j = 0; j < N; j++) {printf("%d ", C[i][j]);if (j != N - 1) {fprintf(fc, "%d\t", C[i][j]);}else {fprintf(fc, "%d\n", C[i][j]);}}printf("\n");}return 0;
}
5. 运行结果
程序从文件la.txt和lb.txt中读入 A A A、 B B B两个矩阵,相乘后的结果矩阵 C C C会输出到屏幕上,同时会输出到文件lc.txt中。