Mamba串烧

文章发布时间:

最后更新时间:

文章总字数:
4.9k

页面浏览: 加载中...

Mamba串烧

[TOC]

Mamba: Linear-Time Sequence Modeling with Selective State Spaces

本文核心内容与思想搬运自该博文中的核心与精华:一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba_mamba模型-CSDN博客

状态空间与状态空间模型SSM

想象一下我们正在穿过一个迷宫,图中每个小框代表迷宫中的一个位置,并附有某个隐式的信息,例如你距离出口有多远。

image

​ 而上述迷宫可以简化建模为一个“状态空间表示state space representation”,每一个小框都含有:你当前所在的位置(当前状态current state)、下一步可以去哪里(未来可能的状态possible future states),以及哪些变化会将你带到下一个状态(向右或向左)。而描述状态的变量(在我们的示例中为 X 和 Y 坐标以及到出口的距离)可以表示为“状态向量state vectors”。

一般SSMs包括以下组成:

  • 输入序列x(t),比如在迷宫中向左和向下移动。
  • 潜在状态表示h(t),比如距离出口距离和 x/y 坐标。
  • 输出序列y(t),比如再次向左移动以更快到达出口。

然而,它不使用离散序列(如向左移动一次),而是将连续序列作为输入并预测输出序列。

image

​ SSM 假设系统(例如在 3D 空间中移动的物体)可以通过两个方程从其在时间t 时的状态进行预测「 当然,其实下面第一个方程表示成这样可能更好:h(t) = Ah(t-1) + Bx(t),不然容易引发歧义 」。

image

通过求解这些方程,可以根据观察到的数据:输入序列和先前状态,去预测系统的未来状态。

SSM关键的两个方程:状态方程和输出方程

总之,SSM的关键是找到:状态表示(state representation)h(t)以便根据输入序列预测输出序列。

image

而这两个方程也是状态空间模型的核心,且矩阵 ABCD都是可以学习的参数。

​ 第一个方程:状态方程,矩阵B与输入x(t)相乘之后,再加上矩阵A前一个状态h(t)相乘的结果
换言之,B矩阵影响输入x(t)A矩阵影响前一个状态h(t),而h(t)指的是任何给定时间t的潜在状态表示(latent state representation),而x(t)指的是某个输入「当然,还是上面那句话,表示成这样更好:h(t) = Ah(t-1) + Bx(t)」。

image

第二个方程:输出方程,描述了状态如何转换为输出(通过矩阵 C),以及输入如何影响输出(通过矩阵 D)

image

建立对SSM中两个核心方程的统一视角

最终,我们可以通过下图统一这两个方程:

image

为了进一步加深对该图的理解,我们一步一步拆解下

  1. 假设我们有一些输入信号x(t),该信号首先乘以矩阵 B

image

2.上面第一步的结果,加上:上一个状态与矩阵A相乘(矩阵A描述了所有内部状态如何连接)的结果,用来更新状态state。

image

3.然后,使用矩阵C来将状态转换为输出。

image

4.最后,再利用矩阵 D提供从输入到输出的直接信号,这通常也称为跳跃连接skip-connection。

image

5.由于矩阵 D类似于跳跃连接,因此在没有跳跃连接的情况下,SSM 通常被视为如下:

image

回到我们的简化视角,现在可以关注只矩阵ABC构建的SSM核心。

image

​ 总之,这两个方程共同旨在根据观测数据预测系统的状态,且考虑到输入一般都是连续的,因此SSM的主要表示是连续时间表示(continuous-time representation)。

从SSM到S4的升级

用了三步:离散化SSM、循环/卷积表示、基于HiPPO处理长序列。

离散数据的连续化:基于零阶保持技术做连续化并采样

​ 由于除了连续的输入之外,还会通常碰到离散的输入(如文本序列),因此如果模型也能处理离散化数据则再好不过。怎么做到呢?好在 可以利用零阶保持技术(Zero-order hold technique)

image

  1. 首先,每次收到离散信号时,我们都会保留其值,直到收到新的离散信号,如此操作导致的结果就是创建了 SSM 可以使用的连续信号
  2. 保持该值的时间由一个新的可学习参数表示,称为步长(siz)——\Delta ,它代表输入的阶段性保持(resolution)。
  3. 有了连续的输入信号后,便可以生成连续的输出,并且仅根据输入的时间步长对值进行采样。

image

这些采样值就是我们的离散输出,且 可以按如下方式做零阶保持

image

​ 它们共同使我们能够从连续 SSM 转变为离散SSM,使得不再是函数到函数x(t) → y(t),而是序列到序列,所以你看到离散化的SSM时,不再带参数t了

image

