5488 words
27 minutes
交叉注意力概述&科研工作推進

一、前言#

因為將 MoE(Mixture-of-Experts)不符合我們的論文想法,MoE最大的特點是 Top-K 的設計產生的稀疏計算,但是我們認為所有的專家都是有用的,即使是讓 Top-K失效,使用 softmax的做法,也已經讓整個架構混亂,所以我想到了使用交叉注意力(Cross-Attention),作為我們新的方向。

二、概念#

2.1 一切的基礎:注意力機制是什麼?#

attention.png

TIP

Scaled Dot-Product Attention(縮放點積注意力):輸入包括維度為dkd_k的查詢(queries)和鍵(keys),以及維度為dvd_v的值(values)。我們計算查詢與所有鍵的點積,每個點積結果都除以dk\sqrt{d_k},然後應用softmax函數,以得到注意力分數。

在我們深入探討之前,先用一句話理解「注意力機制」的本質。

  • 核心思想: 人類在觀察事物時,並不會均等地關注所有細節,而是會將「注意力」集中在最關鍵的部分。機器學習中的注意力機制,就是模仿這種行為,讓模型在處理數據時,能夠動態地為輸入的不同部分分配不同的「重要性」權重
  • 口語化比喻: 想像在一個嘈雜的雞尾酒會上,雖然身邊有幾十個人在說話,但可以集中注意力,只聽清你面前朋友的聲音,而忽略其他噪音。你的「聽覺注意力」幫助你從海量信息中篩選出了最重要的部分。

2.2 自注意力 (Self-Attention) - 模型的「內部反思」#

Self-Attention.png

這是理解交叉注意力的基礎,也是標準Transformer(如BERT、GPT)的核心。

  • 背景與要解決的問題: 在自注意力出現之前,處理像「The cat sat on the mat」這樣的句子,模型很難理解詞與詞之間的遠距離依賴關係。例如,「sat」(坐)這個動作的執行者是「cat」,地點是「mat」。如何讓「sat」這個詞在被處理時,能「意識到」它與「cat」和「mat」的強關聯呢?

  • 核心思想: 讓一個序列「自己對自己」進行注意力計算。 序列中的每一個元素(token),都會回頭審視序列中的所有其他元素(包括自己),並判斷「誰對我理解自己最重要?」,然後從那些重要的元素身上「吸收」信息。

  • 技術實現: 在自注意力中,一個序列X同時扮演了三種角色。它的查詢(Query)、鍵(Key)、值(Value)都來源於自身:

    • Q = W_q * X
    • K = W_k * X
    • V = W_v * X

2.3 交叉注意力 (Cross-Attention) - 模型的「跨界訪談」#

Cross-Attention.png

這正是我們設計「交叉注意力融合先驗」時所使用的核心技術。

  • 背景與要解決的問題: 自注意力完美地解決了單一信息流(如一段文本)的內部理解問題。但如果我們有兩種不同來源的信息,並希望它們之間產生交互呢?

    • 經典例子:看圖說話(Image Captioning)。 模型需要根據一張圖像(來源A),生成一段文字描述(來源B)。當模型要生成「貓」這個詞時,它如何知道應該去關注圖像中的那隻貓,而不是旁邊的沙發?自注意力只能讓文字關注文字,或者讓圖像像素關注圖像像素,無法實現這種跨模態的對齊。
    • 我們的問題: 我們有多個專家(來源A, B, C…)給出的去噪預測結果,還有一個帶噪聲的原始圖像x_t(來源X)。我們如何讓x_t中的每個像素,去「參考」所有專家的預測結果,並智能地融合它們?
  • 核心思想: 讓一個序列作為「查詢方」,去「審視」和「借用」另一個完全不同的序列的信息。 實現了兩組信息之間的定向查詢和信息提取。

  • 口語化比喻:「記者採訪專家團」

    • 查詢序列 (Query Sequence): 扮演記者的角色。在我們的算法中,就是帶噪聲的圖像x_t。它代表著一系列「問題」——「我這個像素點到底應該是什麼樣子?」
    • 鍵/值序列 (Key/Value Sequence): 扮演被採訪的專家團的角色。在我們的算法中,就是所有N個專家的去噪預測結果。它們代表著一系列「答案」或「觀點」。
    • 採訪過程:
      1. 記者(x_t中的某個像素)帶著自己的問題(Query)走向專家團。
      2. 他快速瀏覽每位專家的名牌和簡介(Key),判斷哪幾位專家可能與自己的問題最相關。
      3. 然後,他會重點採訪這幾位最相關的專家,聽取他們的詳細回答(Value)。
      4. 最後,記者將採訪到的所有信息,根據他對專家們的信任程度進行加權總結,形成自己的最終報導(融合後的像素值)。
  • 技術實現: 在交叉注意力中,Query來自一個序列A,而Key和Value來自另一個序列B:

    • Q = W_q * A (例如,A是x_t)
    • K = W_k * B (例如,B是所有專家的輸出)
    • V = W_v * B
