definition

torch.matmul is a function in the PyTorch library that performs matrix multiplication between two tensors. It can handle tensors of different shapes and dimensions as long as they are compatible for matrix multiplication. The function can be called in two ways:

  1. torch.matmul(tensor1, tensor2): This performs matrix multiplication between tensor1 and tensor2. If both tensors are 2-D matrices, then the operation is simply a matrix multiplication. If one or both tensors are higher dimensional, then the function will perform batch matrix multiplication where the matrices are stacked along the first dimension.
  2. tensor1.matmul(tensor2): This is an instance method of a tensor that performs matrix multiplication between the tensor and another tensor. It has the same functionality as torch.matmul(tensor1, tensor2).

torch.matmul is a very useful function in deep learning and other numerical applications that involve linear algebra. It is often used in neural network architectures to compute the forward pass of the model.

example

Sure! Here's an example of using torch.matmul to perform batch matrix multiplication on two tensors of shape (batch_size, n, m) and (batch_size, m, p) respectively, where batch_size is the number of matrices in the batch, and n, m, and p are the dimensions of the matrices:

import torch

# Create two random tensors of shape (batch_size, n, m) and (batch_size, m, p)
batch_size = 2
n = 3
m = 4
p = 5
tensor1 = torch.randn(batch_size, n, m)
tensor2 = torch.randn(batch_size, m, p)

# Perform batch matrix multiplication using torch.matmul
result = torch.matmul(tensor1, tensor2)

# Check the shape of the result tensor
print(result.shape)  # Output: torch.Size([2, 3, 5])

In this example, tensor1 and tensor2 are two random tensors with shapes (2, 3, 4) and (2, 4, 5) respectively, representing two batches of matrices of size (3, 4) and (4, 5) respectively. We use torch.matmul to perform batch matrix multiplication between the two tensors, resulting in a tensor of shape (2, 3, 5), which represents the two batches of matrices of size (3, 5) resulting from the multiplication.