【论文笔记】线性注意力:Learning to (Learn at Test Time): RNNs with Expressive Hidden States

参考文献:Learning to (Learn at Test Time): RNNs with Expressive Hidden States

动机(Motivation)

自注意力在长上下文中表现良好,但具有二次复杂度(无论是transformer还是mamba都涉及矩阵的乘积运算)。现有的 RNN 层具有线性复杂度,但它们在长上下文中的性能受到隐藏状态表达能力的限制。因此,希望提出一种序列建模层,具有线性复杂性和富有表现力的隐藏状态。

创新点

关键思想是将隐藏状态本身作为 机器学习模型,将更新规则作为自监督学习的步骤。由于隐藏状态甚至可以通过测试序 列上的训练来更新,因此我们的层称为Test-Time Training(TTT)层。

方法论

所有序列建模层都可以从将历史上下文存储到隐藏状态的角度来看待,如图 4 所示。

RNN:压缩启发式(compression heuristic),简单来说就是,在隐藏状态层将历史上下文进行压缩存储,性能收到隐藏状态表达能力的限制,因为只会存储固定大小的内容

self-attention:隐藏状态随着t增长,将计算所得的Key-value缓存下来,从而实现显式的存储历史上下文而不进行压缩。时间消耗会随着隐藏状态的增长而增长。

期望实现的目标:更好的压缩启发式,将数千甚至数百万个令牌压缩到隐藏状态,以有效捕获其底层结构和关系。

将TTT用于更新隐藏状态

启发:参数学习的过程可以看作是将大量的数据集压缩成模型权重的过程。而自监督学习训练的模型可以捕获训练数据背后的底层结构和关系。

关键思想:将历史上下文x_1,x_2,…,x_t压缩到隐藏状态$s_t$,这一过程通过自监督学习进行,即将历史上下文x视为没有label的数据集,而将隐藏状态视为模型,也就是将隐藏状态视作自监督模型产生的模型f的权重$W_t$。f可以是任意模型,包括线性模型或者神经网络。

那么输出规则$z_t=f(x_t;W_t)$

更新规则:$Wt=W_{t-1}-\eta\nabla\ell(W_{t-1};x_t)$,其中$\eta$为学习率,$\ell$为自监督损失

压缩启发式通常需要选择记住或者遗忘哪些输入,从更新规则不难看出,TTT能够记住一些产生较大梯度的输入,直观地讲,这些输入使 W 学到了很多东西。

对于$\ell$的选择,一种选择是进行重构,$\ell(W;x_t)=|f(\tilde{x}_t;W)-x_t|^2$

与其他 RNN 层和自注意力机制一样,算法映射输入序列到输出序列可以使用上面的隐藏状态、更新规则和输出规则编程到序列建模层的前向传递中。由于在测试过程中,对于每个输入序列,我们的新层仍然训练不同的权重序列,因此被称为TTT(Test-Time Training)

使用TTT层训练网络

TTT 层与 RNN 层和自注意力机制具有相同的接口,因此可以在任何更大的网络架构 中进行替换。

代码描述

我们将训练较大的网络称为外循环,将每个 TTT 层中的进行的$W$更新的训练称为内循环。区别是在内循环中,梯度$\nabla\ell$可以更新W,也就是模型f的参数,而在外循环中则是更新网络其余部分的参数,$\theta_{\mathrm{rest}}$,在接下来的内容中,外循环的参数使用不用下标的$\theta$表示。

TTT的自监督学习任务(最重要的部分)

创新:不使用人类先验的自监督任务,而是采用更加端到端的方法---直接优化自监督任务以实现下一个token预测的最终目标。

具体来说,将自监督学习任务作为外循环的一部分。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wufen_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值