![]()
機(jī)器之心編輯部
5 月 22 日,Tri Dao 在社交媒體上轉(zhuǎn)發(fā)了 Han Guo 的一條推文。他還寫道:「經(jīng)過一些數(shù)學(xué)重寫,結(jié)果發(fā)現(xiàn) Transformer 的所有內(nèi)容都是一系列 GEMM + epilogue(矩陣乘法加尾聲)。給定一些優(yōu)化的原語,LLM(以及新手)就可以為所有 Transformer 操作編寫光速內(nèi)核!」
![]()
Tri Dao 是 FlashAttention 系列的核心作者之一,而這條推文則指向了他們當(dāng)天發(fā)布的一篇論文:CODA
![]()
- 論文標(biāo)題:CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs
- 論文地址:https://arxiv.org/abs/2605.19269
- 代碼地址:https://github.com/HanGuo97/coda-kernels
這個(gè)名字,讀起來像「終曲」,念起來像「CUDA」。來自 MIT、普林斯頓、Together AI 和 Meta 的研究者,試圖用一套新的編程抽象,把 Transformer 訓(xùn)練里那些鮮少被人關(guān)注、卻持續(xù)消耗時(shí)間的「散碎計(jì)算」,系統(tǒng)性地消化掉。
背景:訓(xùn)練大模型的「偷懶稅」
要理解 CODA 在解決什么問題,先要明白大模型訓(xùn)練的時(shí)間都去哪了。
在一塊英偉達(dá) H100 上訓(xùn)練一個(gè) LLaMA-3 風(fēng)格的 1B 參數(shù)模型,大部分人會(huì)直覺地認(rèn)為:時(shí)間都花在矩陣乘法和注意力計(jì)算上,畢竟那才是「真正的計(jì)算」。這個(gè)直覺大體上沒錯(cuò):矩陣乘法(GEMM)和注意力確實(shí)占據(jù)了主要算力
![]()
但如果你打開性能分析器仔細(xì)看,會(huì)發(fā)現(xiàn)還有一批「小算子」在安靜地消耗著時(shí)間:歸一化(RMSNorm)、激活函數(shù)(SwiGLU、RoPE)、殘差加法、跨層規(guī)約……它們單個(gè)計(jì)算量不大,卻頻繁地把大型中間張量從顯存里搬進(jìn)搬出。
![]()
這就是所謂的「內(nèi)存帶寬瓶頸」:好比一個(gè)廚藝絕頂?shù)膹N師,但每做一道菜都要把食材從遠(yuǎn)處的倉庫搬來、用完再送回去,而不是放在手邊的臺(tái)面上。廚師的手速再快,等待搬運(yùn)的時(shí)間也是真實(shí)的浪費(fèi)。
更糟糕的是,隨著英偉達(dá)的 FP8、FP4 等低精度格式讓矩陣計(jì)算越來越快,這些「搬運(yùn)」操作的相對成本反而在上升:矩陣乘法加速了,但張量搬進(jìn)搬出的成本并沒有同比縮短。
論文中有一組數(shù)據(jù)很直觀:在 H100 上用 TorchTitan 訓(xùn)練 1B 參數(shù)模型時(shí),非矩陣乘法操作占據(jù)了相當(dāng)一部分的端到端運(yùn)行時(shí)間,且隨著 FP8 精度的引入,這一比例還會(huì)進(jìn)一步凸顯。
現(xiàn)有的編程框架對此幾乎無能為力。PyTorch 把 Transformer 的計(jì)算表達(dá)成一串算子序列,算子之間有清晰的邊界。這種邊界對于自動(dòng)微分(autograd)非常友好,卻恰好阻止了跨算子的融合優(yōu)化:每一個(gè)算子邊界,往往就是一次不必要的顯存寫回。
CODA:「尾聲」里藏著寶藏
CODA 的出發(fā)點(diǎn)是一個(gè)樸素的觀察。
在 GPU 上,一個(gè)高性能的矩陣乘法(GEMM)內(nèi)核在結(jié)構(gòu)上分為兩個(gè)部分:主循環(huán)(mainloop)負(fù)責(zé)核心的矩陣分塊乘加計(jì)算,尾聲(epilogue)負(fù)責(zé)在結(jié)果寫回顯存之前做一些收尾處理,比如加偏置、類型轉(zhuǎn)換、簡單縮放。
![]()
尾聲存在的意義,在于此時(shí)矩陣乘法的輸出還「活在」片上寄存器里,還沒有落地到全局顯存。這是一個(gè)短暫的黃金窗口:如果能在這個(gè)時(shí)刻多做一些計(jì)算,就可以完全省掉一次顯存寫入再讀出的往返。
CODA 的核心洞察是:Transformer 里那些內(nèi)存密集型操作,其實(shí)很多可以被代數(shù)地重新參數(shù)化,塞進(jìn)這個(gè)「尾聲」窗口里執(zhí)行。
這需要一點(diǎn)數(shù)學(xué)技巧。以最常見的 GEMM-RMSNorm-GEMM 模式為例:一個(gè)矩陣乘法的結(jié)果,經(jīng)過殘差加法、RMS 歸一化,然后再做另一個(gè)矩陣乘法。傳統(tǒng)做法是三個(gè)獨(dú)立算子串行執(zhí)行,中間結(jié)果兩次落地顯存。
![]()
CODA 團(tuán)隊(duì)發(fā)現(xiàn),RMS 歸一化中的行縮放因子 r,因?yàn)槭敲啃泄蚕淼臉?biāo)量,它和后面的矩陣乘法滿足交換律:可以把 r 的應(yīng)用從「第二個(gè) GEMM 之前」推遲到「第二個(gè) GEMM 的尾聲」。推遲之后,第一個(gè) GEMM 的尾聲只需要計(jì)算局部的「分塊均方根」(partial RMS),由一個(gè)極輕量的輔助規(guī)約內(nèi)核合并,而完整的 RMSNorm 計(jì)算消失了。
類似的重新參數(shù)化,對 SwiGLU、RoPE(旋轉(zhuǎn)位置編碼)、交叉熵?fù)p失等操作同樣適用,甚至對反向傳播也成立。論文中有一個(gè)定理證明:只要前向尾聲是「分塊局部」的,反向傳播就自動(dòng)繼承相同的結(jié)構(gòu)。具體請?jiān)L問原論文查看。
五種「積木」和一套「樂高語言」
CODA 不是一個(gè)具體的融合內(nèi)核,而是一套編程抽象。
它固定住經(jīng)過專家優(yōu)化的 GEMM 主循環(huán),然后在尾聲位置暴露五類可組合的基本原語:
- 逐元素變換(residual 加法、激活函數(shù)、RoPE)
- 向量加載與存儲(chǔ)(廣播 RMSNorm 權(quán)重)
- 矩陣分塊加載與存儲(chǔ)(保存中間激活供反向傳播使用)
- 分塊規(guī)約(局部均方根、分塊 log-sum-exp)
- 有狀態(tài)變換(在線歸一化所需的 max 和 sum-exp 統(tǒng)計(jì))
用這五類積木,一個(gè)標(biāo)準(zhǔn) Transformer 的前向和反向傳播中、除注意力之外的幾乎全部操作都可以被覆蓋。
更有意思的是這套抽象對「誰來寫代碼」的寬容度。論文在實(shí)驗(yàn)中評估了兩種實(shí)現(xiàn)模式:一種是人工程序員撰寫,另一種是用 Claude Code 來生成 —— 給定 CODA 的原語說明、若干示例和實(shí)現(xiàn)日志,由 AI 完成大部分內(nèi)核代碼,人工輕度監(jiān)督。
兩種模式的性能表現(xiàn)均達(dá)到了較高水平。Tri Dao 在推文中說「LLM 以及新手就可以編寫光速內(nèi)核」,這正是論文實(shí)驗(yàn)結(jié)果在現(xiàn)實(shí)層面的映射。
實(shí)驗(yàn)結(jié)果
CODA 的基準(zhǔn)測試選擇的是較為苛刻的對手:cuBLAS 加上 torch.compile,以及專為 LLM 優(yōu)化的 Liger Kernel 和 FlashInfer。
論文對每個(gè)內(nèi)核評估了兩種實(shí)現(xiàn):CODA (LLM)由 Claude Code 生成,研究者提供原語說明、若干示例和一份持續(xù)更新的實(shí)現(xiàn)技巧日志,AI 完成主體代碼,人工做輕度監(jiān)督;CODA (Human)由人工程序員獨(dú)立編寫,使用同樣的高層重參數(shù)化思路,但不依賴 CODA 原語集本身。兩組結(jié)果都與 cuBLAS + torch.compile、Liger Kernel、FlashInfer 等優(yōu)化庫進(jìn)行對比。
在單算子層面,以 GEMM-RMSNorm-GEMM 這一典型模式為例,CODA 在對應(yīng) 1B、7B、70B 三個(gè)模型規(guī)模的隱藏維度下均實(shí)現(xiàn)了對 cuBLAS + PyTorch 基線的超越。SwiGLU、RoPE、交叉熵等尾聲組合也有類似表現(xiàn)。
LLM 生成的內(nèi)核在大多數(shù)基準(zhǔn)上與人工手寫版本不相上下,個(gè)別配置下甚至略有超越。這在 GPU 內(nèi)核優(yōu)化這個(gè)歷來門檻極高的領(lǐng)域,是一個(gè)頗為罕見的結(jié)論。
![]()
![]()
![]()
反向傳播的收益尤為突出:GEMM-Residual-PartialRMS-GEMM 的反向內(nèi)核相比基線加速幅度可達(dá) 1.6 至 1.8 倍,SwiGLU 反向也有約 1.4 至 1.6 倍的提升。這個(gè)方向上,LLM 與人工實(shí)現(xiàn)的差距同樣微小。這并不奇怪:反向傳播天然涉及更多中間張量的存取,尾聲融合的收益就更大;而 CODA 的原語設(shè)計(jì)足夠清晰,使得 AI 模型能夠正確地完成組合。
![]()
在完整 Transformer 層的端到端基準(zhǔn)中,CODA 的前向加速在不同規(guī)模下約為 5% 至 20%,在較大模型尺寸(對應(yīng) 70B 規(guī)模的隱藏維度)下效果更為顯著。
數(shù)值精度方面,CODA 的重參數(shù)化調(diào)整了 RMSNorm 縮放因子的應(yīng)用時(shí)機(jī),但實(shí)驗(yàn)表明其數(shù)值誤差與 PyTorch 參考實(shí)現(xiàn)相當(dāng),在某些配置下誤差甚至更小 —— 得益于 GEMM 主循環(huán)本身具有更高精度的累加器。
CODA 能做什么:一張速查單
在進(jìn)入更大的視角之前,先把 CODA 的能力邊界說清楚。
- 覆蓋范圍:標(biāo)準(zhǔn) Transformer(如 LLaMA 架構(gòu))的前向和反向傳播中,除注意力和詞嵌入之外的幾乎全部計(jì)算,包括 RMSNorm、殘差加法、SwiGLU 激活、RoPE 旋轉(zhuǎn)位置編碼、交叉熵?fù)p失,以及上述操作的反向梯度計(jì)算。
- 加速效果:在對應(yīng) 1B 至 70B 規(guī)模的隱藏維度下,單算子層面相比 cuBLAS + torch.compile 基線有不同程度的提升,其中反向傳播收益最為顯著(部分內(nèi)核可達(dá) 1.6 倍以上);完整 Transformer 層的端到端前向加速約為 5% 至 20%,在較大模型尺寸下效果更突出。
- 誰能用:CODA 基于 CuTeDSL(NVIDIA CUTLASS 的 Python DSL)實(shí)現(xiàn),支持人工程序員和 AI 模型兩種內(nèi)核編寫方式,且兩種方式均能達(dá)到高性能。
- 當(dāng)前限制:目前僅支持單 GPU 場景,不涉及分布式訓(xùn)練;重參數(shù)化主要針對標(biāo)準(zhǔn) Transformer 架構(gòu),其他架構(gòu)的適用性有待驗(yàn)證。
結(jié)語
CODA 并非孤立的工作。它是一類思想的具體實(shí)現(xiàn):在 GPU 上,真正的優(yōu)化空間往往不在「算什么」,而在「怎么搬」。
FlashAttention 讓注意力計(jì)算「住進(jìn)」了片上內(nèi)存,CODA 試圖讓歸一化和激活函數(shù)也「住進(jìn)去」。Triton 降低了寫自定義內(nèi)核的門檻,ThunderKittens、TileLang 等進(jìn)一步在不同層次上探索這一空間。這些工作共同指向同一個(gè)方向:把 PyTorch 算子圖的表達(dá)便利性,與接近手寫 CUDA 的執(zhí)行效率,真正統(tǒng)一在一套可編程的框架里。
Tri Dao 推文的最后一句話值得再回味:「LLM 以及新手就可以為所有 Transformer 操作編寫光速內(nèi)核。」這背后有一個(gè)更深的邏輯:當(dāng)編程抽象設(shè)計(jì)得足夠好,AI 模型本身就可以參與到自身訓(xùn)練基礎(chǔ)設(shè)施的優(yōu)化中。這個(gè)循環(huán),才是 CODA 最耐人尋味的地方。
從這個(gè)角度看,「CODA」這個(gè)名字或許另有深意。在古典音樂中,Coda 是樂曲末尾收束全篇的段落。在這里,它是 GEMM 內(nèi)核的「尾聲」—— 而寫好這段尾聲,或許正是 Transformer 訓(xùn)練系統(tǒng)效率提升的下一個(gè)重要章節(jié)。
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.