引言
Deep Graph Library (DGL) 是一个在图上做深度学习的框架。在0.3.1版本中,DGL支持了基于PyTorch的化学模型库。如何生成分子图是我感兴趣的。
环境准备
分子生成与Junction Tree VAE
分子生成
候选药用化合物的数量估计为10 ^ {23} -10 ^ {60} ,但是合成所有这些化合物是不现实的,每年都会发现新的化合物。到目前为止,仅合成了大约10 ^ 8 。
设计新化合物,考虑其合成方法,在药物发现的过程中尝试实际合成的化合物需要大量的时间和金钱,故AI药物发现具有了原始动机。药物发现的的目标是产生对疾病有效的药物,副作用更少且易合成
Junction Tree VAE
JT-VAE (junction tree variational autoencoder)
JT-VAE同时考虑了分子的两种图表示:分子图和联合树。在分子图中,我们把原子作为节点,化学键作为边。在联合树中,我们将分子图中的一些子结构看作节点。”
基于DGL的分子图生成
导入库
import dglfrom dgl import model_zoofrom dgl.model_zoo.chem.jtnn import JTNNDataset, cuda, JTNNCollatorimport rdkitfrom rdkit import Chemfrom rdkit.Chem import Draw, MolFromSmiles, MolToSmilesimport torchfrom torch.utils.data import DataLoader, Subsetfrom tqdm.notebook import tqdm
数据预处理
dataset = JTNNDataset(data="test", vocab="vocab", training=False)dataset.training = False
载入数据
dataset.data = ['CN1C=NC2=C1C(=O)N(C(=O)N2C)C', 'CCN(CC)C(=O)C1CN(C2CC3=CNC4=CC=CC(=C34)C2=C1)C']
使用Dataloader批次化处理和获取数据
def worker_init_fn(id_): lg= rdkit.RDLogger.logger() lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)dataset.training = Falsedataloader = DataLoader( Subset(dataset, [0,1]), batch_size=1, shuffle=False, num_workers=0, collate_fn=JTNNCollator(dataset.vocab, False), drop_last=True, worker_init_fn=worker_init_fn)
可视化数据集中的数据
Draw.MolsToGridImage([MolFromSmiles(s) for s in dataset.data], molsPerRow=4,subImgSize=(250,150))
加载模型
JT-VAE的训练需要很长时间。DGL提供了预先训练好的模型供用户使用。
model = model_zoo.chem.load_pretrained('JTNN_ZINC')model = cuda(model)
分子表示的插值
首先,对应于咖啡因和麦角酸二乙酰胺的潜在变量ž小号Ť 甲ř Ť,žg ^Ø 一个大号
tree_vec[0]
和mol_vec[0]
咖啡因,tree_vec[1]
以及mol_vec[1]
麦角酸二乙酰胺。
tree_vecs, mol_vecs = [], []for batch in dataloader: model.move_to_cuda(batch) _, tree_vec, mol_vec = model.encode(batch) tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec) # reparam. trick tree_vecs.append(tree_vec) mol_vecs.append(mol_vec)
确定了起点和终点,对通过分割线获得的点进行顺序解码,解码的输出为SMILES
tree_diff = tree_vecs[1] - tree_vecs[0] mol_diff = mol_vecs[1] - mol_vecs[0]smiles = []num_mols = 100 tree_st, mol_st = tree_vecs[0], mol_vecs[0]
for i in tqdm(range(num_mols)): s = model.decode(tree_st+tree_diff/(num_mols-1)*i, mol_st+mol_diff/(num_mols-1)*i) smiles.append(s)
按顺序显示生成的100个分子
mols = []for s in smiles: if s is None: continue mol = MolFromSmiles(s) if mol is not None: mols.append(mol)Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(250,150))
连续输出相同的分子,插值不平滑; 终点未恢复为麦角酸二乙酰胺。
参考资料
作者&编辑丨王建民
研究方向丨药物设计、生物医药大数据