前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >数据不平衡之SMOTE算法

数据不平衡之SMOTE算法

作者头像
全栈程序员站长
发布2022-06-28 10:14:35
6050
发布2022-06-28 10:14:35
举报
文章被收录于专栏:全栈程序员必看

大家好,又见面了,我是你们的朋友全栈君。

在企业的数据分析中,很少会遇到正负样本数据比例平衡的状况。通常情况是,绝大多数为正样本,而只有极少数(几个或者十几个)负样本。在这种情况下,不论是用LR,SVM或者基于提升方法的随机森林,直接用该数据集进行学习的效果都不会太好,原因是这些方法的学习结果都会偏向于样本较多的一类。另一个方面,对学习结果进行评估时,假如正样本占95%,负样本仅占5%,这样甚至不需要学习,直接把所有新样本预测为正,准确率就可以达到95%,而召回率却很低。因此,在学习一个模型前,处理不平衡的数据是十分必要的。

怎么让不平衡的数据变平衡呢?主要有两个方法,一是欠抽样,顾名思义就是删除正样本(以正样本占绝大多数为例)中的样本,删除的数量根据负样本的数量而定,这种方法的缺点也很明显,会删除正样本所带的信息,当正负样本的比例悬殊时,需要删除较多的正样本数量,这会减少很多正样本携带的信息。因此,最常用的方法是第二种——过抽样。

一种过抽样的方法是随机采样,采用简单随机复制样本来增加负样本的数量。这样容易产生模型的过拟合问题,即使得模型学习到的信息过于特别而不够泛化。另一种过抽样的方法就是要介绍的SMOTE算法,其基本思想是对少数类样本进行分析并根据少数类样本人工合成新样本添加到数据集中,算法流程如下。

(1)对于少数类中每一个样本x,以欧氏距离为标准计算它到少数类样本集中所有样本的距离,得到其k近邻。 (2)根据样本不平衡比例设置一个采样比例以确定采样倍率N,对于每一个少数类样本x,从其k近邻中随机选择若干个样本,假设选择的近邻为o。

(3)对于每一个随机选出的近邻o,分别与原样本按照公式o(new)=o+rand(0,1)*(x-o)构建新的样本。

一个简单的python代码如下

# %load smote.py

import random from sklearn.neighbors import NearestNeighbors import numpy as np

class Smote: def __init__(self,samples,N=1,k=5): self.n_samples,self.n_attrs=samples.shape self.N=N self.k=k self.samples=samples self.newindex=0 # self.synthetic=np.zeros((self.n_samples*N,self.n_attrs))

def over_sampling(self): N=int(self.N) self.synthetic = np.zeros((self.n_samples * N, self.n_attrs)) neighbors=NearestNeighbors(n_neighbors=self.k).fit(self.samples) print(‘neighbors’,neighbors) for i in range(len(self.samples)): nnarray=neighbors.kneighbors(self.samples[i].reshape(1,-1),return_distance=False)[0] #print nnarray self._populate(N,i,nnarray) return self.synthetic # for each minority class samples,choose N of the k nearest neighbors and generate N synthetic samples. def _populate(self,N,i,nnarray): for j in range(N): nn=random.randint(0,self.k-1) dif=self.samples[nnarray[nn]]-self.samples[i] gap=random.random() self.synthetic[self.newindex]=self.samples[i]+gap*dif self.newindex+=1 a=np.array([[1,2,3],[4,5,6],[2,3,1],[2,1,2],[2,3,4],[2,3,4]]) s=Smote(a,N=2) #a为少数数据集,N为倍率,即从k-邻居中取出几个样本点 print(s.over_sampling())

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/151097.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档