前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何基于Paddle快速训练一个98%准确率的抑郁文本预测模型?

如何基于Paddle快速训练一个98%准确率的抑郁文本预测模型?

作者头像
abs_zero
发布2020-11-26 15:47:25
9680
发布2020-11-26 15:47:25
举报
文章被收录于专栏:AI派

Paddle是一个比较高级的深度学习开发框架,其内置了许多方便的计算单元可供使用。

本文将讲解如何使用paddle训练、测试、推断自己的数据。

1.准备

开始之前,你要确保Python和pip已经成功安装在电脑上噢。

Windows环境下打开Cmd(开始—运行—CMD),苹果系统环境下请打开Terminal(command+空格输入Terminal),准备开始输入命令安装依赖。

然后,我们需要安装百度的paddlepaddle, 进入他们的官方网站就有详细的指引: https://www.paddlepaddle.org.cn/install/quick

根据你自己的情况选择这些选项,最后一个CUDA版本,由于本实验不需要训练数据,也不需要太大的计算量,所以直接选择CPU版本即可。选择完毕,下方会出现安装指引,不得不说,Paddlepaddle这些方面做的还是比较贴心的(就是名字起的不好)

要注意,如果你的Python3环境变量里的程序名称是Python,记得将python3 xxx 语句改为Python xxx 如下进行安装:

代码语言:javascript
复制
python -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple

然后你还需要安装paddlehub:

代码语言:javascript
复制
pip install -i https://mirror.baidu.com/pypi/simple paddlehub

2. 数据预处理

这次实验,我使用了8000条走饭下面的评论和8000条其他微博的正常评论作为训练集,两个分类分别使用1000条数据作为测试集。

2.1 去重去脏

在这一步,我们需要先去除重复数据,并使用正则表达式@.* 和 ^@.*\n 去除微博@的脏数据。如果你是使用Vscode的,可以使用sort lines插件去除重复数据:

如果不是Vscode,请用Python写一个脚本,遍历文件,将每一行放入集合中进行去重。比较简单,这里不赘述啦。

正则表达式去除脏数据,我这里数据量比较少,直接编辑器解决了:

2.2 分词

首先,需要对我们的文本数据进行分词,这里我们采用结巴分词的形式进行:

然后需要在分词的结果后面使用\t隔开加入标签,我这里是将有抑郁倾向的句子标为0,将正常的句子标为1. 此外,还需要将所有词语保存起来形成词典文件,每个词为一行。

并分别将训练集和测试集保存为 train.tsv 和 dev.tsv, 词典文件命名为word_dict.txt, 方便用于后续的训练。

3.训练

下载完Paddle模型源代码后,进入 models/PaddleNLP/sentiment_classification文件夹下,这里是情感文本分类的源代码部分。

在开始训练前,你需要做以下工作:

1. 将train.tsv、dev.tsv及word_dict.txt放入senta_data文件夹.

2. 设置senta_config.json的模型类型,我这里使用的是gru_net:

3. 修改run.sh相关的设置:

如果你的paddle是CPU版本的,请把use_cuda改为false。此外还有一个save_steps要修改,代表每训练多少次保存一次模型,还可以修改一下训练代数epoch,和 一次训练的样本数目 batch_size.

4. 如果你是windows系统,还要新建一个save_models文件夹,然后在里面分别以你的每训练多少次保存一次的数字再新建文件夹。。没错,这可能是因为他们开发这个框架的时候是基于linux的,他们写的保存语句在linux下会自动生成文件夹,但是windows里不会。

现在可以开始训练了,由于训练启动脚本是shell脚本,因此我们要用powershell或git bash运行指令,Vscode中可以选择默认的终端,点击Select Default Shell后选择一个除cmd外的终端即可。

输入以下语句开始训练

代码语言:javascript
复制
$ sh run.sh train

4.测试

恭喜你走到了这一步,作为奖励,这一步你只需要做两个操作。首先是将run.sh里的MODEL_PATH修改为你刚保存的模型文件夹:

我这里最后一次训练保存的文件夹是step_1200,因此填入step_1200,要依据自己的情况填入。然后一句命令就够了:

代码语言:javascript
复制
$ sh run.sh eval

可以看到我的模型准确率大概有98%,还是挺不错的。

5.预测

我们随意各取10条抑郁言论和普通言论,命名为test.txt存入senta_data文件夹中,输入以下命令进行预测:

代码语言:javascript
复制
$ sh run.sh test

这二十条句子如下,前十条是抑郁言论,后十条是普通言论:

代码语言:javascript
复制
好崩溃每天都是折磨真的生不如死
姐姐   我可以去找你吗
内心阴暗至极……
大家今晚都是因为什么没睡
既然儿子那么好     那就别生下我啊     生下我又把我扔下     让我自生自灭     这算什么
走饭小姐姐怎么办我该怎么办每天都心酸心如刀绞每天都有想要死掉的念头我不想那么痛苦了
你凭什么那么轻松就说出这种话
一闭上眼睛脑子里浮现的就是他的脸和他的各种点点滴滴好难受睡不着啊好难受为什么吃了这么多东西还是不快乐呢
以前我看到那些有手有脚的人在乞讨我都看不起他们   我觉得他们有手有脚的不应该乞讨他们完全可以凭自己的双手挣钱   但是现在我有手有脚我也想去人多的地方乞讨…我不想努力了…
熬过来吧求求你了好吗
是在说我们合肥吗?
这歌可以啊
用一个更坏的消息掩盖这一个坏消息
请尊重他人隐私这种行为必须严惩不贷
这个要转发
??保佑咱们国家各个省千万别再有出事的也别瞒报大家一定要好好的坚持到最后加油
我在家比在学校有钱   在家吃饭零食水果奶都是我妈天天给我买   每天各种水果   还可以压榨我弟跑腿   买衣服也是   水乳也是   除了化妆品反正现在也用不上   比学校的日子过得好多了
广西好看的是柳州的满城紫荆花
加油一起共同度过这次难关我们可以
平安平安老天保佑

得到结果如下:

Final test result: 0 0.999999 0.000001 0 0.994013 0.005987 0 0.997636 0.002364 0 0.999975 0.000025 0 1.000000 0.000000 0 1.000000 0.000000 0 0.999757 0.000243 0 0.999706 0.000294 0 0.999995 0.000005 0 0.998472 0.001528 1 0.000051 0.999949 1 0.000230 0.999770 1 0.230227 0.769773 1 0.000000 1.000000 1 0.000809 0.999191 1 0.000001 0.999999 1 0.009213 0.990787 1 0.000003 0.999997 1 0.000363 0.999637 1 0.000000 1.000000

第一列是预测结果(0代表抑郁文本),第二列是预测为抑郁的可能性,第三列是预测为正常微博的可能性。可以看到,基本预测正确,而且根据这个分数值,我们还可以将文本的抑郁程度分为:轻度、中度、重度,如果是重度抑郁,应当加以干预,因为其很可能会发展成自杀倾向。

我们可以根据这个模型,构建一个自杀预测监控系统,一旦发现重度抑郁的文本迹象,即可实行干预,不过这不是我们能一下子做到的事情,需要随着时间推移慢慢改进这个识别算法,并和相关机构联动实行干预。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-11-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI派 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.准备
  • 2. 数据预处理
    • 2.1 去重去脏
      • 2.2 分词
      • 3.训练
      • 4.测试
      • 5.预测
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档