AttentionモデルとCNNモデルの判断根拠を比較してみる


0.目次

1.やりたいこと

・sin波とcos波の分類を行う機械学習モデルを作成し、その予測の判断根拠を可視化する。
・分類を行う機械学習モデルは、Self-AttentionモデルとCNNモデルを作成する。Self-Attentionモデルではattention-weightを可視化することで、CNNモデルではGradCAMを使用することで可視化し、モデル毎の判断根拠を比較する。

2.使用するライブラリ

import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import seaborn as sns
sns.set()

3.使用するデータ

sin波とcos波にノイズを加えたデータをそれぞれ100個作成し、sin波であれば[1, 0]のラベルをcon波であれば[0, 1]のラベルを付加します。

class Dataset:
    def __init__(self, sequential_length=100, num_point=100):
        # ノイズを加えたsin波とcos波のデータ生成
        np.random.seed(42)  # 乱数シードの設定

        self.sequential_length=sequential_length
        self.num_point = num_point

        # sin波にノイズを加えたデータの生成
        sin_data = np.sin(np.linspace(0, 2 * np.pi, self.num_point))  # sin波の生成
        sin_noisy_data = []
        for _ in range(self.sequential_length):
            noisy_data = sin_data + np.random.normal(0, 0.1, size=sin_data.shape)  # ノイズの追加
            sin_noisy_data.append(noisy_data)
        sin_noisy_data = np.concatenate(sin_noisy_data)
        sin_noisy_data = sin_noisy_data.reshape(self.sequential_length, -1)
        

        # cos波にノイズを加えたデータの生成
        cos_data = np.cos(np.linspace(0, 2 * np.pi, self.num_point))  # cos波の生成
        cos_noisy_data = []
        for _ in range(self.sequential_length):
            noisy_data = cos_data + np.random.normal(0, 0.1, size=cos_data.shape)  # ノイズの追加
            cos_noisy_data.append(noisy_data)
        cos_noisy_data = np.concatenate(cos_noisy_data)
        cos_noisy_data = cos_noisy_data.reshape(self.sequential_length, -1)
        

        # データの結合とシャッフル
        data = np.concatenate((sin_noisy_data, cos_noisy_data))
        data = data.reshape(2 * self.sequential_length, -1, 1)
        
        sin_label1 = np.ones((self.sequential_length, 1))
        sin_label2 = np.zeros((self.sequential_length, 1))
        sin_label = np.concatenate((sin_label1, sin_label2), axis=1)

        cos_label1 = np.zeros((self.sequential_length, 1))
        cos_label2 = np.ones((self.sequential_length, 1))
        cos_label = np.concatenate((cos_label1, cos_label2), axis=1)

        labels = np.concatenate((sin_label, cos_label))
        
        #torchに変換する
        #[batch_size, 1, sequential_length]
        self.X = torch.tensor(data, dtype=torch.float32).transpose(1,2)
        self.t = torch.tensor(labels, dtype=torch.float32)
    
    def __len__(self):

        return len(self.X)

    def __getitem__(self, index):

        return self.X[index], self.t[index]

以下のようなデータが生成されます。

4.Self-Attentionモデルの学習/判断根拠の可視化

4.1.Self-Attention

まず、Attention(注意機構)について簡単に解説します。(※詳細な解説は他の記事に譲ります。)Attentionは自然言語処理の分類や翻訳のタスクにおいて、単語間の関連付け、翻訳元と翻訳先の単語との関連付けを動的に行うことで入力情報を全体を見渡して重要な情報を抽出することが可能で自然言語処理や画像分類のタスクに変革をもたらしました。自然言語処理のタスクにおいては、それまでのLSTM等のRNNの欠点である並列計算ができない課題を克服しました。画像分類のタスクでは、CNNの欠点である離れた情報の関連付けを行うことが可能になり、画像分類の予測精度にも寄与しました。例えば、以下の画像のように位置的に離れた情報を関連付けることが出来ます。

Transfomerの中では、Self-Attention機構とSorceTarget-Attentionが実装され、現在注目されているChatGPTにも組み込まれてます。
以下にSelf-Attentionの仕組み/計算方法を簡単に示します。