​ 这里,矩阵A和 B现在表示模型的离散参数(且这里 使用 而不是 来表示离散时间步长)。
​ 注意:我们在保存时,仍然保存 矩阵 A的连续形式(而非离散化版本),只是在训练过程中,连续表示被离散化(During training, the continuous representation is discretized)。

循环结构表示The Recurrent Representation

总之,离散 SSM 允许 可以用离散时间步长重新表述问题。

image

在每个时间步,都会涉及到隐藏状态的更新(比如h_k取决于Bx_kAh_{k-1}的共同作用结果,然后通过Ch_k预测输出y_k)。

image

为方便大家理解其中的细节,我再展开一下 y_2

image

有没有眼前一亮?如此,便可以RNN的结构来处理。

image

然后 可以这样展开(其中,h_k始终是Bx_kAh_{k-1}的共同作用之下更新的)。

image

卷积结构表示The Convolution Representation

在经典的图像识别任务中,我们用过滤器(即卷积核kernels)来导出聚合特征,而SSM也可以表示成卷积的形式。

image

由于我们处理的是文本而不是图像,因此我们需要一维视角

image

而用来表示这个“过滤器”的内核源自 SSM 公式。

image

1.与卷积一样,我们可以使用 SSM 内核来检查每组token并计算输出

image

2.内核将移动一次以执行下一步的计算。

image

3.最后一步,我们可以看到内核的完整效果:

image

至于上图中的y_2是咋计算得到的,别忘了我上面推导出来的。

image

​ 总结一下,将 SSM 表示为卷积的一个主要好处是它可以像卷积神经网络CNN一样进行并行训练。然而,由于内核大小固定,它们的推理不如 RNN 那样快速。

image

最终,SSM可以视为从输入信号到输出信号的参数化映射。

1.SSMs可以当做是RNN与CNN的结合「These models can be interpreted as acombination of recurrent neural networks (RNNs) and convolutional neural networks (CNNs)」,即推理用RNN,训练用CNN。

image

