多頭注意力是深度學習、自然語言處理、Transformer、大語言模型和多模態模型中非常核心的一個術語。它用來描述一種把注意力機制分成多個“注意力頭”,讓模型從不同角度同時理解上下文關系的方法。換句話說,多頭注意力是在回答:模型怎樣同時關注一句話中的多種關系,而不是只用一種注意力模式理解全部內容。
如果說自注意力機制讓每個 Token 可以根據相關性關注其他 Token,那么多頭注意力則進一步讓模型用多個注意力頭分別學習不同類型的關系。例如,一個頭可能關注主謂關系,一個頭可能關注指代關系,一個頭可能關注局部短語搭配,還有一個頭可能關注長距離依賴。
因此,多頭注意力常用于 Transformer、BERT、GPT、大語言模型、機器翻譯、文本生成、圖像 Transformer、多模態模型和擴散模型中的條件建模,是理解現代人工智能模型結構的重要基礎概念。
一、基本概念:什么是多頭注意力
多頭注意力(Multi-Head Attention)是在注意力機制基礎上的擴展。
![]()
圖 1:多頭注意力結構總覽
普通自注意力會根據 Q、K、V 計算一次注意力:
其中:
? Q 表示 Query,查詢
? K 表示 Key,鍵
? V 表示 Value,值
? d_k 表示 Key 向量維度
? softmax 用于得到注意力權重
多頭注意力不是只計算一次注意力,而是把輸入映射到多個不同的子空間中,分別計算多組注意力。
每一組注意力稱為一個注意力頭。
可以簡單理解為:
線性映射得到最終輸出從通俗角度看:單頭注意力像一個讀者只用一種視角讀句子。多頭注意力像多個讀者同時讀同一句話,每個人關注不同線索,最后把意見合并。
例如,對句子:
我喜歡機器學習,因為它能發現數據中的規律。不同注意力頭可能關注:
? “它”和“機器學習”的指代關系
? “發現”和“規律”的動賓關系
? “因為”引導的因果關系
? 相鄰詞之間的局部搭配
這就是多頭注意力的基本思想。
二、為什么需要多頭注意力
多頭注意力之所以重要,是因為語言和圖像中的關系往往不是單一的。
一句話中可能同時存在:
? 語法關系
? 指代關系
? 修飾關系
? 因果關系
? 局部搭配
? 長距離依賴
如果只用一個注意力頭,模型只能形成一套注意力分布。這可能不足以同時表達多種關系。
多頭注意力通過多個頭,讓模型可以并行學習不同的注意力模式。
從通俗角度看:一個老師看作文,可能重點看語法;另一個老師可能重點看邏輯;第三個老師可能重點看用詞。多個視角合在一起,判斷會更全面。
多頭注意力的作用包括:
? 從多個角度建模上下文關系
? 提高模型表達能力
? 讓不同注意力頭學習不同模式
? 支持 Transformer 并行計算
? 增強模型對復雜序列的理解能力
例如,在機器翻譯中,一個注意力頭可能關注詞語對齊,另一個注意力頭可能關注句法結構,另一個注意力頭可能關注長距離依賴。
因此,多頭注意力不是簡單增加計算量,而是讓模型獲得更豐富的表示能力。
三、從單頭注意力到多頭注意力
理解多頭注意力,可以先從單頭注意力開始。
![]()
圖 2:單頭注意力與多頭注意力
1、單頭注意力
單頭注意力只計算一組 Q、K、V。
對于輸入 X,可以先得到:
然后計算:
其中:
? X 表示輸入序列表示
? W_Q、W_K、W_V 是可學習參數
? O 表示注意力輸出
單頭注意力可以讓每個 token 根據上下文更新表示。
但它只有一套注意力權重。
2、多頭注意力
多頭注意力會為每個頭準備不同的線性變換。
第 i 個注意力頭可以寫為:
其中:
? head_i 表示第 i 個注意力頭
? W_i^Q、W_i^K、W_i^V 表示第 i 個頭對應的參數
? i 表示注意力頭編號
如果有 h 個注意力頭,就會得到:
head?, head?, ..., head_h
這些頭分別從不同子空間中計算注意力。
最后,把它們拼接起來:
再經過輸出線性變換:
其中:
? H 表示拼接后的多頭結果
? W_O 表示輸出投影矩陣
? O 表示多頭注意力的最終輸出
完整公式常寫為:
從通俗角度看:多頭注意力先“分頭理解”,再“合并意見”。
四、注意力頭:每個頭到底在學什么
注意力頭不是人工指定功能的模塊,而是模型在訓練過程中自動學習出來的子結構。
每個注意力頭都有自己的一組 Q、K、V 變換矩陣。這意味著不同頭可以把同一個輸入投影到不同特征空間中。
例如,對同一句話:
小明把書放進書包,因為它很重。不同注意力頭可能學習到不同關系:
? 某個頭關注“它”與“書”的指代關系
? 某個頭關注“放進”和“書包”的動作關系
? 某個頭關注相鄰詞之間的局部關系
? 某個頭關注句子整體結構
從通俗角度看:注意力頭像多個觀察角度。同一個 Token,在不同注意力頭中會被用不同方式理解。
需要注意:注意力頭的含義不是訓練前人工規定的。
我們不能說“第 1 個頭一定學語法,第 2 個頭一定學指代”。
更準確地說:不同注意力頭具有學習不同關系模式的能力,但具體學到什么取決于數據、任務和模型訓練結果。
五、多頭注意力中的維度變化
多頭注意力不僅是概念,也涉及張量形狀變化。
假設輸入序列表示為:
其中:
? N 表示 batch size
? L 表示序列長度
? d_model 表示模型隱藏維度
如果有 h 個注意力頭,通常每個頭的維度為:
其中:
? d_head 表示每個注意力頭的維度
? h 表示注意力頭數量
例如:
h = 8那么:
也就是說,每個頭使用 64 維子空間計算注意力。
多頭計算后,h 個頭會被拼接回 d_model 維:
從通俗角度看:多頭注意力不是簡單把總維度無限增大,而是把原來的表示維度拆成多個子空間分別計算,最后再合并回來。
這也是為什么注意力頭數量通常要能整除 d_model。
六、多頭注意力與自注意力、交叉注意力
多頭注意力可以用于不同類型的注意力結構。最常見的是:
? 多頭自注意力
? 多頭交叉注意力
![]()
圖 3:多頭注意力的三種常見形式
1、多頭自注意力
在多頭自注意力中,Q、K、V 都來自同一個序列。
例如,在文本理解中:
輸入句子 → Q、K、V可以寫為:
其中 Q、K、V 都由 X 生成。
這種結構用于讓同一序列內部的 token 互相建模關系。
例如:句子中的每個 token 關注同一句子中的其他 token。
2、多頭交叉注意力
在多頭交叉注意力中,Q 來自一個序列,K 和 V 來自另一個序列。
例如,在編碼器—解碼器結構中:
編碼器輸出結果 → K、V可以理解為:解碼器在生成時,去關注編碼器提供的信息。
在圖文模型中,也可能出現:
圖像特征作為 K、V從通俗角度看:自注意力是“自己內部互相看”,交叉注意力是“一個序列去看另一個序列”。多頭機制可以同時用于這兩類注意力。
七、多頭注意力在 Transformer 中的位置
在 Transformer 中,多頭注意力是核心模塊之一。
一個典型 Transformer 層通常包含:
殘差連接 + LayerNorm在編碼器中,通常使用多頭自注意力。
在解碼器中,常見兩類注意力:
? 帶因果掩碼的多頭自注意力
? 多頭交叉注意力
1、編碼器中的多頭自注意力
編碼器處理完整輸入序列,每個 token 通常可以關注輸入中的所有 token。
例如,在文本理解任務中:
我 / 喜歡 / 機器學習每個 token 都可以參考其他 token。
2、解碼器中的因果多頭自注意力
在生成式模型中,當前位置不能看到未來 token。因此,需要使用因果掩碼。
例如,生成第 3 個 token 時,只能參考第 1、2、3 個 token,不能參考第 4 個 token。
從通俗角度看:因果掩碼防止模型“偷看答案”。
3、解碼器中的多頭交叉注意力
在機器翻譯等編碼器—解碼器模型中,解碼器還會通過交叉注意力關注編碼器輸出。
例如:
目標語言生成 → 解碼器通過交叉注意力參考源句信息這使模型在生成目標語言時能夠不斷查看源語言內容。
八、多頭注意力的優勢、局限與使用注意事項
1、多頭注意力的主要優勢
多頭注意力最大的優勢是可以從多個角度建模上下文。
不同頭可以學習不同關系模式,使模型表達能力更強。
其次,多頭注意力支持并行計算。
它不像 RNN 那樣必須按時間步依次處理,適合在 GPU 上高效訓練。
再次,多頭注意力可以用于多種數據類型。
文本、圖像 patch、語音片段、多模態特征都可以通過多頭注意力建模關系。
從通俗角度看,多頭注意力的優勢在于:讓模型不是只用一種視角理解輸入,而是用多個視角綜合判斷。
2、多頭注意力的主要局限
多頭注意力也有局限。
首先,計算成本較高。
標準注意力需要計算 L × L 的注意力矩陣,序列越長,成本越高。
可以近似理解為:
其中:
? L 表示序列長度
? O(L2) 表示計算量隨序列長度平方級增長
其次,注意力頭數量不是越多越好。
如果頭太少,表達能力可能不足。
如果頭太多,每個頭的維度太小,單個頭的表達能力也可能受限。
再次,注意力頭不一定都學到有用信息。
有些頭可能冗余,有些頭可能關注模式并不清晰。
此外,注意力圖不能被簡單等同于人類解釋。
某個頭關注某個 token,并不一定意味著它就是人類意義上的因果解釋。
3、使用多頭注意力時需要注意的問題
使用多頭注意力時,需要注意:
? d_model 通常要能被注意力頭數量 h 整除
? 每個頭的維度通常是 d_model / h
? 多頭結果需要拼接并經過輸出投影
? 自注意力中 Q、K、V 來自同一序列
? 交叉注意力中 Q 和 K、V 可來自不同序列
? 生成式模型需要因果掩碼
? 長序列會帶來較高計算和顯存成本
? 注意力頭數量需要結合模型規模和任務選擇
從實踐角度看,多頭注意力是 Transformer 的核心,但它的效果并不只由“頭數”決定,還與模型深度、隱藏維度、訓練數據和任務目標有關。
九、Python 示例
下面給出幾個簡單示例,用來幫助理解多頭注意力的基本使用。
示例 1:使用 PyTorch 的 MultiheadAttention
輸出形狀通常為:
注意力權重形狀: torch.Size([2, 5, 5])這個例子中:
? embed_dim = 16
? num_heads = 4
? 每個頭的維度為 16 / 4 = 4
? 輸出形狀仍然保持 batch × seq_len × embed_dim
示例 2:理解每個頭的維度
print("每個頭的維度:", d_head)輸出:
每個頭的維度:64這說明:
也就是說,768 維表示會被拆成 12 個 64 維的注意力頭進行計算。
示例 3:多頭交叉注意力
print("注意力權重形狀:", attn_weights.shape) # (2,6,10) 每個 query 對源序列的注意力分布這里:
? 目標序列長度為 6
? 源序列長度為 10
? 注意力權重形狀為 6 × 10
這表示目標序列中的每個位置,都可以關注源序列中的 10 個位置。
示例 4:帶因果掩碼的多頭自注意力
輸出:
這個例子中:
? 第 1 個 token 只能看自己
? 第 2 個 token 可以看第 1、2 個 token
? 第 3 個 token 可以看第 1、2、3 個 token
? 后面的 token 不能被提前看到
這正是生成式語言模型中常用的因果自注意力思想。
示例 5:查看不同頭的注意力權重
默認情況下,PyTorch 可能返回對所有頭平均后的注意力權重。如果希望查看每個頭的權重,可以設置:
print("每個頭的注意力權重形狀:", attn_weights.shape) 輸出形狀通常為:
每個頭的注意力權重形狀: torch.Size([2, 4, 5, 5])其中:
? 2 表示 batch size
? 4 表示注意力頭數量
? 5 × 5 表示每個頭中的 token 關注矩陣
這可以幫助觀察不同注意力頭是否學到了不同關注模式。
小結
多頭注意力是在注意力機制基礎上引入多個注意力頭,讓模型從多個角度理解上下文關系。每個頭使用不同的 Q、K、V 投影,分別計算注意力,再將結果拼接并線性映射。它是 Transformer 的核心結構之一。對初學者而言,可以把多頭注意力理解為:讓模型同時用多個視角閱讀同一段內容,再綜合這些視角形成更豐富的表示。
“點贊有美意,贊賞是鼓勵”
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.