核函数为什么能把 Attention 写成状态方程
By Chaa
这一部分的核心问题是:
为什么使用可分解核函数后,attention 可以写成类似 RNN 的状态递推形式,而原版 softmax attention 不可以?
关键不在于“核函数”这个名字,而在于:
κ(q,k)
能不能被拆成:
κ(q,k)=ϕ(q)⊤ϕ(k)
如果能拆,而且 ϕ 是有限维特征映射,那么历史 key-value 信息就可以提前聚合成一个固定大小状态。
1. 状态方程需要什么条件
我们希望把 attention 写成:
St=St−1+update(kt,vt)
然后当前输出只依赖当前 query 和状态:
yt=read(qt,St)
这要求一个很关键的条件:
历史 token 写入状态时,不能依赖未来的 query。
也就是说,第 j 个历史 token 写入状态时,只能用自己的:
kj, vj
不能等某个未来 query qi 出现之后,才知道这个历史 token 应该怎么参与计算。
因此,要想提前维护状态,attention 的相似度必须能拆成:
κ(qi,kj)=ϕ(qi)⊤ϕ(kj)
这样,和 query 有关的部分是:
ϕ(qi)
和历史 key/value 有关的部分是:
ϕ(kj), vj
二者就可以分离。
2. Kernel Attention 的一般形式
把 attention 写成 kernel 形式:
yi=∑j=1iκ(qi,kj)∑j=1iκ(qi,kj)vj
在 causal language modeling 中,第 i 个 token 只能看前面和当前位置,所以求和范围是:
j=1,…,i
如果 kernel 可以分解:
κ(qi,kj)=ϕ(qi)⊤ϕ(kj)
那么代入可得:
yi=∑j=1iϕ(qi)⊤ϕ(kj)∑j=1iϕ(qi)⊤ϕ(kj)vj
因为:
ϕ(qi)
和求和下标 j 无关,所以可以提出求和:
yi=ϕ(qi)⊤(∑j=1iϕ(kj))ϕ(qi)⊤(∑j=1iϕ(kj)vj⊤)
3. 状态形式从哪里来
定义状态矩阵:
Si=j=1∑iϕ(kj)vj⊤
定义归一化状态:
zi=j=1∑iϕ(kj)
那么输出可以写成:
yi=ϕ(qi)⊤ziϕ(qi)⊤Si
更重要的是,Si 和 zi 都可以递推维护:
Si=Si−1+ϕ(ki)vi⊤
zi=zi−1+ϕ(ki)
这就是 linear attention 能写成状态方程的原因。
它本质上把历史信息压缩到了两个状态里:
Si
和:
zi
之后每个 query 只需要从状态中读取,而不需要重新扫描全部历史 token。
4. 为什么原版 Softmax Attention 不能这样写
原版 softmax attention 是:
yi=∑j=1iexp(qi⊤kj)∑j=1iexp(qi⊤kj)vj
这里的相似度是:
κ(qi,kj)=exp(qi⊤kj)
问题在于:
exp(qi⊤kj)
把 qi 和 kj 强烈耦合在一起。
对于每一个新的 query qi,都要重新计算:
exp(qi⊤k1), exp(qi⊤k2), …, exp(qi⊤ki)
然后才能得到分母:
ℓ=1∑iexp(qi⊤kℓ)
也就是说,softmax 的每一行归一化分布都依赖当前 query 和所有历史 key 的比较结果。
所以原版 softmax attention 必须为每个 query 重新构造一行 attention distribution:
[αi1,αi2,…,αii]
这就无法提前把历史 key-value 精确压缩成一个固定大小状态。
5. 更准确地说:Softmax Kernel 可以拆,但需要无限维
需要注意,softmax 的非归一化 kernel:
exp(q⊤k)
并不是完全不能写成内积形式。
理论上,它可以写成:
exp(q⊤k)=ϕ(q)⊤ϕ(k)
但问题是,精确的 ϕ 通常是无限维的。
看一维情况最清楚。设 q,k 都是标量:
exp(qk)
泰勒展开为:
exp(qk)=r=0∑∞r!(qk)r
也就是:
exp(qk)=1+qk+2!q2k2+3!q3k3+⋯
这可以写成两个无限维向量的内积:
exp(qk)=[1, q, 2!q2, 3!q3,…]⊤[1, k, 2!k2, 3!k3,…]
所以:
ϕ(q)=[1, q, 2!q2, 3!q3,…]
这是无限维特征映射。
多维情况也是类似的,会包含所有阶数的交互特征:
1,qa,qaqb,qaqbqc,…
也就是一阶、二阶、三阶,一直到无限阶。
6. 为什么无限维会导致无法状态化
linear attention 想维护的是有限大小状态:
Si=j=1∑iϕ(kj)vj⊤
如果:
ϕ(kj)∈Rdϕ
那么状态大小是:
Si∈Rdϕ×dv
只要 dϕ 固定,状态大小就不随上下文长度增长。
但如果 softmax kernel 的精确 ϕ 是无限维,那么状态会变成:
Si∈R∞×dv
这在实际计算中不可行。
所以原版 softmax attention 不能被精确地写成可计算的有限维状态方程。
7. Linear Attention 实际做了什么
实际 linear attention 通常有两种路线。
第一种是直接换一个有限维可分解 kernel:
κ(q,k)=ϕ(q)⊤ϕ(k)
其中:
ϕ(q),ϕ(k)∈Rdϕ
这样可以精确得到状态递推:
Si=Si−1+ϕ(ki)vi⊤
zi=zi−1+ϕ(ki)
yi=ϕ(qi)⊤ziϕ(qi)⊤Si
第二种是近似 softmax kernel:
exp(q⊤k)≈ϕ(q)⊤ϕ(k)
其中 ϕ 是有限维的近似特征映射。
这种方法不是和原版 softmax 完全等价,而是在做近似。
8. 核函数和状态方程的关系
可以把整个逻辑压缩成三步。
第一步,选择一个可分解 kernel:
κ(q,k)=ϕ(q)⊤ϕ(k)
第二步,把历史 key-value 聚合成状态:
Si=j=1∑iϕ(kj)vj⊤
zi=j=1∑iϕ(kj)
第三步,用当前 query 从状态中读取:
yi=ϕ(qi)⊤ziϕ(qi)⊤Si
因此,kernel 让 attention 状态化的根本原因是:
它把 query 部分和 history 部分拆开了。
9. 与原版 Softmax 的本质区别
原版 softmax attention:
yi=∑j=1iexp(qi⊤kj)∑j=1iexp(qi⊤kj)vj
它的权重必须在当前 query 出现后,通过当前 query 与所有历史 key 的比较来决定。
可分解 kernel attention:
yi=ϕ(qi)⊤(∑j=1iϕ(kj))ϕ(qi)⊤(∑j=1iϕ(kj)vj⊤)
它可以把所有只和历史有关的部分提前累积。
所以二者的差别是:
softmax attention=query-specific full scan
linear attention=state update + state read
10. 最终总结
简单来说:
exp(q⊤k)
可以被看成 kernel 点积,但精确分解通常需要无限维特征:
exp(q⊤k)=ϕ(q)⊤ϕ(k),ϕ is infinite-dimensional
而 linear attention 需要的是有限维状态:
St∈Rdϕ×dv
所以原版 softmax attention 不能被精确地转化为可计算的有限维状态方程。
linear attention 的做法是:
exp(q⊤k)≈ϕ(q)⊤ϕ(k)
或者直接换用一个有限维可分解 kernel。
这带来的结果是:
St=St−1+ϕ(kt)vt⊤
zt=zt−1+ϕ(kt)
yt=ϕ(qt)⊤ztϕ(qt)⊤St
因此,linear attention 能状态化,不是因为“核函数”这个概念本身神奇,而是因为它使用了有限维可分解的特征映射,把 query-dependent 的相似度计算拆成了:
query part×history part
从而让历史部分可以提前累积成状态。