static void matrix_multiply(float* A, float* B, int m, int p, int n, float* C)
// Matrix Multiplication Routine
{
    // A = input matrix (m x p)
    // B = input matrix (p x n)
    // m = number of rows in A
    // p = number of columns in A = number of rows in B
    // n = number of columns in B
    // C = output matrix = A*B (m x n)
    int i, j, k;
    for (i=0;i<m;i++)
        for(j=0;j<n;j++)
        {
          C[n*i+j]=0;
          for (k=0;k<p;k++)
            C[n*i+j]= C[n*i+j]+A[p*i+k]*B[n*k+j];
        }
}

static void matrix_addition(float* A, float* B, int m, int n, float* C)
// Matrix Addition Routine
{
    // A = input matrix (m x n)
    // B = input matrix (m x n)
    // m = number of rows in A = number of rows in B
    // n = number of columns in A = number of columns in B
    // C = output matrix = A+B (m x n)
    int i, j;
    for (i=0;i<m;i++)
        for(j=0;j<n;j++)
            C[n*i+j]=A[n*i+j]+B[n*i+j];
}

static void matrix_subtraction(float* A, float* B, int m, int n, float* C)
// Matrix Subtraction Routine
{
    // A = input matrix (m x n)
    // B = input matrix (m x n)
    // m = number of rows in A = number of rows in B
    // n = number of columns in A = number of columns in B
    // C = output matrix = A-B (m x n)
    int i, j;
    for (i=0;i<m;i++)
        for(j=0;j<n;j++)
            C[n*i+j]=A[n*i+j]-B[n*i+j];
}

static void matrix_transpose(float* A, int m, int n, float* C)
// Matrix Transpose Routine
{
    // A = input matrix (m x n)
    // m = number of rows in A
    // n = number of columns in A
    // C = output matrix = the transpose of A (n x m)
    int i, j;
    for (i=0;i<m;i++)
        for(j=0;j<n;j++)
            C[m*j+i]=A[n*i+j];
}


static int matrix_inversion(float* A, int n, float* AInverse)
// Matrix Inversion Routine
{
    // A = input matrix (n x n)
    // n = dimension of A
    // AInverse = inverted matrix (n x n)
    // This function inverts a matrix based on the Gauss Jordan method.
    // The function returns 1 on success, 0 on failure.
    int i, j, iPass, imx, icol, irow;
    float det, temp, pivot, factor;
    float* ac = (float*)calloc(n*n, sizeof(float));
    det = 1;
    for (i = 0; i < n; i++)
    {
        for (j = 0; j < n; j++)
        {
            AInverse[n*i+j] = 0;
            ac[n*i+j] = A[n*i+j];
        }
        AInverse[n*i+i] = 1;
    }

    // The current pivot row is iPass.
    // For each pass, first find the maximum element in the pivot column.
    for (iPass = 0; iPass < n; iPass++)
    {
        imx = iPass;
        for (irow = iPass; irow < n; irow++)
        {
            if (fabs(A[n*irow+iPass]) > fabs(A[n*imx+iPass])) imx = irow;
        }

        // Interchange the elements of row iPass and row imx in both A and AInverse.
        if (imx != iPass)
        {
            for (icol = 0; icol < n; icol++)
            {
                temp = AInverse[n*iPass+icol];
                AInverse[n*iPass+icol] = AInverse[n*imx+icol];
                AInverse[n*imx+icol] = temp;
                if (icol >= iPass)
                {
                    temp = A[n*iPass+icol];
                    A[n*iPass+icol] = A[n*imx+icol];
                    A[n*imx+icol] = temp;
                }
            }
        }

        // The current pivot is now A[iPass][iPass].
        // The determinant is the product of the pivot elements.
        pivot = A[n*iPass+iPass];
        det = det * pivot;
        if (det == 0)
        {
            free(ac);
            return 0;
        }

        for (icol = 0; icol < n; icol++)
        {
            // Normalize the pivot row by dividing by the pivot element.
            AInverse[n*iPass+icol] = AInverse[n*iPass+icol] / pivot;
            if (icol >= iPass) A[n*iPass+icol] = A[n*iPass+icol] / pivot;
        }

        for (irow = 0; irow < n; irow++)
        {
            // Add a multiple of the pivot row to each row.  The multiple factor
            // is chosen so that the element of A on the pivot column is 0.
            if (irow != iPass) factor = A[n*irow+iPass];
            for (icol = 0; icol < n; icol++)
            {
                if (irow != iPass)
                {
                    AInverse[n*irow+icol] -= factor * AInverse[n*iPass+icol];
                    A[n*irow+icol] -= factor * A[n*iPass+icol];
                }
            }
        }
    }

    free(ac);
    return 1;
}

/*
static void matrix_print(float* A, int m, int n)
// Matrix print.
{
    // A = input matrix (m x n)
    // m = number of rows in A
    // n = number of columns in A
    int i, j;
    for (i=0;i<m;i++)
    {
        printf("| ");
        for(j=0;j<n;j++)
        {
            printf("%7.3f ", A[n*i+j]);
        }
        printf("|\n");
    }
}
*/

// n states
// m inputs
// r outputs
#define n 2
#define m 1
#define r 1

