I was looking for an implementation of Strassen Algorithm in C, and I found this code at the end.
To use the multiply function:
void multiply(int n, matrix a, matrix b, matrix c, matrix d);
which multiplies the two matrices a , b and puts the result in c ( d is an intermediate matrix). Matrices a and b must be of the following type:
typedef union _matrix { double **d; union _matrix **p; } *matrix;
I dynamically selected four matrices a , b , c , d (two-dimensional arrays of twins) and assigned their addresses to the _matrix.d field:
#include "strassen.h" #define SIZE 50 int main(int argc, char *argv[]) { double ** matA, ** matB, ** matC, ** matD; union _matrix ma, mb, mc, md; int i = 0, j = 0, n; matA = (double **) malloc(sizeof(double *) * SIZE); for (i = 0; i < SIZE; i++) matA[i] = (double *) malloc(sizeof(double) * SIZE);
This code compiles successfully, but with an error n > BREAK .
strassen.c:
#include "strassen.h" /* c = a * b */ void multiply(int n, matrix a, matrix b, matrix c, matrix d) { if (n <= BREAK) { double sum, **p = a->d, **q = b->d, **r = c->d; int i, j, k; for (i = 0; i < n; i++) for (j = 0; j < n; j++) { for (sum = 0., k = 0; k < n; k++) sum += p[i][k] * q[k][j]; r[i][j] = sum; } } else { n /= 2; sub(n, a12, a22, d11); add(n, b21, b22, d12); multiply(n, d11, d12, c11, d21); sub(n, a21, a11, d11); add(n, b11, b12, d12); multiply(n, d11, d12, c22, d21); add(n, a11, a12, d11); multiply(n, d11, b22, c12, d12); sub(n, c11, c12, c11); sub(n, b21, b11, d11); multiply(n, a22, d11, c21, d12); add(n, c21, c11, c11); sub(n, b12, b22, d11); multiply(n, a11, d11, d12, d21); add(n, d12, c12, c12); add(n, d12, c22, c22); add(n, a21, a22, d11); multiply(n, d11, b11, d12, d21); add(n, d12, c21, c21); sub(n, c22, d12, c22); add(n, a11, a22, d11); add(n, b11, b22, d12); multiply(n, d11, d12, d21, d22); add(n, d21, c11, c11); add(n, d21, c22, c22); } } /* c = a + b */ void add(int n, matrix a, matrix b, matrix c) { if (n <= BREAK) { double **p = a->d, **q = b->d, **r = c->d; int i, j; for (i = 0; i < n; i++) for (j = 0; j < n; j++) r[i][j] = p[i][j] + q[i][j]; } else { n /= 2; add(n, a11, b11, c11); add(n, a12, b12, c12); add(n, a21, b21, c21); add(n, a22, b22, c22); } } /* c = a - b */ void sub(int n, matrix a, matrix b, matrix c) { if (n <= BREAK) { double **p = a->d, **q = b->d, **r = c->d; int i, j; for (i = 0; i < n; i++) for (j = 0; j < n; j++) r[i][j] = p[i][j] - q[i][j]; } else { n /= 2; sub(n, a11, b11, c11); sub(n, a12, b12, c12); sub(n, a21, b21, c21); sub(n, a22, b22, c22); } }
strassen.h:
#define BREAK 8 typedef union _matrix { double **d; union _matrix **p; } *matrix; #define a11 a->p[0] #define a12 a->p[1] #define a21 a->p[2] #define a22 a->p[3] #define b11 b->p[0] #define b12 b->p[1] #define b21 b->p[2] #define b22 b->p[3] #define c11 c->p[0] #define c12 c->p[1] #define c21 c->p[2] #define c22 c->p[3] #define d11 d->p[0] #define d12 d->p[1] #define d21 d->p[2] #define d22 d->p[3]
My question is how to use the multiply function (how to implement the matrix).
strassen.h
strassen.c