基本-矩阵乘法¶
规则¶
对于两个张量 𝐴 和 𝐵,如果 𝐴 的最后一个维度大小与 𝐵 的倒数第二个维度大小相同,则可以进行矩阵乘法。
计算后的张量形状为:剩余维度 + 𝐴 的所有前维度 + 𝐵 的最后一维。
二维向量¶
若矩阵 𝐴 的形状为 (𝑚,𝑛),矩阵 𝐵 的形状为 (𝑛,𝑝),则它们的乘积 𝐶 是一个形状为 (𝑚,𝑝) 的矩阵。
高维向量¶
- 在深度学习中,矩阵乘法被扩展为支持高维张量的批次操作:
批次维度广播 :批次维度会自动匹配和广播。
最后两维的矩阵相乘 :仅对张量的最后两维执行矩阵乘法,其他维度被广播保留。
批次维度广播¶
右对齐比较:先从右到左逐个维度对齐两个张量的形状。
兼容性条件:两个维度相等,或其中一个为 1,可以广播。否则,无法广播,操作会报错。
示例 1:批次维度完全匹配:
import torch
A = torch.randn(2, 3, 4) # 形状: [2, 3, 4]
B = torch.randn(2, 4, 5) # 形状: [2, 4, 5]
C = torch.matmul(A, B)
print(C.shape) # 输出: [2, 3, 5]
示例 2:批次维度广播:
A = torch.randn(1, 3, 4) # 形状: [1, 3, 4]
B = torch.randn(2, 4, 5) # 形状: [2, 4, 5]
C = torch.matmul(A, B)
print(C.shape) # 输出: [2, 3, 5]
示例 3:批次维度无法广播:
A = torch.randn(3, 3, 4) # 形状: [3, 3, 4]
B = torch.randn(2, 4, 5) # 形状: [2, 4, 5]
C = torch.matmul(A, B) # 报错:批次维度 [3] 和 [2] 不兼容
计算公式¶
\[C[i,j] = \sum_{k=1}^n{A[i,k] \cdot B[k,j]}\]