TIP

交叉注意力是一種在一些現代自然語言處理(NLP)任務的架構中使用的機制,如 Transformer 模型。交叉注意力的思想是使一個序列能夠“關注”另一個序列。在許多場景中,這可能很有用,例如在機器翻譯中,將輸入序列(源語言)的部分與輸出序列(目標語言)的部分對齊是有益的。

交叉注意力的機制與 Transformer 模型中使用的自注意力機制非常相似,但是交叉注意力是一個序列關注另一個序列,而不是自己。

  • Self-Attention:我關注我自己
  • Cross Attention:我關注另一個人

2.4 多頭注意力(Multi-Head Attention)-模型的「專家會診」#

Multi-Head_Attention.png

2.4.1 背景與要解決的問題#

我們之前用「團隊內部會議」比喻了自注意力。在這個會議上,每個詞元(團隊成員)都會去徵求其他所有人的意見。

但這裡存在一個潛在的問題:如果只進行一次籠統的意見徵求,模型可能會**「抓不住重點」**。

  • 單頭注意力的局限性: 一次自注意力計算(單頭)就像讓模型從一個單一的、綜合的角度去理解詞與詞之間的關係。它可能會學會一種「平均的」、「最常見的」關聯模式。例如,在處理「The cat chased the mouse」時,它可能很好地學會了「cat」和「chased」以及「chased」和「mouse」之間的動作關係。但如果句子變得更複雜,比如「The cat, which was quick, chased the mouse that stole the cheese」,單一的注意力可能很難同時捕捉到「cat -> was quick」的描述關係,「mouse -> stole cheese」的從句關係,以及「chased」的動作關係。

  • 問題: 如何讓模型能夠在一次計算中,同時從多個不同的角度、不同的維度去審視和理解序列內部的複雜關係?

2.4.2 核心思想:「分而治之,再行綜合」#

多頭注意力的思想非常優雅,它採用了「分而治之」的策略。與其讓一個「全科醫生」做一次全面的、籠統的體檢,不如請來一組「專科醫生」,每個人都從自己最擅長的領域進行檢查,最後再把所有專家的報告匯總起來,形成一份全面的體檢報告。

  • 工作流程:
    1. 分頭 (Splitting into Heads): 模型不會直接使用一組大的QKV(查詢、鍵、值)矩陣,而是將它們的維度切分成多個(例如8個或12個)更小的、獨立的「頭」。每一組小的QKV就代表一個「專科醫生」。
    2. 並行注意力計算 (Parallel Attention): 這8個「頭」會並行地、獨立地進行各自的自注意力計算。
    3. 各有所長 (Learning Different Aspects): 由於每個頭的權重是獨立初始化的,它們在訓練過程中會逐漸學會關注不同方面的關係。
      • 頭1 可能變成了「語法專家」,專門關注主謂賓結構。
      • 頭2 可能變成了「距離專家」,專門關注相鄰詞元之間的關係。
      • 頭3 可能變成了「語義專家」,專門關注同義詞或反義詞之間的關係。
      • 頭4 可能變成了「指代專家」,專門關注代詞it到底指向了前面哪個名詞。
    4. 綜合結果 (Concatenation & Projection): 當所有頭都完成了自己的注意力計算後,它們會得到各自的輸出結果。模型會將這8個結果拼接在一起,然後再通過一個最終的線性變換,將它們融合成一個統一的、富含多維度信息的最終輸出。
