Inversion of matrix made of diagonal blocks

Alec Jacobson

July 20, 2014

weblog/

Recently I needed to invert (a.k.a. solve against) a large matrix with a special form. It was a block matrix where each block was diagonal:

 M = [M11 M12 ... M1n
      M21 ...
      ...
      Mn1 Mn2 ... Mnn];

where each Mij was a diagonal matrix. In my case Mij=Mji, though I won't rely on this for the following.

Note that this is not the same as a "block diagonal matrix" though it can be rearrange into such a matrix.

The inverse of my kind of matrix will have the same sparsity pattern. I can find it using a recursive application of the blockwise matrix inversion formula.

Here's a little function to do exactly that:

function N = diag_blk_inv(M,k)
  % DIAG_BLK_INV Invert a block matrix made out of square diagonal blocks. M
  % should be of the form:
  %
  %   M = [M11 M12 ... M1n
  %        M21 ...
  %        ...
  %        Mn1 Mn2 ... Mnn];
  % where each Mij is a diagonal matrix.
  %
  % Inputs:
  %   M  n*k by n*k matrix made out of diagonal blocks.
  % Outputs:
  %   N  n*k by n*k matrix inverse of M
  %

  switch k
  case 1
    assert(isdiag(M));
    N = diag(sparse(1./diag(M)));
    return;
  end

  assert(size(M,1)==size(M,2),'Must be square');
  assert(rem(size(M,1),k)==0,'Must be divisble by k');
  n = size(M,1)/k;
  % Extract 4 blocks
  A = M(0*n+(1:(k-1)*n),0*n+(1:(k-1)*n));
  B = M(0*n+(1:(k-1)*n),(k-1)*n+(1:n));
  C = M((k-1)*n+(1:n),0*n+(1:(k-1)*n));
  D = M((k-1)*n+(1:n),(k-1)*n+(1:n));
  assert(isdiag(D));
  % https://en.wikipedia.org/wiki/Invertible_matrix#Blockwise_inversion
  Ainv = diag_blk_inv(A,k-1);
  % Schur complement 
  S = (D-C*Ainv*B);
  assert(isdiag(S));
  Sinv = diag_blk_inv(S,1);
  N = [Ainv + Ainv*B*Sinv*C*Ainv -Ainv*B*Sinv; ...
       -Sinv*C*Ainv              Sinv];
end

Update: No doubt someone will see that I'm computing the inverse explicitly and call me a heretic. Actually in the case of diagonal matrices there's some room for argument why one might want to do this. In terms of efficiency, this is orders of magnitude faster than calling inv(M)*B or M\B. However, matlab's \ doesn't seem to figure out that it could use cholesky decomposition. In that case L = chol(M);L\(L'\B) is roughly the same speed as diag_blk_inv(M,k)*B: both are fast.

Update: To be my own policeman, it probably is safer to use chol in this case. Otherwise I would in general I would need to check the scaling of M. It's unfortunate though because a priori I don't have any reason to believe chol will see the special structure of this matrix.