前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >图神经网络:分子可溶性预测

图神经网络:分子可溶性预测

作者头像
Tom2Code
发布2024-01-10 15:44:56
1730
发布2024-01-10 15:44:56
举报
文章被收录于专栏:TomTom

首先介绍一下数据集的来源:

https://arxiv.org/abs/1703.00564

也是torch_geometric自带的一个数据集,专门用于图神经网络入门的开胃小菜。

之前帮一位老师处理过一些数据,所以了解到了smiles分子式(大概就张这个样子):

OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O

今天的数据集介绍:ESOL 是一个小型数据集,包含 1128 种化合物的水溶性数据。该数据集可用于训练基于化学分子结构(在 SMILES 字符串编码)来预测溶解度的模型,这些数据不含原子的 3D 坐标。

废话不多说,我们直接开始今天的代码:

首先加载数据和描述数据:

代码语言:javascript
复制
from torch_geometric.datasets import MoleculeNet

data = MoleculeNet(root="D:/data/", name="ESOL")
print(data)

print("Dataset type: ", type(data))
print("Dataset features:", data.num_features)
print("Dataset target:", data.num_classes)
print("Dataset sample:", data[0])
print("Dataset Size:", len(data))

打印一下数据的shape:

代码语言:javascript
复制
print(data[0].x)
print(data[0].edge_index.T)
print(data[0].y)

print("Dataset sample smiles:", data[0]["smiles"])

输出:

然后就是搭建我们的图神经网络:

代码语言:javascript
复制
import torch
import torch.nn as nn
from torch.nn import Linear

from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool as gmp, global_max_pool as gap
embedding_size = 64

class GCN(nn.Module):
    def __init__(self):
        # Initialization
        super().__init__()
        torch.manual_seed(42)

        # GCN Layers,第1层网络将特征转化为embedding
        self.initial_conv = GCNConv(data.num_features, embedding_size)
        self.conv1 = GCNConv(embedding_size, embedding_size)
        self.conv2 = GCNConv(embedding_size, embedding_size)
        self.conv3 = GCNConv(embedding_size, embedding_size)

        # output layer,输出是一个线性层,将128维向量转成预测值,该值是个标量,维度和y相同,X2是因为下面要cat两个pooling
        self.out = Linear(embedding_size*2, 1)

    def forward(self, x , edge_index, batch_index):
        # 1st Conv layer
        hidden = self.initial_conv(x, edge_index)
        hidden = torch.tanh(hidden)

        # Other layers
        hidden = self.conv1(hidden, edge_index)
        hidden = torch.tanh(hidden)
        hidden = self.conv2(hidden, edge_index)
        hidden = torch.tanh(hidden)
        hidden = self.conv3(hidden, edge_index)
        hidden = torch.tanh(hidden)

        # Global pooling
        hidden = torch.cat((gmp(hidden, batch_index), gap(hidden, batch_index)), dim=1)

        out = self.out(hidden)

        return out, hidden

model = GCN()
print(model)

模型的结构:

定义训练集和测试集(80%和20%):

代码语言:javascript
复制
from torch_geometric.loader import DataLoader

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)

data_size = len(data)
batch_size = 64
train_loader = DataLoader(data[:int(data_size*0.8)], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data[int(data_size*0.8):], batch_size=batch_size)

开始训练:

代码语言:javascript
复制
def train(data):
    for batch in train_loader:
        optimizer.zero_grad()

        pred, embeding = model(batch.x.float(), batch.edge_index, batch.batch)
        loss = torch.sqrt(loss_fn(pred, batch.y))

        loss.backward()
        optimizer.step()

    return loss, embeding

print("Start training ...")
losses = []
for epoch in range(1001):
    loss, h = train(data)
    losses.append(loss)
    if epoch%100 == 0:
        print(f"Epoch {epoch} | Training loss {loss}")

训练过程:

画图:

代码语言:javascript
复制
import matplotlib.pyplot as plt

losses_float = [float(loss) for loss in losses]
loss_indices = [i for i, l in enumerate(losses_float)]

plt.plot(loss_indices, losses_float)
plt.show()

简单的进行验证一下:

代码语言:javascript
复制
model.eval()
with torch.no_grad():
    for batch in test_loader:
        pred, _ = model(batch.x.float(), batch.edge_index, batch.batch)
        val_loss = torch.sqrt(loss_fn(pred, batch.y))
        # calculate other metrics if needed
        print(f"ground truth:{batch.y[0]}**pred:{pred[0]}")
        

输出:

可以看到效果还是不错的。

想要完整代码可以找Tom

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-01-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Tom的小院 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档