TIP

關鍵點:多頭,不是一種新的注意力類型,而是對基礎注意力(無論是自注意力還是交叉注意力)的一種「實現方式」或「性能升級包」。

2.4.3 口語化比喻:「團隊分組討論」#

如果說自注意力是「團隊內部會議」,那麼多頭自注意力就是對這個會議的升級:

  • 舊模式(單頭): 所有團隊成員在一個大圓桌上七嘴八舌地討論,信息混雜。
  • 新模式(多頭): 會議主持人(模型)說:「我們現在分成幾個小組來討論!A組專門討論技術細節,B組專門討論市場策略,C組專門討論時間規劃…」
    • 每個小組(一個「頭」)內部都會進行充分的討論(一次獨立的自注意力計算)。
    • 討論結束後,每個小組的組長會上台匯報自己小組的結論(一個「頭」的輸出)。
    • 最後,主持人將所有小組的結論匯總起來,形成一個全面的、考慮周詳的最終決策(多頭注意力的最終輸出)。

2.4.4 誤解:多頭注意力依然是選取部分專家嗎#

多頭注意力 (Multi-Head Attention) 並不是在「選取部分專家」。它會利用所有專家的信息。

「選取部分專家」這個動作,是屬於我們之前討論的稀疏混合專家 (MoE) / Top-K路由的專屬職責。而我們要設計的、更強大的「交叉注意力融合網絡」,其核心思想正是拋棄了「N選K」的硬性選擇,轉而讓模型能夠智能地、柔性地利用所有專家的信息

三、 思想比較#

特性自注意力 (Self-Attention)交叉注意力 (Cross-Attention)多頭自注意力 (Multi-Head Self-Attention)我們的MoE (Top-K Routing)
核心思想內部反思 / 自我關聯跨界訪談 / 信息查詢專家會診 / 分組內部反思專家投票 / 稀疏選擇
信息來源單一序列兩個不同的序列單一序列一個輸入序列,N個專家網絡
Q, K, V 來源Q, K, V 全部來自同一序列Q來自序列A,K和V來自序列BQ, K, V 全部來自同一序列,但被切分成多頭並行計算Q來自輸入,K來自輸入,V來自被選中的K個專家
解決的問題理解單一序列內部的上下文和長程依賴對齊和融合兩種不同來源的信息多個子空間中,同時捕捉不同類型的上下文關聯在保持計算稀疏性的前提下,動態選擇最相關的專家
典型應用Transformer編碼器、GPT模型圖像描述生成、多模態融合、我們設計的專家融合網絡現代所有Transformer的編碼器 (BERT, ViT) 和解碼器 (GPT)Switch Transformer, Mixtral
計算模式密集計算 (序列內所有元素交互)密集計算 (兩個序列間所有元素交互)密集計算稀疏計算 (只計算K個專家)

四、反思當前科研工作#

4.1 從MoE轉移到交叉注意力#

現階段的唯一目標是追求極致的正確率,而不考慮計算效率,那麼傳統的、為了節省計算而設計的稀疏MoE確實不是最佳選擇。

  • 所有先驗都有用: 每個預訓練的先驗模型都包含了獨特的知識,理想情況下我們希望能全部利用,而不是只做「N選K」的選擇題。
  • 單純係數權重的局限性: w1E1+w2E2...w1*E1 + w2*E2... 這種對整個圖像使用同一組權重的加權平均,確實不夠智能。它無法處理您比喻的那種情況——「爸爸的鼻子」和「媽媽的眼睛」,即一個先驗在圖像的A區域表現好,另一個先驗在B區域表現好。

核心思想

我們可以將這個問題重新定義為:

對於帶噪聲的圖像 xtx_t 中的每一個像素(或每一個小區域),我們如何「查詢」所有專家(先驗)在該像素上的預測結果,並根據 xtx_t 自身的特徵,智能地決定如何將這些專家的預測結果「融合」成最終的預測?

**交叉注意力(Cross-Attention)**正是解決這個問題的完美工具。

  • 查詢 (Query): 來自帶噪聲的圖像xtx_t。它代表了「我這個位置需要什麼樣的先驗信息來去噪?」
  • 鍵 (Key) 和 值 (Value): 來自所有N個專家的去噪預測結果 E1(xt),E2(xt),...,EN(xt){E_1(x_t), E_2(x_t), ..., E_N(x_t)}。它們代表了「我們每個專家分別提供了什麼樣的去噪方案(Value),以及這個方案的特徵是什麼(Key)」。

通過計算xtx_t的Query和所有專家輸出的Key之間的相似度,模型可以為每個像素動態地生成一組融合權重,然後用這組權重去加權求和所有專家的Value。

口語化解釋: 這相當於,對於圖像中「鼻子」區域的每一個像素,模型會看一眼所有專家的預測結果,然後說:「根據我看到的噪聲情況,專家A(爸爸)關於鼻子的預測看起來最靠譜,專家B(媽媽)的次之,專家C的完全不相關。」 於是,它可能會給出一個融合權重 [0.8,0.15,0.05][0.8, 0.15, 0.05]。而在處理「眼睛」區域的像素時,它可能會給出完全不同的權重 [0.1,0.8,0.1][0.1, 0.8, 0.1]。這就實現了想要的「各取所需」。

4.2 更進一步—多頭交叉注意力(Multi-Head Cross-Attention)#

我們是否可以在交叉注意力的基礎上再進一步演進成多頭交叉?

這個演進過程就是:

  1. 我們首先確定了我們的任務需要交叉注意力,因為我們有兩個不同的信息源:一個是**「問題」(帶噪聲的圖像x_t),另一個是「潛在答案庫」**(所有專家的輸出)。
  2. 然後,為了讓我們的「提問-回答」過程更全面、更強大,我們決定不只進行一次籠統的交叉注意力計算,而是採用多頭機制來實現它。
  3. 於是,多頭交叉注意力誕生了。它讓x_t中的每個像素,能夠同時從多個不同的角度(比如從「結構」角度、從「紋理」角度、從「平滑度」角度)去「查詢」所有專家的輸出,並為每個角度都生成一組融合權重,最後再將所有角度的結論綜合起來。
TIP

想像一下,你是CEO,手下有8位專家顧問(您的8個先驗模型)。現在遇到了一個複雜問題(一張帶噪聲的圖像x_t),你需要集合團隊的智慧來解決。

多頭交叉注意力融合

  • 工作流程:
    1. 全員參與 (執行所有專家): 您邀請所有8位專家全部就位,並讓他們每個人都對當前的問題(x_t)給出一份完整的分析報告(E_i(x_t))。
    2. 分組會診 (多頭機制): 您不是只有一位決策者,而是有一個由**4位副總(4個「頭」)**組成的決策委員會。
    3. 獨立審閱 (交叉注意力):
      • 每一位副總(每一個「頭」)都會拿到所有8位專家的分析報告。
      • 副總A(頭1)可能從「技術可行性」角度審閱,他可能會發現專家1和專家5的報告最有價值。
      • 副總B(頭2)可能從「市場風險」角度審閱,他可能會覺得專家2和專家7的報告更有參考意義。
      • 重點: 每一位副總都會獨立地、從自己的專業視角出發,對所有8份報告進行一次全面的評估和加權。
    4. 最終決策 (結果融合): 最後,4位副總坐在一起,將他們各自綜合出的結論(每個頭的輸出)匯總起來,形成公司最終的、最全面的決策(最終的去噪圖像)。
  • 核心特點:
    • 密集融合: 所有專家都參與了計算和貢獻。
    • 正確率導向: 不計計算成本,目標是從所有可用的信息中,榨取出最優的結果。
    • 多維度分析: 多頭機制確保了融合過程是從多個不同的「視角」進行的,而不是單一的、片面的。

