本章主要内容不是为 pytorch 的所有方法进行详细的讲解,而是提供对 pytorch 的一些必要的、角度不一样的理解。
# 1.tensor 的数据类型
tensor 又称张量,可以认为是计算的基本单元,以浮点数的方式存放在 GPU 中,可以用来存储几乎所有东西,比如参数,梯度,激活值,优化器状态等。
float32:

用 32 位存放一个 tensor 是默认的存储格式,float32 可以简称为 FP32,又称单精度浮点数,或者全精度浮点数(在深度学习里的叫法)。
FP32 由 1 符号位,8 指数位和 23 尾数位(分数位)构成,一个 FP32 数据占 4 个字节。其能表示的范围大小由指数位决定,分数的精度由尾数位决定。
** 一个张量所占内存空间 = 其元素总数 × 单元素所占字节数。** 假设 x = torch.zeros (4,8),x 使用 FP32 格式存储,则 x 所占内存大小为 4 * 8 * 4 = 128 字节。
FP32 优点是动态范围和分数精度足够大,能适用大多数深度学习场景,但问题是占据较大的内存空间。
float16:

为了减小内存开销,很容易想到的办法是减少存储位数,于是有了 float16,又称 FP16。
与全精度对应地,FP16 也叫半精度。
符号位依然为 1 位,指数位和尾数位各自位数变为原来的一半。这相当于将原先的 FP32 缩小为一半的比例,因此内存大大减小了,训练速度也能变快很多。具体而言,一个 FP16 数据占 2 个字节。
但是,单纯地缩放比例为 FP16 造成了比较大的问题。
假设 x = torch.tensor ([1e-8], dtype=torch.float16),那么执行 assert x == 0 你会发现是能通过的,也就是说 x 是 0。这明显是不对的,那么为什么会这样?答案是下溢。1e-8 小于 FP16 能表示的最小正数,因此直接被压扁为 0。
这说明 float16 对非常小的数动态范围不够,训练里如果出现很小的梯度或激活,就可能直接被压扁成 0,造成数值不稳定、梯度爆炸、消失。相比之下 float32 指数位更多,动态范围更大,所以不容易在这个量级下溢。
bfloat16:

这是谷歌大脑于 2018 年提出的新存储格式,又称 bfp16。
它在 float16 的基础上进行改进,很大程度上避免了动态范围不够导致的溢出问题。
bfp16 使用与 FP16 相同的内存空间(2 字节),用 3 个尾数位扩展指数位,损失了分数精度,带来了与 FP32 相当的动态范围。然而,损失的这部分分数精度并不会造成很大的影响,在深度学习中,人们更加关注动态范围。
因此,如果 x = torch.tensor ([1e-8], dtype=torch.float16),那么执行 assert x != 0 是通过的。
FP8:

2022 年,FP8 被提出了,NVIDIA 在 H100 中加入了对 FP8 的支持。
FP8 有 E4M3(范围 [-448, 448])和 E5M2(范围 [-57344, 57344])两种格式。
具体细节可以参考这篇论文:FP8 Formats for Deep Learning。
如何选择?
- FP32 具有更高精度和动态范围,但需要很多内存。
- 用 FP8、bfp16 训练可能会带来不稳定性,但是加快了训练速度,减少了内存开销。
- 人们更愿意用 bfp16 而不是 fp16。
最好的策略就是混合精度训练。
# 2. 一些 tensor 碎碎念
认识 tensor:
在 PyTorch 里,tensor 本质上可以理解为:指向一段连续内存的指针 + 一组描述如何索引这段内存的元数据,元数据主要包括:
- shape:每一维的长度,比如 x.shape == (4, 4)。
- dtype:每个元素占多少字节,即上面提到的数据类型。
- stride:步长,这是描述沿着每个维度移动 1 个索引时,内存地址要跳过多少个元素。
假设有一个二维 tensor 变量 x = torch.tensor ([[0,1,2,3],[4,5,6,7],[8,9,10,11],[12,13,14,15]])。通常说 x 的大小,实际上指的是 x 索引的内存空间占用,通过 x.numel () * x.element_size () 来计算,即元素个数乘以每元素字节数,而 x.size () 和 x.shape 只表示 tensor 的维度长度和形状。
对于一个多维的 tensor,每一维度都有一个步长 stride。当你把第 i 维的索引增加 1,其他维不变时,在底层存储里要向前跳过多少个元素。

