GCN(Graph Convolutional Networks)を使用して化合物の物性値を予測してみた
0.はじめに
化合物の物性予測の際には、記述子を用いたり、フィンガープリントを用いたり等アプローチがいくつかあります。 今回、化合物の構造データ(グラフデータ)から畳み込み演算により特徴を抽出しながらモデルを構築するGCNN(Graph Convolutional Neural Network)を使用して化合物の水溶解度予測の実装をしてみました。 実装には、Pytorch Geometricを使用しました。インストール方法等に関しての説明は割愛させて頂きます。
masao-ds.hatenablog.com こちらの記事では記述子を使用して同様の化合物データの水溶解度予測をしています。
1.やりたいこと
化合物の構造データ(グラフデータ)を入力してその化合物の水溶解度を予測するモデルを構築します。
2.使用するデータセット及び形式
こちら の水溶解度データを使用します。ダウンロードしたデータを下記のようなcsv形式に変換します。
No | SMILES | logS |
---|---|---|
sample_1 | CC(N)=O | 1.58 |
sample_2 | CNN | 1.34 |
sample_3 | CC(=O)O | 1.22 |
sample_4 | C1CCNC1 | 1.15 |
sample_5 | NC(=O)NO | 1.12 |
3.GNN(Graph Neural Network)、GCN(Graph Convolutional Networks)
3.1 GNN(Graph Neural Network)
グラフニューラルネットワーク(Graph Neural Network, GNN)は、グラフ構造を持つデータに対して効果的な機械学習手法です。グラフはノード(頂点)とエッジ(辺)からなるデータ表現であり、GNNはこれらのデータを入力として受け取り、ノードやエッジ間の相互作用をモデル化します。今回扱う化合物データをノードを原子、エッジを化学結合とみなすとグラフデータに相当します。
GNNでは、各ノードの数値化された特徴表現を用いて、グラフ分類、ノード分類、リンク予測などを解決するための予測や分類が行えます。
今回扱う化合物データの物性予測では1つのグラフ構造(化合物)に1つの値(物性値)が紐づき、それを回帰する問題のためグラフ分類(回帰)問題になります。
GNNの利点の一つは、グラフ構造の情報を効果的に抽出できることです。GNNは、ノードやエッジの局所的な特徴だけでなく、ネットワーク全体の構造や相互作用を次節で示すGCN(Graph Convolutional Networks)で考慮することができます。そのため、グラフデータにおいては、ノード間の依存関係や隣接関係を考慮する必要があるさまざまなタスクにおいて、GNNは優れた性能を発揮することが期待できます。
3.2 GCN(Graph Convolutional Networks)
GNN(Graph Neural Network)では、グラフ畳み込み(Graph Convolution)によりグラフデータの特徴を抽出します。
グラフ畳み込み(Graph Convolution)は、グラフデータに対して畳み込み操作を適用する手法です。通常の畳み込みは、画像などのグリッド状のデータに対して行われますが、グラフ畳み込みは非ユニフォームなグラフ構造に対しても適用することができます。
以下にグラフ畳み込みによる特徴抽出の流れを示します。
- ノードの特徴表現の更新: グラフ畳み込みでは、各ノードの特徴表現を更新するために近傍のノードとの情報を利用します。各ノードは初期的に特徴ベクトルを持ち、周囲のノードとの情報を統合することで新しい特徴表現を計算します。
- 近傍ノードの集約: グラフ畳み込みでは、各ノードの近傍ノード(隣接するノード)との情報を集約します。この集約は、隣接するノードの特徴表現を重み付けして統合することで行われます。一般的な手法として、隣接ノードの特徴を線形結合したり、メッセージパッシングと呼ばれる手法を用いて情報を伝播させる方法があります。
- フィルタリング: グラフ畳み込みでは、集約された情報に対してフィルタリング操作を適用します。フィルタリングは、ノードの特徴表現を調整するために使用されます。一般的には、フィルタリング操作には、パラメータ化された重み行列を使用します。
これらの手順を複数の階層で反復的に行うことで、グラフ畳み込みは特徴表現の階層的な抽象化を実現します。情報は近傍ノードを介して伝播し、隣接関係やグラフの構造を考慮しながら特徴表現が更新されます。
Pytorch Geometricには複数のグラフ畳み込みレイヤーが実装されていますが、今回はGCNConv()レイヤーを使用します。
GCNConvレイヤーは下記の論文の内容が実装されているようです。
https://arxiv.org/pdf/1609.02907.pdf
また、下記の記事では上記の論文で提案されているGCNConvレイヤーについて簡単にまとめられています。
qiita.com
GCNConvレイヤーでは、ノード(原子)の情報とノード同士の隣接関係を使用して畳み込み演算を行い、グラフ構造の特徴を抽出するシンプルなグラフ畳み込みです。
そのほかにも、エッジに重みをつけて畳み込みする(ノードの結合の強さを考慮して学習する)等のGCNレイヤーが提案されており、Pytorch Geometricでも実装されています。
4. 化合物の水溶解度予測モデルの構築
以下の流れで水溶解度予測モデルの学習をしていきます。
- データセットの作成:SMILESデータをMolオブジェクトに変換し、ノード情報及びノードの隣接関係を数値化してPytorch Gemotricで扱える形に変換する。
- GCNN(Graph Convolutional Neural Network)の学習:GCNNを定義し、学習を行う。
4.1 ライブラリのインポート
import numpy as np import pandas as pd from rdkit import Chem import torch from torch_geometric.data import Dataset, Data from torch.nn.functional import one_hot import os import seaborn as sns sns.set()
4.2 データセットの作成
PyTorch Geometricのグラフデータを作成するためのクラスtorch_geometric.data.Datasetを使用してSMILESデータをグラフデータに変換します。
pytorch-geometric.readthedocs.io 上記がPyTorch Geometricが提供しているDatasetクラスのソースコードです。
@property def raw_file_names(self) -> Union[str, List[str], Tuple]: r"""The name of the files in the :obj:`self.raw_dir` folder that must be present in order to skip downloading.""" raise NotImplementedError @property def processed_file_names(self) -> Union[str, List[str], Tuple]: r"""The name of the files in the :obj:`self.processed_dir` folder that must be present in order to skip processing.""" raise NotImplementedError def download(self): r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" raise NotImplementedError def process(self): r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" raise NotImplementedError @abstractmethod def len(self) -> int: r"""Returns the number of graphs stored in the dataset.""" raise NotImplementedError @abstractmethod def get(self, idx: int) -> BaseData: r"""Gets the data object at index :obj:`idx`.""" raise NotImplementedError
このあたりの関数には、何も定義されていませんので、今回扱う問題に対応できるようにオーバーライドしてカスタマイズします。
class MoleculeDataset(Dataset): def __init__(self, root, filename, test=False, transform=None, pre_transform=None): #データを読み込むファイル名を指定する。 self.filename = filename super(MoleculeDataset, self).__init__(root, transform, pre_transform) @property def raw_file_names(self): return self.filename @property def processed_file_names(self): self.data = pd.read_csv(self.raw_paths[0], index_col=0).reset_index() #グラフデータに変換したものをバイナリ化して格納するファイル名を指定する。 return [f'data_{i}.pt' for i in list(self.data.index)] def download(self): pass def len(self): return self.data.shape[0] #インスタンス化した後に各グラフデータ(化合物)を呼び出す際に使用する。 #インスタンス[idx]の形式で各グラフデータを呼び出せる。(特殊メソッド__getitem__の中にgeメソッドがあるため)) def get(self, idx): idx = idx + 1 data = torch.load(os.path.join(self.processed_dir, f'data_sample_{idx}.pt')) return data #分子グラフを作成する関数を記述する。 def process(self): #データをcsvファイルから読み込む。 self.data = pd.read_csv(self.raw_paths[0], index_col=0) #一つずつSMILESデータを読み込んで、Molオブジェクトに変換し、ノード情報及び隣接関係を数値化する。 for index, mol in self.data.iterrows(): mol_obj = Chem.MolFromSmiles(mol[0]) #ノード(原子)情報を数値化 node_feats = self._get_node_features(mol_obj) #エッジ(化学結合)の情報を数値化 edge_feats = self._get_edge_features(mol_obj) #原子の隣接関係を数値化 edge_index = self._get_adjacency_info(mol_obj) #教師ラベル label = self._get_labels(mol[1]) structure_id = [[mol[0]]] #Dataクラスをインスタン化してノード情報、エッジ情報、隣接関係をまとめてグラフデータ化する。 data = Data(x=node_feats, edge_index=edge_index, edge_attr=edge_feats, y=label, structure_id=structure_id) #グラフデータを保存する。 torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt'))
processメソッドの中でcsvファイル(./dataset/raw/に配置)を読み込んでノード情報、エッジ情報、隣接関係を数値化して化合物ごとにデータを作成し、バイナリ化してデータを格納します。 processメソッドはインスタンス化する際に呼び出され、インスタンス化した時点でグラフデータが揃います。 また、processメソッド内で使用するノード情報量等を算出する関数を同一クラス内で定義しておきます。 ノード情報量や隣接関係等はrdkitを使用します。
def _get_node_features(self, mol): all_node_feats = [] for atom in mol.GetAtoms(): node_feats = [] node_feats = one_hot(torch.tensor(atom.GetAtomicNum()-1), num_classes=113) node_feats = node_feats.tolist() node_feats.append(atom.GetDegree()) node_feats.append(atom.GetFormalCharge()) node_feats.append(atom.GetHybridization()) node_feats.append(atom.GetIsAromatic()) node_feats.append(atom.GetTotalNumHs()) node_feats.append(atom.GetNumRadicalElectrons()) node_feats.append(atom.IsInRing()) node_feats.append(atom.GetChiralTag()) all_node_feats.append(node_feats) all_node_feats = np.asarray(all_node_feats) return torch.tensor(all_node_feats, dtype=torch.float) def _get_edge_features(self, mol): all_edge_feats = [] for bond in mol.GetBonds(): edge_feats = [] edge_feats.append(torch.tensor(bond.GetBondTypeAsDouble())) edge_feats.append(bond.IsInRing()) all_edge_feats += [edge_feats, edge_feats] all_edge_feats = np.asarray(all_edge_feats) return torch.tensor(all_edge_feats, dtype=torch.float) def _get_adjacency_info(self, mol): edge_indices = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_indices += [[i, j], [j,i]] edge_indices = torch.tensor(edge_indices) edge_indices = edge_indices.t().to(torch.long).view(2, -1) return edge_indices def _get_labels(self, label): label = np.asarray([label]) label = torch.tensor(label, dtype=torch.float32) return torch.reshape(label, (1,1))
ここまででデータセット作成クラスの定義は完了です。以下のコードでデータセットをインスタンス化します。
dataset = MoleculeDataset(root='./dataset/', filename="molecules.csv")
下記のようにコンソールに表示されれば完了です。
Processing... Done!
また、下記のようにインスタンスに対して
dataset[0]
とすると、インスタンス内の[]で指定したインデックスの化合物データのグラフデータが得られます。
Data(x=[4, 121], edge_index=[2, 6], edge_attr=[6, 2], y=[1, 1], structure_id=[1])
4.3 学習用データへの変換
from torch_geometric.loader import DataLoader dataset = dataset.shuffle() #学習データとテストデータに分割する。 from sklearn.model_selection import train_test_split dataset_train, dataset_test = train_test_split(dataset, test_size=0.2) #バッチサイズ batch_size = 64 loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True) loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)
4.4 GCNN(Graph Convolutional Neural Network)の定義
2層のグラフ畳み込み層を設定しました。畳み込みの後はglobal_mean_poolで固定長に変化し、Affine繋げます。
import torch import torch.nn as nn from torch_geometric.nn import GCNConv from torch_geometric.nn import global_mean_pool #ノードの隠れ層の数 n_h = 64 class GCN(nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_node_features, n_h) self.conv2 = GCNConv(n_h, n_h) self.conv3 = GCNConv(n_h, n_h) self.fc1 = nn.Linear(n_h, n_h//2) self.fc2 = nn.Linear(n_h//2, 1) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.5) def forward(self, data): x = data.x edge_index = data.edge_index batch = data.batch x = self.conv1(x, edge_index) x = self.relu(x) x = self.conv2(x, edge_index) x = self.relu(x) x = self.conv3(x, edge_index) x = global_mean_pool(x, batch) x = self.dropout(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x #ネットワークのインスタンス化 net = GCN()
4.5 GCNN(Graph Convolutional Neural Network)の学習
100epochの学習を行い、その際のロスのデータを保存します。
from torch import optim def eval(loader, loss_fn): for data in loader: out = net(data) loss = loss_fn(out, data.y) return loss loss_fn = nn.L1Loss() record_loss_test = [] record_loss_train = [] optimizer = optim.AdamW(net.parameters()) for epoch in range(100): #学習モード net.train() for data in loader_train: optimizer.zero_grad() out = net(data) loss = loss_fn(out, data.y) loss.backward() optimizer.step() #評価モード net.eval() loss_train = eval(loader_train, loss_fn) loss_test = eval(loader_test, loss_fn) record_loss_train.append(loss_train.item()) record_loss_test.append(loss_test.item()) if (epoch+1)%10==0: print('Epoch:', epoch+1, "MAE_train:", loss_train.item(), "MAE_test:", loss_test.item())
学習完了後、学習曲線を描画します。
import matplotlib.pyplot as plt plt.plot(range(len(record_loss_train)), record_loss_train, label="Train") plt.plot(range(len(record_loss_test)), record_loss_test, label="Test") plt.legend() plt.xlabel('Epochs') plt.ylabel('Loss') plt.show()
4.6 予測精度の確認
バリデーションデータに対する予測を行い、実測値と予測値をプロットして精度を確認してみます。
net.eval() # ネットワークを推論モードに切り替える predictions = [] labels = [] with torch.no_grad(): for data in loader_test: label = data.y output = net(data) # ネットワークにテストデータを入力して予測結果を取得 labels.append(label) predictions.append(output) predictions = torch.cat(predictions, dim=0) # 予測結果を結合して1つのテンソルにする labels = torch.cat(labels, dim=0) plt.scatter(labels.float(), predictions.float()) plt.xlabel('Measured') plt.ylabel('Predicted') plt.show()
それなりの精度で予測できていそうです。
5 最後に
今回、GCN(Graph Convolutional Networks)を使用して化合物データの水溶解予測をしてみました。 そのほかのGNNも色々調べて、実装していきたいと思います。 今期からMI(マテリアルインフォマティクス)テーマに参画させてもらっているので、MI関連技術調査に力を入れていきたいです。