• Einstein summation

Some background

numpy

a = np.arange(6).reshape((2,3))
#array([[0, 1, 2],
#       [3, 4, 5]])
a.T
#array([[0, 3],
#       [1, 4],
#       [2, 5]])
np.einsum('ij->ji', a)
#array([[0, 3],
#       [1, 4],
#       [2, 5]])
  • Take diagnal elements
a = np.arange(25).reshape(5,5)
#array([[ 0,  1,  2,  3,  4],
#       [ 5,  6,  7,  8,  9],
#       [10, 11, 12, 13, 14],
#       [15, 16, 17, 18, 19],
#       [20, 21, 22, 23, 24]])
np.diag(a)
np.einsum('ii->i', a)
np.einsum(a,[0,0],[0])
# array([ 0,  6, 12, 18, 24])
  • Sum by rows
a.sum(axis=1)
np.einsum('ij->i', a)
np.einsum(a,[0,1],[0])
  • Sum by columns
a.sum(axis=0)
np.einsum('ij->j', a)
np.einsum(a, [1,0], [0])
  • Matrix production
a = np.arange(6).reshape((2,3))
b = np.arange(12).reshape((3,4))
np.dot(a,b)
np.einsum('ij,jk->ik', a,b)


a = a.T
np.dot(a.T,b)
np.einsum('ji,jk->ik', a,b)

pytorch