![]()
![]()
![]()
該工作已被機器學習領域頂級會議 ICML 2026 錄用,論文題目 “PRISM: Parallel Residual Iterative Sequence Model”。
一、背景:從無限背包到有限背包
(一)Transformer 的無限背包與線性注意力的有限背包
![]()
![]()
背包容量有限,每來一個新 token,模型必須決定往里寫什么、同時擦掉什么。這個 "寫與擦" 的規則,決定了有限背包模型的天花板。但在深入討論 "寫與擦" 之前,我們先要回答一個更基本的問題。
(二)有限背包本質上是 RNN,為何還能并行?
確實如此,有限背包模型的數學形式本質上就是 RNN:
![]()
![]()
關鍵在于一個數學技巧:Parallel Scan(并行前綴掃描)。
![]()
![]()
![]()
![]()
(三)為什么并行這么重要?GPU 的 "搬運工" 瓶頸
一個常見的誤解是將 "串行慢" 歸因于更多的浮點運算。實際上,瓶頸在別處。現代 GPU 的計算核心(Tensor Core / CUDA Core)算力極為充沛,A100 GPU 每秒能做 312 萬億次浮點運算(312 TFLOPS)。真正的瓶頸不是 "算",而是 "搬"。
GPU 的存儲分為兩層:
- HBM(High Bandwidth Memory,高帶寬顯存):容量大(40-80 GB),但讀寫速度 "慢"(約 2 TB/s)。模型參數、state 矩陣 S、中間 activation 都存在這里。
- SRAM(片上緩存):容量小(每個 SM 約 192 KB),但讀寫速度極快(約 19 TB/s,快 10 倍)。GPU 的計算核心只能直接訪問 SRAM。
打個比方:SRAM 像工作臺(小但觸手可及),HBM 像倉庫(大但每次取貨要走一趟)。
所以每一次計算都要經歷一個 "搬運" 流程:把數據從 HBM 搬進 SRAM,在 SRAM 里算完,再把結果搬回 HBM。這個搬運的時間往往遠超計算本身,這就是所謂的 memory-bound(存儲帶寬瓶頸)。
![]()
![]()
能否適配parallel scan 不僅是算法設計上的美學選擇,更直接決定了 10-100 倍的實際運行速度差異。
(四)Rank-1 寫入的瓶頸
以 GDN (Gated DeltaNet)為代表的線性注意力模型,每個 token 對 S 做的是一次 rank-1 更新:
![]()
![]()
如果一個 token 攜帶的語義是多維度的(它同時是某個句法結構的成分、某個語義角色的載體、某個 topic 的關鍵詞),rank-1 的一行寫入無法同時在這些維度上做精細調整。信息在壓縮寫入時不可避免地丟失。
核心矛盾:背包有限,每次卻只允許寫一行。這是當前所有線性復雜度模型的共有瓶頸。
(五)TTT 的突破與代價
既然 rank-1 寫入太淺,一個自然的想法是:讓模型學會更深的寫入規則。
TTT(Test-Time Training)系列工作采取了一種根本性不同的策略:把記憶狀態從一個 linear 矩陣 S 升級為一個 MLP 的權重矩陣。每來一個 token,對 MLP 的權重做多步梯度下降(multi-step GD),逐步精煉寫入內容。這帶來了顯著的質量提升。
![]()
![]()
二、分析:TTT-MLP 為什么效果好,但速度慢?
在設計 PRISM 之前,我們首先深入分析 TTT-MLP 的梯度結構,弄清楚它的高表達力到底從何而來。
(一)步長 × 殘差 × 方向 模式的涌現
![]()
每步更新具有一個結構模式:
![]()
TTT-MLP 的高表達力正來自這個 步長 × 殘差 × 方向 模式:多步殘差遞減提供了優化深度(depth),W? 多行提供多個方向則提供了表達寬度(width /rank-L)(即同時修改 S 矩陣的 L 個獨立維度)。
(二)高表達力與串行是同一根因的兩面
![]()
具體來說,它造成了兩個維度的串行瓶頸:
1. Token 間串行(Inter-token Seriality)
![]()
![]()
2. Step 間串行(Intra-step Seriality)
瓶頸 C(方向與殘差的同步):在多步 GD 中,第 l+1 步的寫入方向必須等待第 l 步的權重更新完畢才能確定,殘差也必須等上一步算完才能得到,強制引入一個無法展開的循環。
瓶頸 C 是最核心的矛盾:它同時是 rank-L 表達力的載體和步間串行的根源。因此消除瓶頸 C 不能簡單取消迭代,必須在取消同步耦合的同時保留多方向和殘差遞減帶來的表達力。
三、方法:PRISM 的設計與實現
基于上述分析,PRISM 的策略非常明確:在兼容 parallel scan 的線性狀態 S 上顯式重建 TTT-MLP 的 步長 × 殘差 × 方向 模式,然后分維度消除串行。
(一)核心迭代形式:步長 × 殘差 × 方向
PRISM 顯式構造了 TTT-MLP 的多步迭代模式:
![]()
![]()
與 TTT-MLP 的對應關系:
![]()
![]()
(二)消除 Token 間串行:A/B 分離 + 局部 Anchor 代理
![]()
![]()
至此,序列級別的 parallel scan 已完全恢復。anchor 讓不同 token 的迭代可以同時啟動,但每個 token 內部的 L 步之間仍需順序執行(瓶頸 C)。
(三)消除 Step 間串行:解耦鏈 + 閉合式預計算
解決瓶頸 C。因為有了 anchor,兩條鏈自然解耦:
![]()
![]()
![]()
由此多步迭代推算得到閉合式:
![]()
L 步的串行循環被消解為單步閉合式計算。整個多步梯度下降計算過程可以編譯成一個 fused kernel,數據只需要從 HBM 搬進 SRAM 一次。
(四)架構全貌與 GDN 退化
多步梯度下降計算過程的原始產出是 L 個 rank-1 迭代計算:
![]()
![]()
![]()
PRISM 可以視為一種多步殘差擬合計算過程,L=1 時精確退化為 GDN。 后續步只是在第一步的基礎上追加非線性修正,且可以使用 low rank 網絡增量,額外參數量不超過基礎模型的 10%。
四、實驗結果
(一)序列推薦
在公開序列推薦基準 Amazon 上,PRISM 表現與 Transformer baseline 效果接近,超過大多數線性注意力類方法。計算效率方面,PRISM 與 GDN 同級,比 TTT-MLP 快 174 倍。
![]()
(二)語言建模(基于 SlimPajama 2B 訓練,130M 參數)
在更大規模的語言建模實驗上(SlimPajama 2B tokens, Mistral tokenizer),PRISM 同樣取得了全面領先:
![]()
PRISM 在 WikiText PPL、LAMBADA PPL 和 9 項 Zero-Shot 下游任務平均準確率上均為最優,領先 GDN 3.2 個百分點。
(三)組件消融
![]()
訓練 PPL 差異極小,但下游泛化差異巨大。單步 solver (L=1) 的訓練 PPL 幾乎等于完整版,但 Avg ACC 下跌 2.9 個百分點 ——rank-L 的真正價值不在 next-token prediction 上,而在需要精確長程檢索的下游任務上。
![]()
五、延伸思考
(一)有限背包終究有限,混合架構也許是必然
![]()
從 PRISM 的視角看,這個直覺有一個很好的技術解釋。PRISM 用短卷積(ShortConv)計算的局部 anchor 替代全局狀態 S 來近似殘差。由于短卷積窗口通常只覆蓋最近 3-4 個 token,對于需要跨越數千步的長程依賴,近似質量必然下降。
如果在 PRISM 層之間穿插少量 Transformer 層,后者就充當了一種全局的、非線性的歷史狀態精確計算器,能補償 anchor 在長程上的近似誤差。從這個角度看,Transformer 本身就是 ShortConv anchor 的 "全局升級版":ShortConv 用固定窗口的局部卷積近似歷史狀態,Transformer 用全局 attention 精確算歷史狀態。
![]()
(二)線性注意力的 LoRA?
PRISM 的最終形式有一個有趣的結構特征:
![]()
這個 "基礎迭代過程 + low rank 旁路" 的形式,跟 LoRA(Low-Rank Adaptation) 非常相似,這啟發了一個微調場景下的有趣思路。
LoRA 的核心思想是:凍結預訓練好的大模型權重,只在關鍵層旁邊加一條 low-rank 旁路來做微調。受 PRISM 形式的啟發,我們可以設想一種面向 Linear Attention / SSM 模型的參數高效微調方法:對已訓練好的模型,凍結基礎迭代過程,只在寫入支路上增加一條 PRISM 風格的殘差擬合旁路,此外,這條旁路有閉合式(不增加訓練時間),而且第一步退化為原模型的標準寫入(不破壞預訓練知識)。這意味著它滿足 LoRA 的兩個關鍵要求:參數高效和不損害原模型能力。
結語
PRISM 驗證了 "寫入前思考" 范式在線性注意力模型中的可行性:通過分析 TTT-MLP 的梯度結構揭示 步長 × 殘差 × 方向 迭代模式,在線性狀態上顯式重建該模式并通過 anchor 代理和閉合式預計算實現完全并行。最終架構極簡 ——GDN + 非線性旁路,訓練速度與 GDN 同級,參數增量不到 10%。在推薦和語言建模兩個場景上的驗證表明,這是一項通用的線性注意力增強技術。未來我們將進一步探索 PRISM 在更大參數規模上的 scaling 行為和推薦系統上的應用效果,以及其作為線性注意力模型參數高效微調方法的實際效果。
參考文獻:
[1] Sun et al. “Learning to (Learn at Test Time): RNNs with Expressive Hidden States.” NeurIPS 2024.
[2] Yang et al. “Gated Delta Networks with Pairwise Tokenized Graphs.” NeurIPS 2024.
[3] Katharopoulos et al. “Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention.” ICML 2020.
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.