系列データABC(何らかの文章と思ったら良い)をSelf-Attentionにより単語間の関連付けを行うことを考えます。次のステップで計算を行います。
1. 各単語を数値に変換するEmbeddingを行います。
2.Embeddingした各単語に対してさらにEmbeddingを行い、Query、Key、Valueを作成します。
3.QueryとKeyの内積を計算し、Queryに設定した単語(注目する単語、他の単語と関連付けした量を算出したい単語)とKeyに設定した単語の類似度を計算する。
4.類似度でValueを重みづけを行う。
5.重みづけしたValue'(Weighted-Value)をすべての単語に対して計算したの後に総和をとる。
以上のステップにより他の単語との関連付けを動的に行う。学習の中でEmbeddingレイヤーの重みが最適化され、分類精度を上げるのにどこに注目すればよいか(attentionすればよいか)を学習する。学習後に得られる各単語間の類似はAttentionWeightと呼ばれ、機械学習モデルの予測の判断根拠を可視化する際に有用である。AttentionはXAIの一面を持つことで説明性を付加してくれる。今回はこのAttentionWeightを用いて、判断根拠を可視化する。

こちらの記事を見ると、Attentionの仕組み、計算方法が良く分かると思います。 jalammar.github.io

4.2.学習モデルの定義

それでは、Sel-Attentionモデルを構築、学習を行い、Attention-weigthを可視化してみます。
まず、Self-Attentionモデルのアーキテクチャーを下記のように定義します。 入力データ(sin波、cos波)をEmbbedingします。その後、Self-Attentionに使用するため、Query, Key, ValueをEmbeddingを行い作成します。作成したQuery, Key, Valueを使用してSelf-Attention計算を行い、最後にFlatten処理を行い、Affineレイヤーを結合し、出力層につながる構造になっています。

class AttentionClasifier(nn.Module):

    def __init__(self, dim_emb:int=10, num_heads:int=1, sequential_length=100):
        super(Clasifier, self).__init__()

        #埋め込み次元
        self.dim_emb = dim_emb
        #アテンションヘッドの数
        self.num_heads = num_heads
        #各ヘッドの次元数
        self.head_dim = dim_emb // num_heads
        #
        self.sqrt_dh = self.head_dim**0.5

        #埋め込み
        self.emb = nn.Linear(1, dim_emb)
        #query埋め込み
        self.query = nn.Linear(dim_emb, dim_emb, bias=False)
        #key埋め込み
        self.key = nn.Linear(dim_emb, dim_emb, bias=False)
        #value埋め込み
        self.value = nn.Linear(dim_emb, dim_emb, bias=False)
        #
        self.w_o = nn.Linear(dim_emb, dim_emb)
        #平坦化
        self.flatten = nn.Flatten()
        #
        self.affine = nn.Linear(sequential_length * dim_emb, 2)

    def forward(self, x):

        batch_size, _ , sequence_length = x.size()
        #形状変換
        #(batch_size, 1, sequence_length) -> (batch_size, sequence_length, 1)
        x = x.transpose(1,2)
        #埋め込み
        ## (batch_size, sequence_length, 1) -> (batch_size, sequence_length, dim_emb)
        x = self.emb(x)
        #query, key, value
        ## (batch_size, sequence_length, dim_emb) -> (batch_size, sequence_length, dim_emb)
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        #q, k, vをヘッドに分ける
        ## (batch_size, sequence_length, dim_emb) -> (batch_size, sequence_length, num_heads, dim_emb//num_heads)
        q = q.view(batch_size, sequence_length, self.num_heads, self.head_dim)
        k = k.view(batch_size, sequence_length, self.num_heads, self.head_dim)
        v = v.view(batch_size, sequence_length, self.num_heads, self.head_dim)

        #セルフアテンションができるように変換
        ## (batch_size, sequence_length, num_heads, dim_emb//num_heads) -> (batch_size, num_heads, sequential_length, dim_emb//num_heads)
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)

        #内積
        ## (batch_size, num_heads, sequential_length, dim_emb//num_heads) -> (batch_size, num_heads, dim_emb//num_heads, sequential_length)
        k_T = k.transpose(2,3)
        ## (batch_size, num_heads, sequential_length, dim_emb//num_heads) * (batch_size, num_heads, dim_emb//num_heads, sequential_length) -> (batch_size, num_heads, sequential_length, seqential_length)
        dots = (q @k_T) / self.sqrt_dh

        #列方向にソフトマックス
        attn = F.softmax(dots, dim=-1)

        #加重和
        ## (batch_size, num_heads, sequential_lenght, sequential_length) * (batch_size, num_heads, sequential_length, dim_emb//num_heads) -> (batch_size, num_heads, sequential_length, dim_emb//num_heads)
        out = attn @ v
        ## (batch_size, num_heads, sequential_length, dim_emb//num_heads) -> (batch_size, sequential_lengh, num_heads, dim_emb//num_heads)
        out = out.transpose(1,2)
        ## (batch_size, sequential_lengh, num_heads, dim_emb//num_heads) -> (batch_size, sequential_lengh, dim_emb)
        out = out.reshape(batch_size, sequence_length, self.dim_emb)

        out = self.w_o(out)
        out = self.flatten(out)
        out = self.affine(out)
       
        return out, attn
    