2.总之,这类模型可以非常高效地计算为递归或卷积,在序列长度上具有线性或近线性缩放(This class of models can be computed very efficiently as either arecurrence or convolution, with linear or near-linear scaling in sequence length。

矩阵A的问题与解决策略——HiPPO

​ 我们之前在循环表示中看到的那样,矩阵A捕获先前previous状态的信息来构建新状态(h_k = A h_{k-1} + B x_k,当k = 5时,则有h_5 = A h_{4} + B x_5)

image

其实,某种意义上,算是矩阵A产生了隐藏状态(matrix A produces the hidden state)。

image

​ 由于矩阵A只记住之前的几个token和捕获迄今为止看到的每个token之间的区别,特别是在循环表示的上下文中,因为它只回顾以前的状态。

那么我们怎样才能以保留比较长的memory的方式创建矩阵A呢?

​ 答案是可以使用Hungry Hungry Hippo((High-order Polynomial Projection Operator,简称H3),HiPPO尝试将当前看到的所有输入信号压缩为系数向量(HiPPO attempts to compress all input signals it has seen thus far into a vector of coefficients)。

image

​ 它使用矩阵A构建一个“可以很好地捕获最近的token并衰减旧的token”状态表示(to build a state representation that captures recent tokens well and decays older tokens),其公式可以表示如下:

image

具体表示可以如下图所示:

image

​ 正由于HiPPO 矩阵可以产生一个隐藏状态来记住其历史(从数学上讲,它是通过跟踪Legendre polynomial的系数来实现的,这使得它能够逼近所有以前的历史),使得在被应用于循环表示和卷积表示中时,可以处理远程依赖性

​ 如此,S4的定义就出来了:序列的结构化状态空间——Structured State Space for Sequences,一类可以有效处理长序列的 SSM。

image

SSM的问题:矩阵参数固定不变,无法针对输入做针对性推理

首先,Linear Time Invariance(LTI)规定 SSM中的A、B、C始终是固定不变的参数。这意味着:

1.对于 SSM 生成的每个token,矩阵A 、B、C都是相同的(regardless of what sequence you give the SSM, the values of A,B,and C remain the same. We have a static representation that is not content-aware)
2.使得SSM无法针对输入做针对性的推理「since it treats each token equally as a result of the fixed A, B, and C matrices. This is a problem as we want the SSM to reason about the input (prompt)」

此外,如下图所示,无论输入x 是什么,矩阵 B都保持完全相同,因此与x无关。

image

同样,无论输入如何,A和C也保持固定。

image

Mamba组成结构与原理简析

​ 简言之,Mamba是一种状态空间模型(SSM),建立在更现代的适用于深度学习的结构化SSM (简称S6)基础上,与经典架构RNN有相似之处。

Mamba = 有选择处理信息 + 硬件感知算法 + 更简单SSM架构

与先前的研究相比,Mamba主要有三点创新:

1.对输入信息有选择性处理(Selection Mechanism)

相比SSM压缩所有历史记录(当然,transformer则是不压缩所有历史记录),mamba设计了一个简单的选择机制,通过“参数化

SSM的输入”,以便关注或忽略特定的输入。这样一来,模型能够过滤掉与问题无关的信息,并且可以长期记住与问题相关的信息。

2.硬件感知的算法(Hardware-aware Algorithm)

该算法采用“并行扫描算法”而非“卷积”来进行模型的循环计算,但为了减少GPU内存层次结构中不同级别之间的IO访问,

没有具体化扩展状态。

当然,这点也是受到了S5(Simplified State Space Layers for Sequence Modeling)的启发。

3.更简单的架构

将SSM架构的设计与transformer的MLP块合并为一个块(combining the design of prior SSM architectures with the MLP block of

Transformers into a single block),来简化过去的深度序列模型架构,从而得到一个包含selective state space的架构设计。

选择性状态空间模型:从S4到S6

​ 作者认为,序列建模的一个基础问题是把上下文压缩成更小的状态(We argue that a fundamental problem of sequence modeling is compressing context into a smaller state),从这个角度来看:

  • 注意力机制虽然有效果但效率不算很高,毕竟其需要显式地存储整个上下文(storing the entire context,也就是KV缓存),直接导致训练和推理消耗算力大。好比,Transformer就像人类每写一个字之前,都把前面的所有字+输入都复习一遍,所以写的慢。

  • RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制。

On the other hand, recurrent models are efficient because they have a finite state, implying constant-time inference and linear-time

training. However, their effectiveness is limited by how well this state has compressed the context。

好比,RNN每次只参考前面固定的字数(仔细体会这句话:When generating the output, the RNN only needs to consider the previous hidden state and current input. It prevents recalculating all previous hidden states which is what a Transformer would do),写的快是快,但容易忘掉更前面的内容。

  • 而SMM的问题在于其中的矩阵A B C始终是不变的,无法针对不同的输入针对性的推理,详见上文

image

  • 最终,Mamba的解决办法是,让模型对信息有选择性处理,可以关注或忽略特定的内容,即使状态大小固定也能压缩上下文。好比,Mamba每次参考前面所有内容的一个概括,越往后写对前面内容概括得越狠,丢掉细节、保留大意。

总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度:

  • 高效的模型必须有一个小的状态(比如RNN或S4)

  • 而有效的模型必须有一个包含来自上下文的所有必要信息的状态(比如transformer)

而mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的:

image

为方便大家理解,再进一步阐述mamba与其前身结构化空间模型S4的优势。

首先,在其前身S4中,其有4个参数(∆, A, B, C)

image

且它们都是固定的,不随输入变化(即与输入无关),这些参数控制了以下两个阶段:
image

  • 第一阶段(1a 1b),通常采用固定公式A = 𝑓𝐴(∆, A)和B = 𝑓𝐵(∆, A, B),将“连续参数”(∆,A,B)转化为“离散参数”(A,B),其中(𝑓𝐴, 𝑓𝐵) 称为离散化规则,且可以使用多种规则来实现这一转换。

    The first stage transforms the “continuous parameters” (∆, A, B) to “discrete parameters” (A, B) through fixed formulas A = 𝑓𝐴(∆, A) and B = 𝑓𝐵(∆, A, B), where the pair (𝑓𝐴, 𝑓𝐵) is called a discretization rule.

    例如下述方程中定义的零阶保持(ZOH):
    Various rules can be used such as the zero-order hold (ZOH) defined in equation (4).

    image.png)

  • 第二阶段(2a 2b,和3a 3b),在参数由(∆,A, B, C)变换为(A, B, C)后,模型可以用两种方式计算,即线性递归(2)或全局卷积(3)。

    After the parameters have been transformed from (∆, A, B, C) ↦ (A, B, C), the model can be computed in two ways, either as a linear recurrence (2) or a global convolution (3)

如之前所说的

模型通常使用卷积模式(3)可以进行高效的并行化训练「 其中整个输入序列提前看到,为何可以做高效的并行化呢,因为该模式能够绕过状态计算,并实现仅包含(B, L, D)的卷积核(3a),即Thus the more efficient convolution mode wasintroduced which could bypass the state computation and materializes a convolution kernel (3a) of only (𝙱, 𝙻, 𝙳)」

并切换到循环模式(2)以高效的自回归推理(其中输入每次只看到一个时间步)
the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time),

and switched into recurrent mode (2) for efficient autoregressive inference (wheret he inputs are seen one timestep at a time).

