definition
torch.matmul
is a function in the PyTorch library that performs matrix multiplication between two tensors. It can handle tensors of different shapes and dimensions as long as they are compatible for matrix multiplication. The function can be called in two ways:
torch.matmul(tensor1, tensor2)
: This performs matrix multiplication betweentensor1
andtensor2
. If both tensors are 2-D matrices, then the operation is simply a matrix multiplication. If one or both tensors are higher dimensional, then the function will perform batch matrix multiplication where the matrices are stacked along the first dimension.tensor1.matmul(tensor2)
: This is an instance method of a tensor that performs matrix multiplication between the tensor and another tensor. It has the same functionality astorch.matmul(tensor1, tensor2)
.
torch.matmul
is a very useful function in deep learning and other numerical applications that involve linear algebra. It is often used in neural network architectures to compute the forward pass of the model.
example
Sure! Here's an example of using torch.matmul
to perform batch matrix multiplication on two tensors of shape (batch_size, n, m)
and (batch_size, m, p)
respectively, where batch_size
is the number of matrices in the batch, and n
, m
, and p
are the dimensions of the matrices:
import torch
# Create two random tensors of shape (batch_size, n, m) and (batch_size, m, p)
batch_size = 2
n = 3
m = 4
p = 5
tensor1 = torch.randn(batch_size, n, m)
tensor2 = torch.randn(batch_size, m, p)
# Perform batch matrix multiplication using torch.matmul
result = torch.matmul(tensor1, tensor2)
# Check the shape of the result tensor
print(result.shape) # Output: torch.Size([2, 3, 5])
In this example, tensor1
and tensor2
are two random tensors with shapes (2, 3, 4)
and (2, 4, 5)
respectively, representing two batches of matrices of size (3, 4)
and (4, 5)
respectively. We use torch.matmul
to perform batch matrix multiplication between the two tensors, resulting in a tensor of shape (2, 3, 5)
, which represents the two batches of matrices of size (3, 5)
resulting from the multiplication.