前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >LightGCN模型部分代码解读

LightGCN模型部分代码解读

作者头像
秋枫学习笔记
发布2022-09-19 11:24:55
9750
发布2022-09-19 11:24:55
举报
文章被收录于专栏:秋枫学习笔记

代码地址:

pytorch版:https://github.com/gusye1234/LightGCN-PyTorch

tensorflow版:https://github.com/kuandeng/LightGCN

本文对LightGCN模型部分的代码进行了解读,对相应部分进行了简单的注释帮助大家理解。笔者第一次尝试代码阅读分享,有什么不足之处或者建议可以给我留言哦,感谢。

Dropout

在图上实施dropout,以一定概率忽略一部分边

代码语言:javascript
复制
def __dropout_x(self, x, keep_prob):
        # 获取self.Graph中的大小,下标和值,Graph采用稀疏矩阵的表示方法SparseTensor
        size = x.size()
        index = x.indices().t()
        values = x.values()
        # 通过rand得到len(values)数量的随机数,加上keep_prob
        random_index = torch.rand(len(values)) + keep_prob
        # 通过对这些数字取int使得小于1的为0,在通过bool()将0->false,大于等于1的取True
        random_index = random_index.int().bool()
        # 利用上面得到的True,False数组选取下标,从而dropout了为False的下标
        index = index[random_index]
        # 由于dropout在训练和测试过程中的不一致,所以需要除以p
        values = values[random_index]/keep_prob
        # 得到新的graph
        g = torch.sparse.FloatTensor(index.t(), values, size)
        return g
    
    def __dropout(self, keep_prob):
        if self.A_split:
            graph = []
            for g in self.Graph:
                graph.append(self.__dropout_x(g, keep_prob))
        else:
            graph = self.__dropout_x(self.Graph, keep_prob)
        return graph

消息传播

computer函数是LightGCN类中用于进行图信息传播的实现方法,整体上通过在整个图上进行矩阵计算得到所有用户和商品的embedding。

代码语言:javascript
复制
def computer(self):
        """
        propagate methods for lightGCN
        """       
        # 得到所有用户和所有商品的embedding
        users_emb = self.embedding_user.weight
        items_emb = self.embedding_item.weight
        all_emb = torch.cat([users_emb, items_emb])
        # torch.split(all_emb , [self.num_users, self.num_items])
        embs = [all_emb]
        # 判断是否需要dropout
        if self.config['dropout']:
            if self.training:
                print("droping")
                g_droped = self.__dropout(self.keep_prob)
            else:
                g_droped = self.Graph 
        else:
            g_droped = self.Graph 
        # 根据层数对图进行信息传播和聚合考虑n-hop
        # 通过稀疏矩阵乘法对Graph进行n_layers次的计算
        for layer in range(self.n_layers):
            if self.A_split:
                temp_emb = []
                for f in range(len(g_droped)):
                    temp_emb.append(torch.sparse.mm(g_droped[f], all_emb))
                side_emb = torch.cat(temp_emb, dim=0)
                all_emb = side_emb
            else:
                all_emb = torch.sparse.mm(g_droped, all_emb)
            embs.append(all_emb)
        embs = torch.stack(embs, dim=1)
        #print(embs.size())
        # 对每一层得到的输出求均值,以此将不同层的信息进行融合
        light_out = torch.mean(embs, dim=1)
        users, items = torch.split(light_out, [self.num_users, self.num_items])
        return users, items

损失构建

在computer函数计算得到所有用户和商品经过消息传播后的embedding之后,getEmbedding根据当前用户和商品查询出需要用到的embedding以及当前用户和商品的原始embedding,即未经GCN的embedding。

传播后的embedding用于计算bpr损失,原始embedding用于计算L2正则项。

代码语言:javascript
复制
def getEmbedding(self, users, pos_items, neg_items):
        # 得到需要计算相似度的用户和商品的embedding
        all_users, all_items = self.computer()
        users_emb = all_users[users]
        pos_emb = all_items[pos_items]
        neg_emb = all_items[neg_items]
        # 没经过传播的embedding,用于后续正则项计算
        users_emb_ego = self.embedding_user(users)
        pos_emb_ego = self.embedding_item(pos_items)
        neg_emb_ego = self.embedding_item(neg_items)
        return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego
    
    def bpr_loss(self, users, pos, neg):
        (users_emb, pos_emb, neg_emb, 
        userEmb0, posEmb0, negEmb0) = self.getEmbedding(users.long(), pos.long(), neg.long())
        # 这个损失计算的是LightGCN论文中损失函数中的正则项,即做了一个L2正则
        reg_loss = (1/2)*(userEmb0.norm(2).pow(2) + 
                         posEmb0.norm(2).pow(2) +
                         negEmb0.norm(2).pow(2))/float(len(users))
        # 通过乘法计算用户和商品的相似度
        pos_scores = torch.mul(users_emb, pos_emb)
        pos_scores = torch.sum(pos_scores, dim=1)
        neg_scores = torch.mul(users_emb, neg_emb)
        neg_scores = torch.sum(neg_scores, dim=1)
        # pair-wise的排序损失
        loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-12-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 秋枫学习笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • pytorch版:https://github.com/gusye1234/LightGCN-PyTorch
  • 本文对LightGCN模型部分的代码进行了解读,对相应部分进行了简单的注释帮助大家理解。笔者第一次尝试代码阅读分享,有什么不足之处或者建议可以给我留言哦,感谢。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档