前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >一份半监督学习的指南-伪标签学习

一份半监督学习的指南-伪标签学习

作者头像
致Great
发布2020-12-16 17:56:14
1.3K0
发布2020-12-16 17:56:14
举报
文章被收录于专栏:程序生活程序生活

1 引言

在ML中,有3种机器学习方法-监督学习、无监督学习和强化学习技术。 我们所知道的监督学习是指数据带有标签的情况, 无监督学习是仅存在数据而没有标签的情况,强化学习算法的思路非常简单,以游戏为例,如果在游戏中采取某种策略可以取得较高的得分,那么就进一步“强化”这种策略,以期继续取得较好的结果。

想象一下这样一种情况,在训练中,标记数据的数量更少,而未标记数据的数量更多。 一种称为半监督学习( [Semi-Supervised Learning],SSL)的新技术,它是监督学习和非监督学习的混合体。 顾名思义,半监督学习中同时存在一组标记的训练数据和另一组未标记的训练数据。 我们可以将这种情况想像成Google图片或Facebook通过其面孔(数据)识别出图片中的人物并根据该人物先前存储的图像生成建议名称(标签)的情况。

在本文中,我们将讨论如何使用半监督学习技术生成伪标签。

2 Pseudo-Labelling 伪标签

伪标签是使用标记的数据模型预测未标记数据并进行标记的过程。 首先,模型已经训练了包含标签的数据集,该模型用于为未标记的数据集生成伪标签。 最后,将数据集和标签(原始标签和伪标签)组合在一起以进行最终模型训练。 之所以称为(意味着虚幻),是因为它们可能是真实标签,也可能不是真实标签,并且是通过我们基于类似的数据模型生成的标签。

该方法的主旨思想其实很简单。首先,在标签数据上训练模型,然后使用经过训练的模型来预测无标签数据的标签,从而创建伪标签。此外,将标签数据和新生成的伪标签数据结合起来作为新的训练数据。

3 Python 实现

在这个例子中,我们使用了sklearn中的breast cancer数据集。我们知道整个已经包含了标签,但我们要修改它,将数据分成两部分,一部分有标签,另一部分没有标签。我们将从经过训练的带标签数据模型中为未带标签的数据生成我们自己的标签,然后最后使用两者合并的数据集来训练最终的模型。

3.1 数据集

Breast cancer dataset是预测肿瘤是良性(B)还是恶性(M)的分类问题。前两列为1)id和2)diagnosis(标签):

代码语言:javascript
复制
a)radius_mean(从中心到外围点的距离的平均值)
b)texture_mean(灰度值的标准偏差)
c)perimeter_mean(周长)
d)area_mean(面积)
e)smoothness_mean(半径长度的局部变化)
f)compactness_mean(周长^ 2 /面积– 1.0)
g)concavity_mean(轮廓凹部的严重程度)
h) concave points_mean(轮廓的凹面部分的数量)

3.2 导入包

代码语言:javascript
复制
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier

3.3 加载数据集

代码语言:javascript
复制
X,y = load_breast_cancer(True)
X.shape
代码语言:javascript
复制
(569, 30)

3.4 分割数据集

代码语言:javascript
复制
x_train,x_test,y_train,_ = train_test_split(X,y,test_size=.6)
x_train.shape,y_train.shape,x_test.shape
代码语言:javascript
复制
((227, 30), (227,), (342, 30)

3.5 训练模型

代码语言:javascript
复制
model1 = RandomForestClassifier()
history = model1.fit(x_train,y_train)
history
代码语言:javascript
复制
RandomForestRegressor(bootstrap=True, ccp_alpha=0.0, criterion=’mse’,
max_depth=None, max_features=’auto’, max_leaf_nodes=None,
max_samples=None, min_impurity_decrease=0.0,
min_impurity_split=None, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
n_estimators=100, n_jobs=None, oob_score=False,
random_state=None, verbose=0, warm_start=False)

3.6 评分

代码语言:javascript
复制
model1.score(x_train,y_train)
代码语言:javascript
复制
1.0

3.7 预测

代码语言:javascript
复制
y_new = model1.predict(x_test)
y_new.shape
代码语言:javascript
复制
(342,)

合并数据集

代码语言:javascript
复制
final_X = np.concatenate((x_train,x_test))
final_X.shape
代码语言:javascript
复制
(569, 30)

合并原始标签与伪标签

代码语言:javascript
复制
final_Y = np.concatenate((y_train,y_test))
final_Y.shape
代码语言:javascript
复制
(569,)

基于合并的数据集训练最终模型

代码语言:javascript
复制
model2 = RandomForestRegressor()
model2.fit(final_X,final_Y)
model2.score(final_X,final_Y)
代码语言:javascript
复制
1.0

4 结论

伪标签的实现到此为止,大家可以根据自己的想法去比赛中尝试吧。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 引言
  • 2 Pseudo-Labelling 伪标签
  • 3 Python 实现
    • 3.1 数据集
      • 3.2 导入包
        • 3.3 加载数据集
          • 3.4 分割数据集
            • 3.5 训练模型
              • 3.6 评分
                • 3.7 预测
                  • 4 结论
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档