前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >对比学习Python实现

对比学习Python实现

作者头像
里克贝斯
发布2022-01-15 11:52:44
9350
发布2022-01-15 11:52:44
举报
文章被收录于专栏:图灵技术域图灵技术域

对比学习是一种通过对比正反两个例子来学习表征的自监督学习方法。对于自监督对比学习,下一个等式是对比损失:

\[ \mathcal{L}_{i,j} = - \log \frac{exp(\textbf{z}_i \cdot \textbf{z}_j / \tau)}{\sum_{k=1,k\neq i}^{2N}exp(\textbf{z}_i \cdot \textbf{z}_k / \tau)} \]
\[ \mathcal{L}_{i,j} = - \log \frac{exp(\textbf{z}_i \cdot \textbf{z}_j / \tau)}{\sum_{k=1,k\neq i}^{2N}exp(\textbf{z}_i \cdot \textbf{z}_k / \tau)} \]

在很多情况下,对比学习只需要对每一个样本生成一个正样本,同一个batch内的其他样本作为负样本,实现如下:

代码语言:txt
复制
def contrastive_loss(x, x_aug, T):
    """
    :param x: the hidden vectors of original data
    :param x_aug: the positive vector of the auged data
    :param T: temperature
    :return: loss
    """
    batch_size, _ = x.size()
    x_abs = x.norm(dim=1)
    x_aug_abs = x_aug.norm(dim=1)
 
    sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
    sim_matrix = torch.exp(sim_matrix / T)
    pos_sim = sim_matrix[range(batch_size), range(batch_size)]
    loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
    loss = - torch.log(loss).mean()
    return loss

如果要用生成的负样本进行对比,代码如下:

代码语言:txt
复制
def info_nce_loss(self, features):
    labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(self.args.device)
 
    features = F.normalize(features, dim=1)
 
    similarity_matrix = torch.matmul(features, features.T)
    # assert similarity_matrix.shape == (
    #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
    # assert similarity_matrix.shape == labels.shape
 
    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    # assert similarity_matrix.shape == labels.shape
 
    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
 
    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
 
    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
 
    logits = logits / self.args.temperature
    return logits, labels
 
self.criterion = torch.nn.CrossEntropyLoss()
loss = self.criterion(logits, labels)
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2022-01-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档