"Blockwise"-Matrix multiply each 2d slice in 3d array

Alec Jacobson

January 05, 2015

weblog/

I'm sure I've done this before and pretty sure I've also posted it here before, but I couldn't find it.

Suppose you have to 3d-arrays:

A an M by S by K array formed for example by A = cat(3,A1,A2,A3, ... ,AK)
B an S by N by K array formed for example by B = cat(3,B1,B2,B3, ... ,BK)

and you'd like to compute a new 3d array C such that

C an M by N by K array as if formed by C = cat(3, A1 * B1, A2 * B2, ..., AK * BK)

where X*Y is the usual 2d matrix multiply. In my case, K >> M,S,N.

In matlab, a first attempt might be to write a single for loop over the last dimension of size K:

C = zeros(m,n,k);
for k = 1:K
  C(:,:,k) = A(:,:,k) * B(:,:,k);
end

Matlab's notorious for loops have gotten better in the last couple years, but this is still slow.

Better is to unroll the 2D loop and take advantage of vectorized, elementwise vector-vector multiplication:

C = zeros(m,n,k);
for i = 1:M
  for j = 1:S
    for k = 1:N
      C(i,k,:) = C(i,k,:) + A(i,j,:).*B(j,k,:);
    end
  end
end

Slightly better still is to use bsxfun to compute vectorize outer-products in a single for loop

C = zeros(m,n,k);
for j = 1:size(A,2)
  C = C + bsxfun(@times,A(:,j,:),B(j,:,:));
end

This solution works especially well if S is small compared to the other dimensions.

For (M=S=N=3, K=10000000), these take 60 secs, 13 secs, and 9 secs respectively.