专栏首页润风拂过存甘霖更少标注的机器学习方法——主动学习(python示例)

更少标注的机器学习方法——主动学习(python示例)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/blmoistawinde/article/details/84994719

准备工作

假设我们正要完成手写数字识别的任务。我们可以使用著名的mnist数据集来训练这样的机器学习模型。数字示例如下:

总共有1797个数字,每个数字使用一个64维的向量表示

效果检验

先来看看使用完全数据集训练的模型能够达到什么样的效果(这里暂不区分训练测试集)

from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(x,y)
clf.score(x,y)
0.993322203672788

然而,要想得到完全数据,我们需要标注1797个数字,这是比较吃力的。这还只是一个很小的玩具数据集,对于更大的数据集,数量级数以百万计,我们又怎能标的过来?

那么,我们能不能只标注一小部分数据,比如说3%(50个数字左右),让它也达到不错的效果?看看效果如何。

clf2 = LogisticRegression()
chosen_ids = np.random.choice(range(len(x)),50,replace=False)
clf2.fit(x[chosen_ids],y[chosen_ids])
clf2.score(x,y)
0.7918753478018921

效果下降了很多,这还只是一个简单的任务,对于更难的任务模型的表现还会更差,那么我们就没有办法节省劳动力(偷懒)了吗?

非也,我们还有主动学习

主动学习

何谓主动学习?这里我采用一种通俗的讲法:

想象你面对百万大军,要想打败他们未必需要将其全部剿灭,有时只需要斩其上将首级即可。

主动学习做的,就是帮助我们找到那个“上将”,解决重点问题,达到事半功倍的效果。看下面的图:

左图中红绿代表两种数据。现在我们只能标注其中有限个数据来训练分类器。中间的图显示的就是随机标注的样本和得到的分界线,准确率大约为70%。而右图就是主动学习方法找到的标注点,因为这些点几乎构成了完美分界线的边界,所以使用与中图同样的样本数,它能够取得90%左右的准确率!

那么我们怎么找到这些关键的点呢?让我们从上例来考虑,上面找到的标注点都在中间区域,实际上都是相对比较容易混淆的、难以确定类别的样本。我们的目标就是找到这些样本。分类器的预测概率就可以帮助我们衡量这些特性,于是就有了下面的寻找策略:

  • RS策略:随机选择,作为对照。
  • LC策略:寻找分类器最没有信心的预测样本,即预测的最可能类别的概率也很低的样本。
  • BT策略:寻找分类器最“左右为难”的预测样本,即预测的最可能的两个类别的概率很接近的样本。

下面用代码实验和验证这些策略的效果:

我们进行10批5个样本的标注,对于不同策略选出的样本,使用直接抽取已知标签的方法来模拟手动标注。来看看随着标注的进行,模型的表现如何改善,以及其最终的效果。

batch_size = 5
rounds = 10
def RS(proba, batch_size):
    return np.random.choice(range(proba.shape[0]),batch_size,replace=False)

def LC(proba, batch_size):
    return np.argsort(np.max(proba,axis=1))[:batch_size]

def BT(proba, batch_size):
    sorted_proba = np.sort(proba,axis=1)
    return np.argsort(sorted_proba[:,-1]-sorted_proba[:,-2])[:batch_size]
strategies = {"RS":RS,"LC":LC,"BT":BT}
results = {"RS":[],"LC":[],"BT":[]}
for type0 in strategies:
    clf = LogisticRegression()
    anno_batch = RS(x,batch_size)    # 第一批标注样本只能随机选取
    x_train = x[anno_batch]
    y_train = y[anno_batch]
    for i in range(rounds-1):
        clf.fit(x_train,y_train)
        prec = clf.score(x,y)
        results[type0].append(prec)
        proba = clf.predict_proba(x)
        stategy0 = strategies[type0]      # 后面采用不同策略
        anno_batch = stategy0(proba,batch_size)
        x_train = np.concatenate([x_train,x[anno_batch]])
        y_train = np.concatenate([y_train,y[anno_batch]])
        
    prec = clf.score(x,y)
    results[type0].append(prec)

以上只是一次随机实验的结果,受每一批的具体样本影响而有一些波动。不过进行多次实验也可以看到,两种主动学习方法利用同样多的样本,却能够达到接近90%的准确率。确实比RS更胜一筹,我们可以使用这个技术来降低达到一定准确度所需的标注量。妈妈再也不用担心我手动标注数据的辛苦了!

体验手动标注

本文的完全代码,和我也写了一段真正手动标注数据的代码可以在这里找到。

有兴趣的更可以点开下面的链接,利用binder的黑科技,马上在线使用notebook亲身尝试一下,用可以接受的时间量标注,打造出一个表现良好的模型: https://mybinder.org/v2/gist/blmoistawinde/d48c41ebbc00fffa05ae498632187bd6/master?filepath=active_learning1.ipynb

参考资料: https://blog.csdn.net/qrlhl/article/details/48828589 https://blog.csdn.net/Houchaoqun_XMU/article/details/80146710

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 用HarvestText自动识别实体及人物别名,用于实体链接分析

    处理1需要文本清洗,处理2需要挖掘出人物别名,HarvestText提供了能够方便处理这些问题的接口:

    blmoistawinde
  • NeuralCoref: python的共指消解工具,向代词指代的问题进军!

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    blmoistawinde
  • 深度学习哪家强?用数据来一较高下吧

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    blmoistawinde
  • 【可视化】雄安新区潜力有多大?大数据告诉你

    雄安新区是继深圳经济特区和上海浦东新区之后又一具有全国意义的新区,地处北京、天津、保定腹地。范围涉及河北省雄县、容城、安新3县及周边部分区域。起步区面积约100...

    钱塘数据
  • 来自学渣的最实用面经

    前言: 1.你只需要稍微努力一点,就可以比大部分人优秀。 2.如果不能在技术上打败别人,就去从其他方面让自己出彩。 背景: 学校专业:某西北地区文科211学校...

    牛客网
  • PHP数据类型转换

            Java,c,c++等强类型语言必须先声明数据类型,java和c声明一个整数型(int a=100;)

    十月梦想
  • 你是否经常忘记 Linux 计划任务 Crontab 复杂的语法格式呢,用上这款神器后再也不用担心了!

    Linux / Unix 系统里有一个很方便的程序「例行性计划任务」(Crontab),接触过的朋友一定不陌生。Crontab 主要是让系统去执行一些固定时间要...

    iMike
  • 遇到这些 TS 问题你会头晕么?

    相信很多读者看到 let value: Fonum = 12; 这一行,TS 编译器并未提示任何错误会感到惊讶。很明显数字 12 并不是 Fonum 枚举的成员...

    semlinker
  • 工作--用户登录注册相关设计

    最近做一个网站,网站需要用户登录注册,自然也就需要一套高扩展性的用户模块设计,该篇文章记录笔者遇到问题的解决方案,希望对你有帮助。

    屈定
  • Java设计模式-策略模式

    策略模式: 定义一系列的算法, 将其一个个封装起来, 并使它们可相互替换, 使得算法可独立于使用它的客户而变化. ? (图片来源: 设计模式: 可复用面向对象...

    奋斗蒙

扫码关注云+社区

领取腾讯云代金券