4.3.Dataloaderの作成/学習

続いて、PytorchのDataloaderを使用してバッチ学習用にデータを作成し、学習モデルアーキテクチャーをインスタンス化します。その後、学習を行います。

#データセットの作成
dataset = Dataset()
#バッチサイズ
batch_size = 8
#バッチ学習のためにDataloaderに変換する
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

#学習モデルをインスタンス化
cls_attr= AttentionClasifier()

# 損失関数と最適化手法の定義
learning_rate = 1e-3
num_epochs = 1000
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cls_attr.parameters(), lr=learning_rate)

# 学習ループ
total_step = len(loader)
for epoch in range(num_epochs):
    for i, (data, labels) in enumerate(loader):
       
        # フォワードパス
        outputs, attn = cls_attr(data)
        loss = criterion(outputs, labels)

        # バックワードパスと最適化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_step}], Loss: {loss.item():.4f}")

学習を行うと、lossが更新されます。lossが最終的に0になり、完全に分類されます。

Epoch [1/1000], Step [10/25], Loss: 0.7461
Epoch [1/1000], Step [20/25], Loss: 0.6778
Epoch [2/1000], Step [10/25], Loss: 0.5947
Epoch [2/1000], Step [20/25], Loss: 0.4390
Epoch [3/1000], Step [10/25], Loss: 0.0641
Epoch [3/1000], Step [20/25], Loss: 0.0084
Epoch [4/1000], Step [10/25], Loss: 0.0014
Epoch [4/1000], Step [20/25], Loss: 0.0008
Epoch [5/1000], Step [10/25], Loss: 0.0005
Epoch [5/1000], Step [20/25], Loss: 0.0005
Epoch [6/1000], Step [10/25], Loss: 0.0004
Epoch [6/1000], Step [20/25], Loss: 0.0004
Epoch [7/1000], Step [10/25], Loss: 0.0003
Epoch [7/1000], Step [20/25], Loss: 0.0003
Epoch [8/1000], Step [10/25], Loss: 0.0003
Epoch [8/1000], Step [20/25], Loss: 0.0002
Epoch [9/1000], Step [10/25], Loss: 0.0002
Epoch [9/1000], Step [20/25], Loss: 0.0002
Epoch [10/1000], Step [10/25], Loss: 0.0002
Epoch [10/1000], Step [20/25], Loss: 0.0002
Epoch [11/1000], Step [10/25], Loss: 0.0002
Epoch [11/1000], Step [20/25], Loss: 0.0001
Epoch [12/1000], Step [10/25], Loss: 0.0001
Epoch [12/1000], Step [20/25], Loss: 0.0002
Epoch [13/1000], Step [10/25], Loss: 0.0001
...
Epoch [999/1000], Step [10/25], Loss: -0.0000
Epoch [999/1000], Step [20/25], Loss: -0.0000
Epoch [1000/1000], Step [10/25], Loss: -0.0000
Epoch [1000/1000], Step [20/25], Loss: -0.0000

4.4.Self-Attentionによる判断根拠の可視化

AttentionWeightを可視化してみます。

#データを取得
data_ = dataset.X
#AttentionWeightを取得
cls_attr.eval()
_, attn = cls_attr(data_) 
data_ = data_.reshape(-1, 100)
#データ番号
case = 101
#プロット
attn_ = attn[case].reshape(100, 100)
plt.imshow(attn_.detach().numpy())
plt.show()

以下にcos波とsin波のAttentionMapの比較を示します。 学習の中で波の山の位置と谷の位置に関連があるという情報を獲得しています。面白いですね。Attentionmapと波形を重ねてみます。

#データ
data_ = dataset.X
#AttentionWeight
cls_attr.eval()
_, attn = cls_attr(data_) 
data_ = data_.reshape(-1, 100)
#データ番号
case = 1
#プロット
fig, ax = plt.subplots()
# カラーマップを作成
cmap = plt.get_cmap('Blues')
attn_ = attn[case].reshape(100, 100)
attn_list = attn_.detach().numpy().sum(axis=0)
attn_list = (attn_list - attn_list.min()) / (attn_list.max() - attn_list.min())
attn_list = attn_list.tolist()
#AttentionWeightで色付け
for i in range(100):
    ax.axvspan(xmin=i, xmax=i+1, color=cmap(attn_list[i]), alpha=0.25)