結論

  • MoE Top-K: 是在決定誰能參與計算的階段,進行專家級別的選擇
  • 多頭交叉注意力: 是在所有專家都已完成計算之後,進行信息級別的融合。其中的「多頭」指的是融合的視角有多個,而不是「被融合的專家只有幾個」。

4.3 交叉多頭注意力與多頭注意力對比#

特性基礎自注意力 (Self-Attention)多頭自注意力 (Multi-Head Self-Attention)多頭交叉注意力 (Multi-Head Cross-Attention)
核心思想內部反思專家會診 / 分組內部反思跨界專家團訪談
信息來源單一序列單一序列兩個不同的序列
Q, K, V 來源Q, K, V 全部來自同一序列Q, K, V 全部來自同一序列,但被切分成多頭並行計算Q來自序列A,K和V來自序列B,同樣被切分成多頭並行計算
解決的問題捕捉序列內部的基本上下文關聯多個子空間中,同時捕捉不同類型的上下文關聯多個子空間中,對齊和融合兩種不同來源的信息
典型應用簡單的注意力模型現代所有Transformer的編碼器 (BERT, ViT) 和解碼器 (GPT)圖像描述生成、多模態融合、我們設計的專家融合網絡

4.4 多頭交叉注意力 (Multi-Head Cross-Attention) 內部架構詳圖#

目標: 讓一個「查詢」序列(Query Sequence),能夠從多個不同的角度,去「審視」和「提取」另一個「信息」序列(Key/Value Sequence)中的內容。

輸入:

  • 查詢序列 A (例如,帶噪聲的圖像 x_t 的特徵)
  • 信息序列 B (例如,所有專家輸出的特徵)

Multi-Head_Cross-Attention.png

架構圖流程詳解

  1. 初始線性投影 (Initial Linear Projection):

    • 模塊接收兩個不同的序列,AB
    • 它首先通過三個獨立的線性層,將A投影成Query (Q),將B投影成Key (K)Value (V)。這是「交叉」的關鍵,Q和K/V來自不同的源頭。
  2. 分頭 (Split into Heads):

    • 為了實現多角度分析,模型會將剛剛生成的大Q, K, V向量,在特徵維度上「切」成多份(例如8個頭)。現在我們就有了8套小型的Q1, K1, V1, Q2, K2, V2, …
  3. 並行注意力計算 (Parallel Attention Calculation):

    • 接下來的計算對於每個頭都是獨立且並行發生的。我們以「頭1」為例:
      • 計算注意力分數:Q1K1進行矩陣乘法(MatMul),得到一個分數矩陣,表示A中的每個元素對B中每個元素的關注程度。
      • 歸一化: 對分數進行縮放(Scale,通常是除以維度的平方根以穩定梯度),然後用Softmax函數將其轉換為總和為1的注意力權重。
      • 加權求和: 用計算出的注意力權重,去對V1進行加權求和。這一步是真正的「信息提取」,它根據權重從V1中篩選並融合出最重要的信息。
      • 輸出out1 「頭1」完成它的分析,得到一個輸出向量out1
  4. 拼接與最終投影 (Concatenate & Final Projection):

    • 當所有8個頭都完成了它們的並行計算後,我們會得到8個獨立的輸出向量out1, out2, ..., out8
    • 模型將這8個向量**拼接(Concatenate)**在一起,形成一個大的特徵向量。
    • 最後,這個大的特徵向量會再經過一個線性層(Linear_Output)進行一次最終的融合和降維,得到整個多頭交叉注意力模塊的最終輸出。

這個最終輸出,就是一個序列A的新表示,但它現在已經智能地、從多個角度融合了來自序列B的豐富信息。

Reference#

交叉注意力概述&科研工作推進
https://huangno1.github.io/posts/cross_attention_intro_and_research_work/
Author
HuangNO1
Published at
2025-07-15
License
CC BY-NC-SA 4.0