主页

索引

模块索引

搜索页面

基本-矩阵乘法

规则

  • 对于两个张量 𝐴 和 𝐵,如果 𝐴 的最后一个维度大小与 𝐵 的倒数第二个维度大小相同,则可以进行矩阵乘法。

  • 计算后的张量形状为:剩余维度 + 𝐴 的所有前维度 + 𝐵 的最后一维。

二维向量

  • 若矩阵 𝐴 的形状为 (𝑚,𝑛),矩阵 𝐵 的形状为 (𝑛,𝑝),则它们的乘积 𝐶 是一个形状为 (𝑚,𝑝) 的矩阵。

高维向量

  • 在深度学习中,矩阵乘法被扩展为支持高维张量的批次操作:
    • 批次维度广播 :批次维度会自动匹配和广播。

    • 最后两维的矩阵相乘 :仅对张量的最后两维执行矩阵乘法,其他维度被广播保留。

批次维度广播

  • 右对齐比较:先从右到左逐个维度对齐两个张量的形状。

  • 兼容性条件:两个维度相等,或其中一个为 1,可以广播。否则,无法广播,操作会报错。

示例 1:批次维度完全匹配:

import torch

A = torch.randn(2, 3, 4)    # 形状: [2, 3, 4]
B = torch.randn(2, 4, 5)    # 形状: [2, 4, 5]

C = torch.matmul(A, B)
print(C.shape)  # 输出: [2, 3, 5]

示例 2:批次维度广播:

A = torch.randn(1, 3, 4)    # 形状: [1, 3, 4]
B = torch.randn(2, 4, 5)    # 形状: [2, 4, 5]

C = torch.matmul(A, B)
print(C.shape)  # 输出: [2, 3, 5]

示例 3:批次维度无法广播:

A = torch.randn(3, 3, 4)    # 形状: [3, 3, 4]
B = torch.randn(2, 4, 5)    # 形状: [2, 4, 5]

C = torch.matmul(A, B)  # 报错:批次维度 [3] 和 [2] 不兼容

计算公式

\[C[i,j] = \sum_{k=1}^n{A[i,k] \cdot B[k,j]}\]

主页

索引

模块索引

搜索页面