plt.plot(np.linspace(0, 100, 100), data_[case])
plt.show()

Self-Attentionモデルは山と谷の位置を見て予測をしていることが分かります。

5.CNNモデルの学習/判断根拠の可視化

5.1.GradCAM

GradCAM(Gradient-weighted Class Activation Mapping)は、畳み込みニューラルネットワーク(CNN)を用いた画像認識タスクにおいて、ネットワークの予測結果を可視化するための手法です。GradCAMは、モデルが画像内のどの領域に注目して予測を行ったのかを視覚的に理解するために使用されます。GradCAMの処理手順は以下のようになります。
1.まず、対象となる画像をネットワークに入力し、予測結果を得ます。
2.ネットワークの最終畳み込み層の出力(特徴マップ)に対して、予測結果に対する勾配を求めます。勾配は、予測クラスに対する特徴マップの各ピクセルの重要度を表します。
3.勾配を特徴マップの各チャンネルに適用し、チャンネルごとに重要度を求めます。この重要度は、各チャンネルが予測にどれだけ寄与しているかを示します。
4.チャンネルごとの重要度を特徴マップに重み付けして足し合わせ、GradCAMマップを生成します。GradCAMマップは、予測クラスに対して重要な領域を強調したヒートマップとして表現されます。
5.GradCAMマップを元の入力画像のサイズにアップサンプリングします。
6.最後に、GradCAMマップを元の入力画像に重ね合わせることで、モデルが予測に使用した重要な領域を可視化します。
このようにして生成されたGradCAMマップは、ネットワークの予測の根拠を可視化するために利用され、画像認識モデルの解釈性や信頼性を向上させるのに役立ちます。

5.2.学習モデルの定義

一般的な1次元のCNNモデルを定義します。

class Conv1dClasifier(nn.Module):

    def __init__(self, num_filter, kernel_size):
        super(Conv1dClasifier, self).__init__()
        self.num_filter = num_filter
        self.kernel_size = kernel_size

        self.conv1 = nn.Conv1d(1, num_filter, kernel_size)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(num_filter, num_filter*2, kernel_size)
        self.pool2 = nn.MaxPool1d(2)
        self.flatten = nn.Flatten()
        self.affine1 = nn.Linear(368, 100)
        self.affine2 = nn.Linear(100, 10)
        self.affine3 = nn.Linear(10, 2)

        
    def forward(self, x):

        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.flatten(x)
        x = F.relu(self.affine1(x))
        x = F.relu(self.affine2(x))
        x = self.affine3(x)

        return x
    

5.3.Dataloaderの作成/学習

#学習及び評価
#データセットの作成
dataset = Dataset()
#バッチサイズ
batch_size = 8
#バッチ学習のためにDataloaderに変換する
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

#学習モデルをインスタンス化
cls = Conv1dClasifier(num_filter=8, kernel_size=3)

# 損失関数と最適化手法の定義
learning_rate = 1e-3
num_epochs = 1000
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cls.parameters(), lr=learning_rate)

# 学習ループ
total_step = len(loader)
for epoch in range(num_epochs):
    for i, (data, labels) in enumerate(loader):
       
        # フォワードパス
        outputs = cls(data)
        loss = criterion(outputs, labels)

        # バックワードパスと最適化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_step}], Loss: {loss.item():.4f}")

学習を行うと、lossが更新されます。lossが最終的に0になり、完全に分類されます。

