前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >博客 | 一次LDA的项目实战(附GibbsLDA++代码解读)

博客 | 一次LDA的项目实战(附GibbsLDA++代码解读)

作者头像
AI研习社
发布2018-12-25 16:27:12
1.1K0
发布2018-12-25 16:27:12
举报
文章被收录于专栏:AI研习社AI研习社

本文原载于知乎专栏“AI的怎怎,歪歪不喜欢”AI研习社经授权转载发布。欢迎关注 邹佳敏 的知乎专栏及 AI研习社博客专栏(文末可识别社区名片直达)。

深度学习是一项目标函数的拟合技术,在绝大多数场景中,它要求实践者拥有一份可靠的标注数据,作为目标函数的采样,这恰恰是最难的部分。尤其是NLP领域,每个人的受教育水平和对语言的理解均有不同,一份可靠的标注数据更是难上加难。

因此,在缺乏标注数据,无法使用深度学习,甚至是传统分类算法的前提上,似乎只能考虑无监督的聚类方法来达成业务目标。

曾经的我非常鄙视聚类算法,认为它不够稳定。每一次聚类结果的含义都会发生变化,需要人工确认语义信息,尤其是当目标类型数过多时,非常痛苦。但和标注X万篇语料数据,同时不确定标注是否可靠的情况相比,相信拥有完美数学逻辑的LDA,就成为了我的唯一选择。

事实证明,在某些条件下,LDA简直是NLP领域的聚类神器!

一, 项目背景:

财经类的“宏观”新闻分类:以市场流动性,政经制度和地缘政治为例。

条件1:要求的目标类型少,数量可控。

二,项目实施:

1,语料确定:放弃通用语料的尝试,将目光锁定在,门户网站财经栏目下的“宏观”页签。

条件2:干净纯粹的训练数据集,输入数据噪音小,便于在训练前,对K心中有数。

(比如:已知目标3类,K选定为45,人工观察后将45个topic_id映射至目标3类)

2,目标确定:结合业务背景,明确分类目标的业务含义。

2.1,市场流动性:市场货币投放,银行间利率升降,央行放水,钱荒等;

2.2,政经制度:国改,混改,土改,税改等政府发布的改革制度等;

2.3,地缘政治:军事,打仗等。

条件3:分类目标间的内涵独立。即类间耦合弱,类内耦合强的分类目标最优。

3,工具确定:分词工具加入业务长词或通用长词,保证对聚类结果的可识别,易区分。

条件4:业务词典或通用长词词典。构建通用长词词典的小技巧:将腾讯AILab的开源词嵌入模型的单词抠出,并筛选长词,效果不错。

三,GibbsLDA++代码解读

代码语言:javascript
复制
// 代码截取自:GibbsLDA++ from http://gibbslda.sourceforge.net/

// 核心代码解读:LDA算法代码
class model {
public:
    int M;       // 语料中的文章数
    int V;       // 语料中的单词数(去重)
    int K;       // LDA的主题数

    double alpha;  // 超参数1:文章m属于主题k的先验概率
    double beta;   // 超参数2:单词w属于主题k的先验概率
    int niters;   // LDA训练迭代次数

    double * p;   // 临时变量:每篇文章的每个单词,在每次采样时,分配到每个主题下的概率
    int ** z;     // size M x doc.size():语料中第m篇文章中,第n个单词,所属的主题id
    int ** nw;     // size V x K: 语料中第v个单词,属于第k个主题的单词计数(在整个语料中,每个单词在不同的文章出现)
    int ** nd;     // size M x K:语料中第m篇文章,属于第k个主题的单词计数(在一篇文章中,每个单词只属于一个主题)
    int * nwsum;   // size K:属于第k个主题的单词个数
    int * ndsum;   // size M:属于第m篇文章的单词个数