比如,stride [0] 指的是第 0 维(行)的步长,这里 x 的 stride [0]==4,那么 x [0][0] 跳到 x [1][0] 即跳一行,相当于在内存中要跳 4 个元素。对于 x.stride [1]==1,x [0][0] 跳到 x [0][1] 相当于在内存只要往右移动一个元素。
假设有这样的 tensor:x [r][c],假设内存从 0 开始,那么可以在底层一维存储中找到它的位置:offset = r * stride[0] + c * stride[1]。
tensor 的内存布局:
既然我们知道了创建的 tensor 变量实际上是一张有关内存的视图,那么很自然知道 pytorch 的很多操作并没有直接在数据上进行拷贝,而是在视图上进行操作和变换。
也就是说,很多操作如 transpose,permute,select,切片等返回的只是新的 view,它们通常共享同一块存储,只是改变了 stride,shape,offset,因此几乎是 O (1) 的。只有当一个操作需要改变数据在内存中的物理排列,或者需要生成新的数据时,才会发生拷贝或新分配操作,如 contiguous ()、大多 reshape 操作、clone () 等。
但并不是所有形状变换都能只靠修改视图完成,比如 transpose 或者 permute 之后,tensor 往往变成非连续的(non-contiguous),这是因为 stride 变了,这时候使用 view () 对其形状进行变换时就报错。像 view () 这种操作要求 tensor 的内存布局满足特定条件(通常要是 contiguous),否则就无法在不拷贝的情况下重解释同一块内存。
jaxtyping 标记法:
过往我们通过 x = torch.ones (2, 2, 1, 3) 创建一个 tensor,现在通过 jaxtyping 可以这样创建:x: Float [torch.Tensor, "batch seq heads hidden"] = torch.ones (2, 2, 1, 3)。
jaxtyping 将维度分别命名为 batch seq heads hidden,在后续用 einops 操作张量时,能更清晰地表达维度含义并减少维度错误。
einops:
einops 是一个用于 tensor 变换与计算的库,它的灵感来自爱因斯坦求和记法,支持对维度进行命名,并对其进行操作。
最常用的操作是 einsum。
假设有两个 tensor,x: Float [torch.Tensor, "batch seq1 hidden"] = torch.ones (2, 3, 4),y: Float [torch.Tensor, "batch seq2 hidden"] = torch.ones (2, 3, 4)。过往我们用 z = x @ y.transpose (-2, -1) 为 x 和 y 进行矩阵乘法,现在我们可以使用 einsum,即 z = einsum (x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")。核心操作的是后面两个维度,如果你嫌其他维度写的繁琐,你可以用三个点替代:z = einsum (x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")。在这里,batch 被三个点替代了。
einsum 是用于张量乘法、求和的通用计算接口,而 reduce 则是对某些维度做聚合(sum/mean/max 等),也是很常见。
假设有 x: Float [torch.Tensor, "batch seq hidden"] = torch.ones (2, 3, 4),如何对最后一维做平均?以往我们使用 y = x.mean (dim=-1),如今可以使用 y = reduce (x, "... hidden -> ...", "mean") 得到形状为 "batch seq" 的 tensor,每一个元素都是在 hidden 维度上做平均计算的值。
有时候想要把一个维度拆成两个,或者重新将两个维度编排为一个,就需要用到 rearrange。
假设有 x: Float [torch.Tensor, "batch seq total_hidden"] = torch.ones (2, 3, 8),total_hidden 实际上是两个维度的乘积:heads * hidden1。那么 x = rearrange (x, "... (heads hidden1) -> ... heads hidden1", heads=2) 可以通过指定其中一维度 heads,将 total_hidden 拆为 heads 维和 hidden1 维。此时有另一个 w: Float [torch.Tensor, "hidden1 hidden2"] = torch.ones (4, 4),就能执行矩阵乘法:x = einsum (x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")。最后,可以用 rearrange 将拆出来的 heads 与运算得到的 hidden2 合并为原来的 total_hidden:x = rearrange (x, "... heads hidden2 -> ... (heads hidden2)")。
# 3.FLOPs、FLOPS、MFU
FLOPs 和 FLOPS 有什么区别?
FLOPs 指的是一个算法需要做多少次浮点数运算(主要是浮点数加法和浮点数乘法),被衡量为一个算法的时间复杂度。
FLOPS 又可以写作 FLOP/s,指的是一个机器一秒钟能做多少次浮点数运算,被用于衡量硬件的性能。
对于一个硬件在设计完毕后,通常有一个理论峰值 FLOPS,但是实际运行往往不能达到这个理论峰值。因此对于一个硬件,任何运行的时刻都有一个实际 FLOPS,那么可以引入 MFU 作为衡量硬件发挥性能的程度。
定义 MFU = 实际 FLOPS / 理论峰值 FLOPS。通常,当 MFU >= 0.5 时,我们就说这个硬件已经非常好地发挥了其性能,尤其当一个算法的计算由矩阵乘法所主导。
# 4. 一些运算所需 FLOPs
逐元素操作:
逐元素操作有矩阵加法、矩阵点乘等,这些操作主要聚焦于两个形状相同的矩阵,并对它们进行浮点数运算。
比如两个 m * n 的矩阵进行加法所需的 FLOPs 可以理解为 m * n 次浮点数加法操作次数。对于矩阵点乘操作的 FLOPs,同样可以理解为 m * n 次浮点数乘法的操作次数。
因此可以认为,对于两个 m * n 的矩阵进行逐元素操作的 FLOPs 可以认为是 m * n,即 O (m * n) 复杂度。
矩阵乘法:
有形状为 [B, D] 的矩阵 A,以及形状为 [D, K] 的矩阵 B,矩阵乘法需要多少 FLOPs?
矩阵乘法的过程,可以拆解为:从矩阵 A 的某一行开始,与矩阵 B 的每一列进行向量内积,得到的结果作为新矩阵的第一行,然后对矩阵 A 的下一行继续执行这种操作,循环直到矩阵 A 遍历完。向量内积的过程是两个向量对应每个分量进行相乘后累加的结果,因此浮点数乘法和加法都有涉及。
上面的例子中,从矩阵 A 拿出一个行向量,有 D 个元素,从矩阵 B 拿出一个列向量,也是 D 个元素。对向量进行内积,需要 D 次浮点数乘法和 D 次浮点数加法,总共 2 * D 次 FLOPs。
矩阵 A 的一个行向量需要与矩阵 B 的所有列向量进行内积,才得到新矩阵的一行。矩阵 B 有 K 列,因此进行了 K 次内积,总共 2 * D * K 次 FLOPs。而矩阵 A 有 B 行,因此还需要在此基础上重复 B 次运算才能得到整个新矩阵,因此一个矩阵乘法需要 2 * B * D * K 次 FLOPs。
# 5. 前反向传播的 FLOPs
小例子:
假如有 1024 块 H100,规定单块 H100 的算力是 (1979e12) / 2 flops/s,用 15T 的 tokens 训练一个 70B 的大模型,需要多少天?
对于模型的一个参数,需要见遍 15e12 个 token(数据点),那么 70B 的参数,就至少需要计算 70e9*15e12 次。
在 1 个 epoch 的情况下,前向传播计算了 270e915e12 次。这是因为对于矩阵里的一个元素,需要做乘法与加法两次运算。也就是说一个参数需要与一个 token 运算两次,那么 70e9 个参数,15e12 个 token,就要运算 270e915e12 次。
对于反向传播,则需要 470e915e12 次计算。反向传播可以看成要完成两个任务:更新该层参数,把梯度往前传。
第一个任务,优化器得到后一层的梯度∂L/∂y,开始计算。如果要更新网络在该层的参数 W,需要优化器知道损失对参数的导数∂L/∂W,以线性层为例,∂L/∂W = x⊤ * ∂L/∂y。因此更新该层参数也是需要一次矩阵乘法,可知需要 270e915e12 次运算。
第二个任务在更新完该层参数后进行。前一层看到的 “输出” 是这一层的输入 x,它需要自己的梯度来更新自己那一层的参数,也就是说要把∂L/∂x 传递给前一层当作它的∂L/∂y 来更新参数,而∂L/∂x = W⊤ * ∂L/∂y。这也是一次矩阵乘法,需要 270e915e12 次运算。
因此加上反向传播的 470e915e12,一个 epoch 总共需要 670e915e12 次计算。
一天的计算量 FLOPs = 一块 GPU 一秒的计算量 * GPU 的 mfu * GPU 数量 * 时间(即 606024)。那么所需训练的天数就是总计算量 / 一天的计算量。
假设 H100 的 mfu=0.5,那么可以计算出时间大约为 144 天(四舍五入到个位)。
抽象到感知机模型:
假设有一个两层的线性模型,输入为 x = torch.ones (B, D, device='cuda:0'),第一层权重:w1 = torch.randn (D, D, device=device, requires_grad=True),第二层权重:w2 = torch.randn (D, K, device=device, requires_grad=True)。
前向传播的过程是,x 经过 w1 计算出 h1,h1 的形状是 B * D,h1 经过 w2 计算出输出 h2,h2 的形状是 B * K,然后用 h2 计算出损失 loss(假设使用均方损失),公式化为:h1 = x @ w1,h2 = h1 @ w2,loss = h2.pow (2).mean ()
那么前向传播所需计算量 FLOPs 即为矩阵乘法的计算量 FLOPs(如果忽略激活函数的计算量),因此这里所需 FLOPs 是两个矩阵乘法的 FLOPs:2 * (B * D * D) + 2 * (B * D * K) = 2 * B * D * (D + K)
那么如果有更多层?注意到 2 * B * D * (D + K) 最后括号内实际上是两层的权重参数和,而 B * D 实际上是 x 矩阵有多少个元素,即有多少数据点(token)。所以很自然知道,无论多少层,前向传播的 FLOPs 近似为 2 * 数据点 * 参数量。
继续回到两层的情况,如何计算反向传播的 FLOPs?
像上面的小例子提到的,反向传播需要计算每一层参数的梯度,以及传给前一层的梯度。因此参考前向传播的链条过程,反向传播需要依次计算 h2.grad = d loss /d h2,w2.grad = d loss /d w2,h1.grad = d loss /d h1,w1.grad = d loss /d w1。
从前面的例子知道,如果求该层权重参数 W 的梯度 dW,需要用该层的输入 x 与下一层传来的梯度进行矩阵乘法;如果要求传给前一层的梯度 dx,就需要用该层的权重 W 与下一层传来的梯度进行矩阵乘法,即分别浓缩为两个等式:∂L/∂W = x⊤ * ∂L/∂y 与∂L/∂x = W⊤ * ∂L/∂y。
那么你会问了,h2.grad 怎么计算?求向前传的梯度难道不是用该层参数 W 与下一层传来的梯度进行矩阵乘法吗?但是 h2 已经是输出,不存在该层参数 W 和下一层传来的梯度,它是作为反向传播算梯度的起点。
实际上,用该层参数 W 与下一层传来的梯度进行矩阵乘法求向前传的梯度,是在线性层反转时出现的,对于输出 h2 的梯度的单独计算,要回归 h2.grad = d loss /d h2 这个求导公式本身。实际上,求输出 h2 的梯度主要是逐元素的计算,并不涉及矩阵乘法,因此这种计算的代价相对于矩阵乘法通常很小,经常被当低阶项。
因此,让我们来计算一下两层所需 FLOPs。计算 w2.grad 是一次矩阵乘法,需要:2 * B * D * K,计算 h1.grad 也是一次矩阵乘法,需要 2 * B * D * K,在第二层总共需要 4 * B * D * K 。在第一层计算 w1.grad 需要 2 * D * B * D ,x.grad(如果继续反传):2 * B * D * D,第一层合计:4 * B * D * D,两层合计 4 * B * D * (D + K)。
因此无论多少层,反向传播的 FLOPs 可以近似为 4 * 数据点 * 参数量。
其实很多时候并不需要 x.grad,因此看似反向传播的 FLOPs 还要少 2 * B * D * D,上面的结论不太对?
如果模型的层数比较少,确实减少这部分会带来较大的变化。但是在模型层数都较多,模型较大的情况,少这一小部分对整体 FLOPs 没多大影响。因此在当下模型普遍层数较深的情况下,少了计算 x.grad 部分的 FLOPs,对整体 FLOPs 没多大影响,依然近似为 4 * 数据点 * 参数量。
因此,训练所需总 FLOPs≈(2+4)(数据点)(参数)。
抽象到 Transformer:
上面抽象到感知机的结论依然适用于 Transformer。对于纯解码器的 Transformer,除去 embedding 部分(tokenization 部分以及 position encoding 操作),前向传播所需的 FLOPs 依然近似于 2 * (数据点) * (参数量),反向传播所需的 FLOPs 近似于 4 * (数据点) * (参数量),训练总 FLOPs 近似于 6 * (数据点) * (参数量)。
关于具体的推导,可以参考这篇博客:https://www.adamcasson.com/posts/transformer-flops#counting-flops-in-transformers
# 7. 随机性
随机性出现在许多地方:参数初始化、dropout、数据排序等。
为了提高可复现性,建议确定一个随机种子来控制随机性,并在这三个常见的地方一次性设置好:torch.manual_seed (seed)、 np.random.seed (seed)、random.seed (seed)
确定性在 debug 时也特别有用,这样你就能追踪到错误。
# 8.Memmap 数据加载
经过 tokenization 处理后的数据是整数序列,有种做法是将这些数据存储为 numpy arrays,比如有这样的 tokens 序列:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],可以转换为 numpy 的 array 数据格式(通常被 tokenization 来实现):orig_data = np.array ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.int32)。
然后用 orig_data.tofile ("data.npy") 写出原始二进制文件(更像 .bin/.dat,.npy 只是一种后缀示例),读取时需要自己知道 dtype(以及形状)。
在需要训练数据时,可以用 numpy 的 memmap 操作来 “懒加载” 这些数据:data = np.memmap ("data.npy", dtype=np.int32)。它不会一次性把整个数据文件读入内存,而是把数据文件映射到虚拟内存中,访问到某个片段时由操作系统按需加载对应页,从而实现 “懒加载”。这对于超大规模语料(例如数十亿 token)非常有用。
# 9. 参数初始化
如果参数没有初始化会怎么样?
如果某一层的权重全是 0,那么这层输出一开始全是 0。
同一层里如果很多神经元权重完全一样(尤其全 0),它们在前向得到一样的输出、反向得到一样的梯度,更新也一模一样 —— 等于这些神经元永远学成同一个东西,模型容量被浪费。
如果初始化不当,尺度不对会怎么样?
假设初始化输入和参数:x = nn.Parameter (torch.randn (input_dim)),w = nn.Parameter (torch.randn (input_dim, output_dim)),然后计算输出 output = x @ w。
torch.randn 将创建一个均值为 0,方差为 1(标准正态分布)的变量。x @ w,实际上是 x 与 w 的每一列进行向量内积,方差为 1 的分量对应相乘后方差还是为 1,因为 x 维度为 input_dim,也就是有 input_dim 个元素,因此分量相乘后累加得到的结果方差为 input_dim(方差的相加相乘规律),也就是输入的维度。
那么,一个标准差与输入维度成正比的输出,一定会随着网络逐渐变深,变得越来越大,让后续层的激活、残差、softmax 更容易进入数值不稳定区间,反向传播时梯度也会被这种尺度放大或缩小,导致爆炸或消失,在训练上表现就是不稳定、loss 抖,甚至 NaN。
初始化的目的:
让信号在网络里传播时,方差不要随着层数系统性放大或缩小,所以常见初始化(Xavier/Kaiming 等)会让权重方差跟输入维度成反比,比如 Xavier 初始化:w = randn (...) /sqrt (input_dim)。
这样就让输出的尺度大致与输入无关,训练更稳定。
# 10. 参数量计算与内存占用
你有 8 块 H100,一块 H100 的内存是 80e9 bytes,如果你使用 AdamW 优化算法,你可以训练多大的模型?
每个参数在实际训练时通常要同时存 4 个相同形状的 tensor,如果使用默认的 FP32(float32),每个 tensor 都是 4 字节。参数本身存 4 字节,梯度 4 字节,以及 AdamW 的两份优化器状态(即一阶动量 / 梯度 EMA 和二阶动量 / 梯度平方 EMA),每份 4 个字节。因此,一个参数在训练时占 4+4+(4+4)字节。
因此用总的内存字节数除以一个参数训练时占多少字节,就能得到训练的所允许的参数总量,即模型大小,可知能训练 80e9 * 8 / (4+4+4+4) ≈ 40e9,即 40B 大小的模型。
用 Pytorch 计算模型参数量:
模型的参数通常存储在 nn.Parameter 对象中,因此借助 Pytorch 计算模型的参数量并不难,见如下计算模型参数量的函数:

除了 model.parameters () 方法能直接得到模型的参数列表,还可以用 model.state_dict ().items () 得到更加详细的模型参数情况:

model.state_dict ().items () 返回模型每层权重名,以及其对应的参数,均为列表。
训练内存占用:
训练时,GPU 内存的占用主要来自:模型参数,前向传播过程中产生的中间变量,梯度以及优化器存储的状态。
模型参数不用说,GPU 内存需要存储模型的参数,这是一部分内存占用,近似于模型参数量 * bytes_per_element。
前向传播过程中产生的中间变量也占据相当一部分内存。从前面计算反向传播的 FLOPs 内容中,计算某层权重参数的梯度需要用到该层中间变量,并与下一层传上来的梯度进行矩阵乘法。因此在反向传播过程中,需要前向传播保留中间变量。
在模型的某一隐藏层,假设网络产生的中间输出形状是 [B, D],B 是 batch size,D 是隐藏维度。模型需要保存这个中间变量,就需要占用 B * D * bytes_per_element 的内存。
那么,对于一整个模型,将每一层这样的中间变量都存储起来,近似需要 B * D * num_layers * bytes_per_element 的内存,num_layers 是模型层数。
这忽略了很多细节,比如还可能要存输入、最后一层输出、以及每层内部的其他临时量等。另外,如果层之间的隐藏维度 D 设计地不一致时,这样将隐藏层的维度都统一为 D 的计算也会带来一定的偏差。而且,某些非线性激活函数,如 GELU,反向传播计算梯度时也会用到中间变量。所以有时候这部分的显存占用会比模型参数带来的显存占用更高。
反向传播的产物就是梯度张量,然后存储起来交给优化器更新参数。每个可训练参数几乎都对应一个同形状的梯度张量,因此这部分的显存占用是模型参数量 * bytes_per_element。
最后一个显存占用的大头就是优化器,优化器主要存储每个参数的 “历史状态”,这些状态通常至少和参数一样大,甚至更大。以 AdamW 为例,更新一个参数,大致需要存储参数梯度的一阶动量(梯度的指数滑动平均)以及二阶动量(梯度平方的指数滑动平均),这两份可以认为和参数、梯度一样大。这部分的显存占用通常和选取的优化器、dtype 的大小有关。
Pinned Memory:
Pinned memory,也称页锁定内存,指的是把 CPU 端的一块内存锁定为不可被操作系统换页的 “固定内存”。GPU 通过 DMA 从主机内存读数据时,要求那块内存地址稳定,被锁定的内存满足这一点,而且在 pinned Memory 里,拷贝可以和 GPU 计算并行,使得拷贝更高效。
通常设置 DataLoader (..., pin_memory=True) 让 batch 在 CPU 端以 pinned 形式产生,当然也可以通过 x = x.pin_memory () 手动设置,然后再用 x.to ("cuda", non_blocking=True) 异步装载到 GPU 内存中。
