"Blockwise"-Matrix multiply each 2d slice in 3d array

Alec Jacobson

January 05, 2015


I'm sure I've done this before and pretty sure I've also posted it here before, but I couldn't find it.

Suppose you have to 3d-arrays:

A an M by S by K array formed for example by A = cat(3,A1,A2,A3, ... ,AK)
B an S by N by K array formed for example by B = cat(3,B1,B2,B3, ... ,BK)

and you'd like to compute a new 3d array C such that

C an M by N by K array as if formed by C = cat(3, A1 * B1, A2 * B2, ..., AK * BK)

where X*Y is the usual 2d matrix multiply. In my case, K >> M,S,N.

In matlab, a first attempt might be to write a single for loop over the last dimension of size K:

C = zeros(m,n,k);
for k = 1:K
  C(:,:,k) = A(:,:,k) * B(:,:,k);

Matlab's notorious for loops have gotten better in the last couple years, but this is still slow.

Better is to unroll the 2D loop and take advantage of vectorized, elementwise vector-vector multiplication:

C = zeros(m,n,k);
for i = 1:M
  for j = 1:S
    for k = 1:N
      C(i,k,:) = C(i,k,:) + A(i,j,:).*B(j,k,:);

Slightly better still is to use bsxfun to compute vectorize outer-products in a single for loop

C = zeros(m,n,k);
for j = 1:size(A,2)
  C = C + bsxfun(@times,A(:,j,:),B(j,:,:));

This solution works especially well if S is small compared to the other dimensions.

For (M=S=N=3, K=10000000), these take 60 secs, 13 secs, and 9 secs respectively.