    // 隐层参数:M*V个参数 >> M*K+K*V个参数,降维的本质所在
    double ** theta;   // size M x K:文档-主题概率分布:语料中第m篇文章,属于第k个主题的概率
    double ** phi;     // size K x V:主题-单词概率分布:语料中第v个单词,属于第k个主题的概率public:
    // train初始化:加载输入语料,为每个单词随机选取一个主题id,并初始化z,nw,nd,nwsum和ndsum变量(统计计数的方式)和其他变量(置零)
    int init_est();

    // train核心逻辑:删除非核心代码,更清晰
    void estimate() {
        // 从第last_iter处,开始训练,兼容estc方法
        // 比如,目标迭代1000次,但在第100次后保存模型,后续可直接加载模型,从第101次开始训练)
        int last_iter = liter;
        // 迭代niters次。每次迭代,遍历全部语料(M篇文章,每篇文章length个单词)
        for (liter = last_iter + 1; liter <= niters + last_iter; liter++) {
            // 对第m篇文章的第n个单词,采样其所属的主题id,即z[m][n]
            for (int m = 0; m < M; m++) {
                for (int n = 0; n < ptrndata->docs[m]->length; n++) {
                    // 源码注释:LDA算法介绍中,通常使用z_i来代表z[m][n]
                    // (z_i = z[m][n]) sample from p(z_i|z_-i, w)

                    // !!! Gibbs采样的核心逻辑:为每篇文章的每个单词,迭代采样其属于的topic,即主题id
                    int topic = sampling(m, n);

                    // 更新z变量:LDA真正的模型输出
                    // 因为z变量可以将nd,nw,ndsum和nwsum都还原出来,而theta和phi又可以从nd,nw,ndsum和nwsum还原
                    z[m][n] = topic;
                }
            }
        }
        // 根据nd,ndsum和alpha,计算theta变量:文档-主题概率分布,无普适性,用于展示每篇文档的主题概率
        compute_theta();
        // 根据nw,nwsum和beta,计算phi变量:主题-单词概率分布,语料中每个单词所属的主题概率,有普适性,也可作为LDA模型输出
        compute_phi();
        // 保存模型:在GibbsLDA++代码中,最核心的是z变量,即*.tassign文件
        save_model(utils::generate_model_name(-1));
    }

    // Gibbs采样核心逻辑
    int sampling(int m, int n) {
        // remove z_i from the count variables
        int topic = z[m][n];
        int w = ptrndata->docs[m]->words[n];

        // 新的一轮采样前,自减上一次采样的计数
        nw[w][topic] -= 1;
        nd[m][topic] -= 1;
        nwsum[topic] -= 1;
        ndsum[m] -= 1;

        // 真正的采样逻辑
        double Vbeta = V * beta;
        double Kalpha = K * alpha;

        // 基于狄利克雷-多项分布的Gibbs采样,千言万语就化作这个简单的公式,数学真是神奇!
        // 建议参考资料:有先后顺序
        // 1,https://www.cnblogs.com/pinard/p/6831308.html
        // 2,LDA数学八卦.pdf
        for (int k = 0; k < K; k++) {
            // 神奇,神奇,神奇,神奇!
            p[k] = (nw[w][k] + beta) / (nwsum[k] + Vbeta) *
                   (nd[m][k] + alpha) / (ndsum[m] + Kalpha);
        }

        // 根据已计算出的p,随机挑出一个最优可能的主题id
        // 一种常见的方法,在word2vec的负采样中也有使用
        for (int k = 1; k < K; k++) {
            p[k] += p[k - 1];
        }
        // 通过画线段的方式,很容易理解:概率越大,所属的线段越长,越有可能被随机选中,但不绝对
        double u = ((double)random() / RAND_MAX) * p[K - 1];
        for (topic = 0; topic < K; topic++) {
            if (p[topic] > u) {
                break;
            }
        }

        // 基于最新一轮采样的结果,更新计数
        nw[w][topic] += 1;
        nd[m][topic] += 1;
        nwsum[topic] += 1;
        ndsum[m] += 1;

        return topic;
    }
};
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-12-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档