Cross-attentionによる特徴量結合の実装
フックを用いた中間層の活性値抽出:
ネットワークのソースコードを変更することなく、任意の中間層の出力を捕捉するための最も洗練された方法は、PyTorchのフック機能を利用することである。特にregister_forward_hookは、フォワードパス中に特定のモジュールの出力を取得するための強力なツールである 。以下に、特定のレイヤー(例:pretrained_model.layer)にフックをアタッチし、その出力を辞書に保存するコード例を示す。
code: _.py
# 中間特徴を保存するためのグローバルな辞書
intermediate_features = {}
def get_features(name):
def hook(model, input, output):
intermediate_featuresname = output.detach() return hook
# ASTモデルの特定のTransformerブロックにフックを登録
# 'ast.encoder.layer' はモデルの構造に依存する
pretrained_model.ast.encoder.layer.register_forward_hook(get_features('ast_layer_10'))
CrossAttentionの実装例: 残差接続を含む
code: _.py
import torch
import torch.nn as nn
class CrossAttentionFusion(nn.Module):
def __init__(self, cnn_dim, pt_dim, embed_dim, num_heads, dropout=0.1):
super().__init__()
# 各特徴を共通の埋め込み次元に射影する層
self.q_proj = nn.Linear(cnn_dim, embed_dim)
self.k_proj = nn.Linear(pt_dim, embed_dim)
self.v_proj = nn.Linear(pt_dim, embed_dim)
# Multi-head Cross-Attention層
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
# Feed-Forward Network
self.ffn = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(embed_dim * 4, embed_dim)
)
# Layer Normalization
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x_cnn, x_pt):
# 入力形状: x_cnn: (B, T_cnn, C_cnn), x_pt: (B, T_pt, C_pt)
# Tはシーケンス長、Cはチャンネル数/次元数
# 1. 特徴を共通次元に射影
q = self.q_proj(x_cnn)
k = self.k_proj(x_pt)
v = self.v_proj(x_pt)
# 2. Cross-Attentionと残差接続
attn_output, _ = self.attention(query=q, key=k, value=v)
x = self.norm1(q + self.dropout(attn_output)) # qに残差接続
# 3. Feed-Forward Networkと残差接続
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
モデル構築
最終的な親クラスとして、nn.Moduleを継承したハイブリッドモデルを定義する。このクラスは、インスタンス化の際にCNNバックボーン、事前学習済みモデル、そして1つ以上のCrossAttentionFusionモジュールを内部に保持する。
この親クラスのforwardメソッドが、全体の処理フローを統括する。
入力されたスペクトログラムを、CNNと事前学習済みモデルの両方のバックボーンにそれぞれ入力する。
事前に登録したフックを通じて、指定した中間層の特徴マップをintermediate_features辞書から取得する。
取得した特徴マップをCrossAttentionFusionモジュールに入力し、融合された特徴表現を得る。この際、特徴マップの形状(特にシーケンス長)をAttention層が扱えるように、適切に変形(例:reshape, permute)する必要がある場合がある。
融合された特徴マップを、後続の層(例:RNNや最終的な分類層)に入力し、最終的な音響イベントの予測結果を出力する。