Epoch [1/1000], Step [10/25], Loss: 0.6089
Epoch [1/1000], Step [20/25], Loss: 0.4347
Epoch [2/1000], Step [10/25], Loss: 0.4945
Epoch [2/1000], Step [20/25], Loss: 0.1950
Epoch [3/1000], Step [10/25], Loss: 0.5618
Epoch [3/1000], Step [20/25], Loss: 0.1849
Epoch [4/1000], Step [10/25], Loss: 0.4518
Epoch [4/1000], Step [20/25], Loss: 0.2676
Epoch [5/1000], Step [10/25], Loss: 0.3489
Epoch [5/1000], Step [20/25], Loss: 0.2580
Epoch [6/1000], Step [10/25], Loss: 0.2527
Epoch [6/1000], Step [20/25], Loss: 0.1663
Epoch [7/1000], Step [10/25], Loss: 0.3255
Epoch [7/1000], Step [20/25], Loss: 0.2407
Epoch [8/1000], Step [10/25], Loss: 0.1572
Epoch [8/1000], Step [20/25], Loss: 0.3102
Epoch [9/1000], Step [10/25], Loss: 0.3041
Epoch [9/1000], Step [20/25], Loss: 0.3000
Epoch [10/1000], Step [10/25], Loss: 0.3676
Epoch [10/1000], Step [20/25], Loss: 0.2177
Epoch [11/1000], Step [10/25], Loss: 0.3556
Epoch [11/1000], Step [20/25], Loss: 0.2806
Epoch [12/1000], Step [10/25], Loss: 0.1376
Epoch [12/1000], Step [20/25], Loss: 0.3395
Epoch [13/1000], Step [10/25], Loss: 0.2662
...
Epoch [999/1000], Step [10/25], Loss: -0.0000
Epoch [999/1000], Step [20/25], Loss: -0.0000
Epoch [1000/1000], Step [10/25], Loss: -0.0000
Epoch [1000/1000], Step [20/25], Loss: -0.0000

5.4.GradCAMによる判断根拠の可視化

GradCAMを行うクラスを定義します。

class GradCAM:
    def __init__(self, model, feature_layer):
        self.model = model
        self.feature_layer = feature_layer
        self.model.eval()
        self.feature_grad = None
        self.feature_map = None
        self.hooks = []

        # 最終層逆伝播時の勾配を記録する
        def save_feature_grad(module, in_grad, out_grad):
            self.feature_grad = out_grad[0]
        self.hooks.append(self.feature_layer.register_backward_hook(save_feature_grad))

        # 最終層の出力 Feature Map を記録する
        def save_feature_map(module, inp, outp):
            self.feature_map = outp[0]
        self.hooks.append(self.feature_layer.register_forward_hook(save_feature_map))

    def forward(self, x):
        return self.model(x)

    def backward_on_target(self, output, target):
        self.model.zero_grad()
        one_hot_output = torch.zeros([1, output.size()[-1]])
        print(one_hot_output.shape)
        one_hot_output[0][target] = 1
        output.backward(gradient=one_hot_output, retain_graph=True)

    def clear_hook(self):
        for hook in self.hooks:
            hook.remove()

上記で定義したGradCAMクラスをインスタンス化し、cam値を算出します。

gradcam = GradCAM(cls, cls.conv2)
model_output = gradcam.forward(dataset.X[0].unsqueeze(0))
target = model_output.argmax(1).item()

gradcam.backward_on_target(model_output, target)

# Get feature gradient
feature_grad = gradcam.feature_grad.data.numpy()[0]
# Get weights from gradient
weights = np.mean(feature_grad, axis=(1))  # Take averages for each gradient
# Get features outputs
feature_map = gradcam.feature_map.data.numpy()
gradcam.clear_hook()

# Get cam
cam = np.sum((weights * feature_map.T), axis=1).T
cam = np.maximum(cam, 0)  # apply ReLU to cam

#リサイズ
cam = np.interp(np.linspace(0,1, 100), np.linspace(0,1,47), cam)
#正規化
cam = (cam - cam.min()) / (cam.max() - cam.min())
cam

GradCAMの実装はこちらを参考にしました。 tech.jxpress.net

cam値が求まったので、波形とcam値を重ねてみます。

#データ
data_ = dataset.X.reshape(-1, 100)
#データ番号
case = 101
fig, ax = plt.subplots()
# カラーマップを作成
cmap = plt.get_cmap('Blues')

for i in range(100):
    ax.axvspan(xmin=i, xmax=i+1, color=cmap(cam[i]), alpha=0.25)

plt.plot(np.linspace(0, 100, 100), data_[case])
plt.show()

Self-attentionモデルとは異なり、sin波、cos波に依らず、sin波の山と谷の位置に注目して判断していることが伺えます。CNNモデルでは注目する位置は変わっていませんが、Self-Attentionモデルは波形によって注目する位置が変わっていて、面白いですね。

6.まとめ

今回、Self-AttentionモデルとCNNモデルを作成し、sin波とcos波の分類を行い、その判断根拠の可視化と両モデルの判断根拠を比較してみました。Self-Attentionモデルでは学習の中で山と谷の位置に関係があるという情報を自動で獲得していて非常に興味深かったです。

7.参考