下面,再分析下各个变量的含义

  • 一个标量,类似遗忘门
    如sonta所说,这个量跟RNN里的gating有着深刻的联系(∆ in SSMs can be seen to play a generalized role of the RNN gating mechanism)。
    即data dependent的 Δ 跟RNN的forget gate的功能类似(step size Δ that represents the resolution of the input discretization of SSMs is the principled foundation of heuristic gating mechanisms.)
    B,起到的作用类似于:进RNN的memory
    C,起到的作用类似于:取RNN的memory

咋理解?我拿出上文第二部分的这个图 一摆,就一目了然了。

image

所以有人说,data dependent的B/C的功能跟RNN的input/output gate类似。

  • A,意味着对应这个维度的SSM来说,A在每个hidden state维度上的作用可以不相同,起到multi-scale/fine-grained gating的作用,这也是LSTM网络里面用element-wise product的原因。

​ 其次,通过之前的讲解,可知\boldsymbol{A} \in \mathbb{R}^{N \times N}, \boldsymbol{B} \in \mathbb{R}^{N \times 1}, \boldsymbol{C} \in \mathbb{R}^{1 \times N}矩阵都可以由N个数字表示(the A ∈ ℝ𝑁×𝑁, B ∈ ℝ𝑁×1 , C ∈ ℝ1×𝑁 matrices can all be represented by 𝑁 numbers.)

image

​ 为了对批量大小为、长度为、具有个通道(类似R G B三个通道)的输入序列进行操作,SSM被独立地应用于每个通道(To

operate over an input sequence 𝑥 of batch size 𝐵 and length 𝐿 with 𝐷 channels, the SSM is applied independently to each channel)。

image

​ 请注意,在这种情况下,每个输入的总隐藏状态具有维,在序列长度上计算它需要时间和内存(the total hidden state has dimension 𝐷𝑁 per input, and computing it over the sequence length requires 𝑂(𝐵𝐿𝐷𝑁) time and memory)。

​ 最后,在Mamaba中,作者让这些参数矩阵、矩阵、成为输入的函数(即可学习或可训练的),让模型能够根据输入内容自适应地调整其行为。

image

1.从S4到S6的过程中:

影响输入的B矩阵、影响状态的C矩阵的大小从原来的(D,N)「其中,D指的是输入向量的维度,比如一个颜色的变量一般有R G B三个维度,N**指SSM的隐藏层维度hidden dimension,当然 一般设的比较小」。

image

变成了(B,L,N)「这三个参数分别对应batch size、sequence length、hidden state size

image

\Delta的大小由原来的D变成了(B,L,D)
且每个位置的B矩阵、C矩阵、\Delta都不相同,这意味着对于每个输入token,现在有不同的B矩阵、C矩阵,可以解决内容感知问题

进一步,咱们通过
s_{B}(x)=\operatorname{Linear}_{N}(x)
s_{C}(x)=\operatorname{Linear}_{N}(x)
s_{\Delta}(x)=\operatorname{Linear}_{D}(x)
\tau_{\Delta}=\text { softplus }
来逐一将B, C, \Delta数据依赖化(data dependent)「其中的这个\text { Linear }_{d}(x)代表把D维的输入向量x经过一个线性层映射到d

2.虽然A没有变成data dependent,但是通过SSM的离散化操作之后,(\bar{A}, \bar{B})会经过outer product变成(B, L, N, D)的data dependent张量,算是以一种parameter efficient的方式来达到data dependent的目的。

总之,Mamba通过合并输入的序列长度和批量大小来使矩阵B和C,甚至步长Δ取决于输入(其意味着对于每个输入token,现在有不同的B和C矩阵,可以解决内容感知问题),从而达到选择性地选择将哪些内容保留在隐藏状态以及忽略哪些内容的目标。至于步长Δ,较小的步长Δ会也能做到忽略特定单词,而更多地使用先前的上下文,而较大的步长Δ会更多地关注输入单词而不是上下文。

image

并行扫描(parallel scan)算法

​ 由于A B C这些矩阵现在是动态的了,因此无法使用卷积表示来计算它们(CNN需要固定的内核),因此,我们只能使用循环表示,如此也就而失去了卷积提供的并行训练能力。

so,为了实现并行化,让我们探讨如何使用循环计算输出:

硬件感知的状态扩展:借鉴Flash Attention

简化的SSM架构

将大多数SSM架构比如H3的基础块,与现代神经网络比如transformer中普遍存在的门控MLP相结合,组成新的Mamba块,重复这个块,与归一化和残差连接结合,便构成了Mamba架构

image

顺带提一嘴,transformer quality in linear time以及mega moving average equipped gated attention的这两个工作,也用了类似的结

构:即删除transformer的ffn/glu结构

最终流程如下(图源自mamba原论文)

image

打赏
支付宝 | Alipay
微信 | Wechat