JTNN
JTNN :Junction Tree Variational Autoencoder for Molecular Graph Generation
JTNN使用联合树算法从分子图形成一棵树。然后,模型会将树和图编码为两个单独的向量z_G和z_T。
JTNN是一种自动编码器模型,旨在学习分子图的隐藏表示。这些表示可用于下游任务,例如属性预测或分子优化。
基于JTNN可视化给定分子的邻居分子
import torchfrom torch.utils.data importDataLoader, Subset
import argparsefrom dgl import model_zooimport rdkitfrom dgl.examples.pytorch.model_zoo.chem.generative_models import jtnnfrom dgl.model_zoo.chem.jtnn importJTNNDataset, cuda, JTNNCollatorfrom rdkit.ChemimportMolFromSmiles, MolToSmilesfrom rdkit.ChemimportDraw
使用Dataloader数据批处理
def worker_init_fn(id_): lg = rdkit.RDLogger.logger() lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
dataset = JTNNDataset(data="test", vocab="vocab", training=False)vocab_file = dataset.vocab_file
hidden_size = 450latent_size = 56depth = 3
model = model_zoo.chem.load_pretrained("JTNN_ZINC")model = model.cuda()
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
MAX_EPOCH = 100PRINT_ITER = 20
#添加噪声def reconstruct(idx): dataset.training = False dataloader = DataLoader( Subset(dataset, [idx]), batch_size=1, shuffle=False, num_workers=0, collate_fn=JTNNCollator(dataset.vocab, False), drop_last=True, worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample # tree and molecule vectors. acc = 0.0 tot = 0 print(len(dataset)) for it, batch in enumerate(dataloader): gt_smiles = batch['mol_trees'][0].smiles # print(gt_smiles) model.move_to_cuda(batch) _, tree_vec, mol_vec = model.encode(batch) tree_mean = model.T_mean(tree_vec) # Following Mueller et al. tree_log_var = -torch.abs(model.T_var(tree_vec)) mol_mean = model.G_mean(mol_vec) # Following Mueller et al. mol_log_var = -torch.abs(model.G_var(mol_vec))
epsilon = torch.randn(1, model.latent_size // 2).cuda() tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon epsilon = torch.randn(1, model.latent_size // 2).cuda() mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
xy_range = torch.linspace(-2, 2, 5) mean_noise, var_noise = torch.zeros_like(tree_vec), torch.ones_like(tree_vec) noise1 = torch.normal(mean_noise, var_noise) noise2 = torch.normal(mean_noise, var_noise) ms = [] for i in range(5): for j in range(5): noise = noise1 * xy_range[i] + noise2 * xy_range[j] s = model.decode(tree_vec + noise, tree_vec + noise) ms += [s] return ms, gt_smiles
获取给定分子
ms, smiles=reconstruct(3)
5000
获取邻居分子
ms_draw=[MolFromSmiles(s) for s in set(ms) if s is not None]smiles
'CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c2CCCCC3)C1'
绘制给定分子
img = Draw.MolsToGridImage([MolFromSmiles(smiles)],molsPerRow=5,subImgSize=(250,150))img
绘制邻居分子
img = Draw.MolsToGridImage(ms_draw,molsPerRow=5,subImgSize=(250,150))img
参考资料
作者&编辑丨王建民
审稿 | 李牧非
研究方向丨药物设计、生物医药大数据