Token Merging Your ViT But Faster
Transformer 乘着最近大模型的井喷又火了一把,在NLP领域的统治地位好像已经无可撼动了,Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。记得几年前接触CV里面的Transformer结构DETR和Vit,惊讶于他们的效果的同时始终觉得将图片tokens化的方式太不美观了,同时所需要的训练资源太大了,学生党根本尝试不了。最近看到一篇论文,提出了一种无需训练即可加速 ViT 模型,提高吞吐量的方法 Token Merging (ToMe)。ToMe 通过一种轻量化的匹配算法,逐步合并 ViT 内部的相似的 tokens,实现了在基本不损失性能的前提下,大幅提升 ViT 架构的吞吐量。
背景
与卷积神经网络 (CNN) 相比,视觉 Transformer 模型 (ViT) 有一系列优良的性质,比如:
- Transformer 模型的 Attention 模块和 MLP 模块主要有矩阵乘法这种可以加速的操作构成。
- Transformer 支持一些性能强大的自监督学习任务 (掩码图像建模 MAE 等等)。
- Transformer 适配多种模态的输入数据 (图片,文本,音频等)。
- Transformer 对于超大规模数据集 (ImageNet-22K) 的泛化性好,预训练之后的模型在下游任务中 (比如 ImageNet-1K 图像分类任务) 表现卓越。
但是在资源受限的边缘设备 (如手机和无人机) 上实际运行 Transformer 不太友好,因为 Transformer 模型又相对较大的延时。一种常见的加速视觉 Transformer 模型的方法是对 token (图片 Patch) 进行剪枝,但是 token 剪枝的缺点有:
- 需要额外的训练过程,对资源不友好。
- token 剪枝限制了模型的实用性,当 token 数量随着输入的变化而发生变化时,无法进行批处理 (Batch Inference)。为了解决这个问题,大多数 token 剪枝的工作借助了 Mask,对冗余的 token 进行遮挡。但是这样的做法并没有真正剪去这些冗余的 token,使得这些方法并不能在实际业务中真正加速。
- token 剪枝带来的信息损失限制了可以允许剪枝的 token 数量。
另一种加速 ViT 的做法是对 token (图片 Patch) 进行融合。和本文方法最接近的 Token Pooling 使用了一个缓慢的基于 k-means 的方法,但是速度较慢,不适用于现成的模型。
本文希望做一个无需训练并且兼顾性能-速度权衡的 token 融合方法。因为其无需训练的优良属性,对于大模型将会非常友好。在训练过程中使用 ToMe,可以观察到训练速度增长,总训练时间缩短了一半。
方法
Token Merging 的基本思路
Token Merging 的基本思路是在一个 ViT 模型中间插入一些 token merging 的模块,希望把这些模块植入 ViT 以后,训练和推理的速度都有提升。基本作法是在每一个层之后减少$r$ 个 token,那么一个有$ L$ 层的 Transformer 模型从头到尾减少的 token 数量就是 $Lr$ 。这个 $r $值越高,减少的 token 数量就越多,但是精度也会越差。而且值得注意的是,无论一张输入图片有多少个 tokens,都会减少$Lr $个 token,而不是像上文的 token 剪枝算法那样使得 token 的数量动态变化。为什么这么设计呢?原因就是如上文所述当 token 数量随着输入的变化而发生变化时,无法进行批处理 (Batch Inference),使得这些方法并不能在实际业务中真正加速。
如下图1所示是 Token Merging 的示意图,ToMe 的位置被插在 Attention 模块和 MLP 模块之间,因为作者希望借助 Attention 中的特征帮助决定该去融合哪些 tokens。
什么样的 tokens 是相似的
根据上面的基本思路,要考虑的第1个问题是我们应该合并哪些 tokens,即什么样的 tokens 可以被认为是相似的 tokens?一种比较直接的想法是距离比较近的 tokens 是相似的,但是并不是最优解。
如下图2所示为消融实验结果,意在探索什么样的 tokens 是相似的。消融实验使用的模型是 MAE 训练策略下得到的 ViT-L/16 预训练模型 (acc: 85.96%, im/s: 93.3),不再进行任何额外训练。使用$ r=8$ 合并,这将在网络的24层上逐渐移除 98% 的 tokens。
如左图所示为使用什么特征衡量相似度,作者发现使用 Key 来衡量相似度对性能最友好,因为 Attention 模块中的 Key 已经总结了每个 token 中包含的信息,以便用于 Attention 中的 dot-product 相似度。如右图所示为使用什么距离衡量相似度,作者发现使用余弦距离来衡量 token 之间的相似度可以获得最好的精度-速度权衡。
如下图3所示,把不同 head 的 Key 进行取平均操作,而不是拼接在一起,更有助于效率。
Token Merging 的具体步骤:二分软匹配
在定义了 tokens 的相似性之后,下面就需要一种快速的方法来确定要匹配哪些 tokens,以便在实际运行时能够快速将 tokens 的数量减少$ r$ 。这个过程对于延时的要求很高,因为在 ViT 模型中要对可能上千个 tokens 执行匹配 $L $次,所以这个匹配算法的运行时间必须完全可以忽略不计。
把 ToMe 模块输入的所有 tokens 分为相同大小的2个集合 $\mathbb{A},\mathbb{B}$ 。
把从集合$\mathbb{A}$ 中的每个 token 到 $\mathbb{B}$ 中与其最相似的 token 画一条边。
只留下最相似的$r $条边,其余删掉。
- 融合仍然相连的$r$条边 (特征取均值)。
- 把这两个集合拼在一起,得到 ToMe 模块的融合结果。
Token Merging 的后续操作:调节注意力权重
前文提到,ToMe 模块会融合 $r$个 token。在 ViT 模型里面,一个 token 代表输入图片的一个 Patch,比如输入图片有 $N $个 Patch,就是有 $N$个 token。Attention 矩阵的维度也是$ N\times N$ 的,它代表了 $N $个 Patch 之间的相关关系。但是现在我们融合了 $r $个 token 之后呢,Attention 矩阵的维度应该是 $(N−r)×(N−r)$ 的,融合了 token 之后,有的 Key 应该占的 Attention 比重大一些,因为它融合了多个 token 的信息。所以作者在这里定义了一个行向量$ {s}$ 。 $s\in \mathbb{R}^{1\times N}$ 是包含每个 token 大小 (token 所代表的 Patch 数量) 的行向量。通过上式将行向$ s$ 直接加在 Attention 矩阵上面,相当于是人为增加了有些 Key 的 attention weight,而这些 key 恰好是发生了融合的 Key。
到目前为止,已经能够直接向已经训练好的 ViT 模型中添加 ToMe 模块。使用 ToMe 模块进行训练虽然不是必须的,但是它可以减少准确度下降,并且加快训练速度。ToMe 模块本质上是 token 的均值操作,因此可以视为是一种池化操作 (Pooling)。因此,我们可以按照平均池化操作 (Average Pooling) 的方式进行反向传播。
结果
如下图13所示是在网络的结尾处的每个合并的 token 所对应的输入 Patch。可以发现,ToMe 方法造成的 token 融合的效果和分割很像。比如,在第2张图中,哈士奇的腿、身体和脸被合并到了不同的 token 中。在第3张图中,猴子的手、身体、脸、眼睛和嘴都被合并到了不同的 token 中。在最后1张图中,所有实例 (狗) 中相同的部分会被合并在一起。值得注意的是,与剪枝不同,ToMe 这种 token 融合的方法能够合并背景和前景中的大量冗余的 tokens,而且不丢失信息。
总结
ToMe 是一个无需训练并且兼顾性能-速度权衡的 token 融合方法,意在缩减 ViT 模型中大量冗余的 tokens。Token Merging 的基本思路是在一个 ViT 模型中间插入一些 token merging 的模块,希望把这些模块植入 ViT 以后,训练和推理的速度都有提升。在图像和视频中多个模型的实验结果表明,这种 token 融合的方法能够合并背景和前景中的大量冗余的 tokens,提高 ViT 模型的吞吐量,而且不丢失信息。