float kalman(float gyroscope_rate, float accelerometer_angle)
{
    // A is an n by n matrix
    // B is an n by m matrix
    // C is an r by n matrix
    // Sz is an r by r matrix
    // Sw is an n by n matrix
    // xhat is an n by 1 vector
    // P is an n by n matrix
    // y is an r by 1 vector
    // u is an m by 1 vector

    // Constants.
    static float A[n][n] = {{1.0, -0.019968}, {0.0, 1.0}};
    static float B[n][m] = {{0.019968}, {0.0}};
    static float C[r][n] = {{1.0, 0.0}};
    static float Sz[r][r] = {{17.2}};
    static float Sw[n][n] = {{0.005, 0.005}, {0.005, 0.005}};

    // Persistant states.
    static float xhat[n][1] = {{0.0}, {0.0}};
    static float P[n][n] = {{0.005, 0.005}, {0.005, 0.005}};

    // Inputs.
    float u[m][m];              // Gyroscope rate.
    float y[m][m];              // Accelerometer angle.

    // Temp values.
    float AP[n][n];             // This is the matrix A*P
    float CT[n][r];             // This is the matrix C'
    float APCT[n][r];           // This is the matrix A*P*C'
    float CP[r][n];             // This is the matrix C*P
    float CPCT[r][r];           // This is the matrix C*P*C'
    float CPCTSz[r][r];         // This is the matrix C*P*C'+Sz
    float CPCTSzInv[r][r];      // This is the matrix inv(C*P*C'+Sz)
    float K[n][r];              // This is the Kalman gain.
    float Cxhat[r][1];          // This is the vector C*xhat
    float yCxhat[r][1];         // This is the vector y-C*xhat
    float KyCxhat[n][1];        // This is the vector K*(y-C*xhat)
    float Axhat[n][1];          // This is the vector A*xhat 
    float Bu[n][1];             // This is the vector B*u
    float AxhatBu[n][1];        // This is the vector A*xhat+B*u
    float AT[n][n];             // This is the matrix A'
    float APAT[n][n];           // This is the matrix A*P*A'
    float APATSw[n][n];         // This is the matrix A*P*A'+Sw
    float KC[n][n];             // This is the matrix K*C
    float KCP[n][n];            // This is the matrix K*C*P
    float KCPAT[n][n];          // This is the matrix K*C*P*A'

    // Fill in inputs.
    u[0][0] = gyroscope_rate;
    y[0][0] = accelerometer_angle;

#if 0
    // Print various matrices.
    printf("u =\n");
    matrix_print((float*) u, m, m);
    printf("y =\n");
    matrix_print((float*) y, m, m);
    printf("A =\n");
    matrix_print((float*) A, n, n);
    printf("B =\n");
    matrix_print((float*) B, n, m);
    printf("State =\n");
    matrix_print((float*) xhat, n, 1);
#endif

    // Update the state estimate by extrapolating estimate with gyroscope input.
    // xhat_est = A * xhat + B * u
    matrix_multiply((float*) A, (float*) xhat, n, n, 1, (float*) Axhat);
    matrix_multiply((float*) B, (float*) u, n, r, 1, (float*) Bu);
    matrix_addition((float*) Axhat, (float*) Bu, n, 1, (float*) AxhatBu);

#if 0
    printf("State Estimate =\n");
    matrix_print((float*) AxhatBu, n, 1);
#endif

    // Compute the innovation.
    // Inn = y - c * xhat;
    matrix_multiply((float*) C, (float*) xhat, r, n, 1, (float*) Cxhat);
    matrix_subtraction((float*) y, (float*) Cxhat, r, 1, (float*) yCxhat);

#if 0
    printf("Innovation =\n");
    matrix_print((float*) yCxhat, r, 1);
#endif

    // Compute the covariance of the innovation.
    // s = C * P * C' + Sz
    matrix_transpose((float*) C, r, n, (float*) CT);
    matrix_multiply((float*) C, (float*) P, r, n, n, (float*) CP);
    matrix_multiply((float*) CP, (float*) CT, r, n, r, (float*) CPCT);
    matrix_addition((float*) CPCT, (float*) Sz, r, r, (float*) CPCTSz);

    // Compute the kalman gain matrix.
    // K = A * P * C' * inv(s)
    matrix_multiply((float*) A, (float*) P, n, n, n, (float*) AP);
    matrix_multiply((float*) AP, (float*) CT, n, n, r, (float*) APCT);
    matrix_inversion((float*) CPCTSz, r, (float*) CPCTSzInv);
    matrix_multiply((float*) APCT, (float*) CPCTSzInv, n, r, r, (float*) K);

    // Update the state estimate.
    // xhat = xhat_est + K * Inn;
    matrix_multiply((float*) K, (float*) yCxhat, n, r, 1, (float*) KyCxhat);
    matrix_addition((float*) AxhatBu, (float*) KyCxhat, n, 1, (float*) xhat);

    // Compute the new covariance of the estimation error.
    // P = A * P * A' - K * C * P * A' + Sw
    matrix_transpose((float*) A, n, n, (float*) AT);
    matrix_multiply((float*) AP, (float*) AT, n, n, n, (float*) APAT);
    matrix_addition((float*) APAT, (float*) Sw, n, n, (float*) APATSw);
    matrix_multiply((float*) K, (float*) C, n, r, n, (float*) KC);
    matrix_multiply((float*) KC, (float*) P, n, n, n, (float*) KCP);
    matrix_multiply((float*) KCP, (float*) AT, n, n, n, (float*) KCPAT);
    matrix_subtraction((float*) APATSw, (float*) KCPAT, n, n, (float*) P);

    // Return the estimate.
    return xhat[0][0];
}
