前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >神经网络高维互信息计算Python实现(MINE)

神经网络高维互信息计算Python实现(MINE)

作者头像
里克贝斯
发布2021-05-21 10:16:49
2K2
发布2021-05-21 10:16:49
举报
文章被收录于专栏:图灵技术域

论文

Belghazi, Mohamed Ishmael, et al. “Mutual information neural estimation.” International Conference on Machine Learning. 2018.

利用神经网络的梯度下降法可以实现快速高维连续随机变量之间互信息的估计,上述论文提出了Mutual Information Neural Estimator (MINE)。NN在维度和样本量上都是线性可伸缩的,MI的计算可以通过反向传播进行训练。


核心


Python实现

现有github上的代码无法计算和估计高维随机变量,只能计算一维随机变量,下面的代码给出的修改方案能够计算真实和估计高维随机变量的真实互信息。

其中,为了计算理论的真实互信息,我们不直接暴力求解矩阵(耗时,这也是为什么要有MINE的原因),我们采用给定生成随机变量的参数计算理论互信息。

SIGNAL_NOISE = 0.2 SIGNAL_POWER = 3

12

SIGNAL_NOISE = 0.2SIGNAL_POWER = 3

完整代码基于pytorch

Python

代码语言:javascript
复制
# Name: MINE_simple
# Author: Reacubeth
# Time: 2020/12/15 18:49
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt


SIGNAL_NOISE = 0.2
SIGNAL_POWER = 3

data_dim = 3
num_instances = 20000


def gen_x(num, dim):
    return np.random.normal(0., np.sqrt(SIGNAL_POWER), [num, dim])


def gen_y(x, num, dim):
    return x + np.random.normal(0., np.sqrt(SIGNAL_NOISE), [num, dim])


def true_mi(power, noise, dim):
    return dim * 0.5 * np.log2(1 + power/noise)


mi = true_mi(SIGNAL_POWER, SIGNAL_NOISE, data_dim)
print('True MI:', mi)


hidden_size = 10
n_epoch = 500


class MINE(nn.Module):
    def __init__(self, hidden_size=10):
        super(MINE, self).__init__()
        self.layers = nn.Sequential(nn.Linear(2 * data_dim, hidden_size),
                                    nn.ReLU(),
                                    nn.Linear(hidden_size, 1))

    def forward(self, x, y):
        batch_size = x.size(0)
        tiled_x = torch.cat([x, x, ], dim=0)
        idx = torch.randperm(batch_size)

        shuffled_y = y[idx]
        concat_y = torch.cat([y, shuffled_y], dim=0)
        inputs = torch.cat([tiled_x, concat_y], dim=1)
        logits = self.layers(inputs)

        pred_xy = logits[:batch_size]
        pred_x_y = logits[batch_size:]
        loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))))
        # compute loss, you'd better scale exp to bit
        return loss


model = MINE(hidden_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
plot_loss = []
all_mi = []
for epoch in tqdm(range(n_epoch)):
    x_sample = gen_x(num_instances, data_dim)
    y_sample = gen_y(x_sample, num_instances, data_dim)

    x_sample = torch.from_numpy(x_sample).float()
    y_sample = torch.from_numpy(y_sample).float()

    loss = model(x_sample, y_sample)

    model.zero_grad()
    loss.backward()
    optimizer.step()
    all_mi.append(-loss.item())


fig, ax = plt.subplots()
ax.plot(range(len(all_mi)), all_mi, label='MINE Estimate')
ax.plot([0, len(all_mi)], [mi, mi], label='True Mutual Information')
ax.set_xlabel('training steps')
ax.legend(loc='best')
plt.show()

结果

变量维度为1

变量维度为3

需要指出的是在计算最终的互信息时需要将基数e转为基数2。如果只是求得一个比较值,在真实使用的过程中可以省略。


参考

https://github.com/mzgubic/MINE

互信息公式及概述 列向量互信息计算通用MATLAB代码

相关文章

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-12-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 论文
  • 核心
  • Python实现
  • 参考
    • 相关文章
    相关产品与服务
    灰盒安全测试
    腾讯知识图谱(Tencent Knowledge Graph,TKG)是一个集成图数据库、图计算引擎和图可视化分析的一站式平台。支持抽取和融合异构数据,支持千亿级节点关系的存储和计算,支持规则匹配、机器学习、图嵌入等图数据挖掘算法,拥有丰富的图数据渲染和展现的可视化方案。
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档