JTNN
JTNN :Junction Tree Variational Autoencoder for Molecular Graph Generation
JTNN使用联合树算法从分子图形成一棵树。然后,模型会将树和图编码为两个单独的向量z_G和z_T。
JTNN是一种自动编码器模型,旨在学习分子图的隐藏表示。这些表示可用于下游任务,例如属性预测或分子优化。
基于JTNN可视化给定分子的邻居分子
import torchfrom torch.utils.data importDataLoader, Subset
import argparsefrom dgl import model_zooimport rdkitfrom dgl.examples.pytorch.model_zoo.chem.generative_models import jtnnfrom dgl.model_zoo.chem.jtnn importJTNNDataset, cuda, JTNNCollatorfrom rdkit.ChemimportMolFromSmiles, MolToSmilesfrom rdkit.ChemimportDraw使用Dataloader数据批处理
def worker_init_fn(id_):    lg = rdkit.RDLogger.logger()    lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
dataset = JTNNDataset(data="test", vocab="vocab", training=False)vocab_file = dataset.vocab_file
hidden_size = 450latent_size = 56depth = 3
model = model_zoo.chem.load_pretrained("JTNN_ZINC")model = model.cuda()
print("Model #Params: %dK" %      (sum([x.nelement() for x in model.parameters()]) / 1000,))
MAX_EPOCH = 100PRINT_ITER = 20#添加噪声def reconstruct(idx):    dataset.training = False    dataloader = DataLoader(        Subset(dataset, [idx]),        batch_size=1,        shuffle=False,        num_workers=0,        collate_fn=JTNNCollator(dataset.vocab, False),        drop_last=True,        worker_init_fn=worker_init_fn)
    # Just an example of molecule decoding; in reality you may want to sample    # tree and molecule vectors.    acc = 0.0    tot = 0    print(len(dataset))    for it, batch in enumerate(dataloader):        gt_smiles = batch['mol_trees'][0].smiles        # print(gt_smiles)        model.move_to_cuda(batch)        _, tree_vec, mol_vec = model.encode(batch)        tree_mean = model.T_mean(tree_vec)        # Following Mueller et al.        tree_log_var = -torch.abs(model.T_var(tree_vec))        mol_mean = model.G_mean(mol_vec)        # Following Mueller et al.        mol_log_var = -torch.abs(model.G_var(mol_vec))
        epsilon = torch.randn(1, model.latent_size // 2).cuda()        tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon        epsilon = torch.randn(1, model.latent_size // 2).cuda()        mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
        xy_range = torch.linspace(-2, 2, 5)        mean_noise, var_noise = torch.zeros_like(tree_vec), torch.ones_like(tree_vec)        noise1 = torch.normal(mean_noise, var_noise)        noise2 = torch.normal(mean_noise, var_noise)        ms = []        for i in range(5):            for j in range(5):                noise = noise1 * xy_range[i] + noise2 * xy_range[j]                s = model.decode(tree_vec + noise, tree_vec + noise)                ms += [s]        return ms, gt_smiles获取给定分子
ms, smiles=reconstruct(3)5000
获取邻居分子
ms_draw=[MolFromSmiles(s) for s in set(ms) if s is not None]smiles'CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c2CCCCC3)C1'
绘制给定分子
img = Draw.MolsToGridImage([MolFromSmiles(smiles)],molsPerRow=5,subImgSize=(250,150))img绘制邻居分子
img = Draw.MolsToGridImage(ms_draw,molsPerRow=5,subImgSize=(250,150))img参考资料
作者&编辑丨王建民
审稿 | 李牧非
研究方向丨药物设计、生物医药大数据