SMOTE算法处理非平衡数据与结果评估

  • 算法的提出背景:
  • 不平衡数据的常见处理方法
  • SMOTE算法的原理
  • SMOTE算法的python代码实现


算法的提出背景:

在实际应用中,针对 分类问题中类别型的因变量可能存在严重的偏倚,即类别之间的比例严重失调。如:(1)欺诈问题中,欺诈类情况在样本集中毕竟占少数;(2)客户流失问题中,非忠实的客户往往也是占很少一部分;(3)在某营销活动的响应问题中,真正参与活动的客户也同样只是少部分。
在上述问题中,由于数据存在严重的不平衡,导致预测出的结论往往也是有偏的,即:分类结果会偏向于较多观测的类。为了解决数据的非平衡问题,2002年Chawla提出了 SMOTE算法,即合成少数过采样技术,该技术 是目前处理非平衡数据的常用手段

不平衡数据的常见处理方法

不平衡的样本会影响模型的评估效果,严重的会带来过拟合的结果。所以我们需要让正负样本在训练过程中拥有相同话语权或权重。在这里,称数据集中样本较多的一类称为“大众类”(majority class),样本较少的一类称为“小众类”(minority class)。

对于不平衡样本的处理做法总结如下

非平衡数据如何进行面板回归_过拟合


(1)常规处理方法有上采样和下采样两种,具体而言:

  • 上采样(Oversampling,过采样):通过复制多份小众类,使得一些样本在小众类中反复出现,这会导致过拟合改进解法:使用 数据合成 的方法基于已有数据生成更多的样本,其中SMOTE最为常见;或者可以通过 数据加权 的方法来解决问题,但其难点在于如何合理设置权重
  • 下采样(Undersampling,欠采样):通过选取部分大众类,使得部分样本信息丢失而导致欠拟合改进解法:使用 EasyEnsembleBalanceCascade 两种改进方法;

(2)除了上述情形,对于正负样本极不平衡的情况,我们也可以视其为 异常值检测(Outlier Detection)一分类(One Class Learning) 问题。经典的工具包有One-class SVM等。
(3)以上方法着重于处理数据,但同时可以选用适于不平衡样本的模型,比如XGBoost 。

SMOTE算法的原理

基本概念

SMOTE 全称是Synthetic Minority Oversampling Technique,即:合成少数类过采样技术。它是基于随机过采样算法的一种改进方案,由于随机过采样算法采取简单复制样本的策略来增加少数类样本,这样容易产生模型过拟合的问题,即使得模型学习到的信息过于特别(Specific)而不够泛化(General)。

算法流程

SMOTE算法的 基本思想 是对少数类样本进行分析并根据少数类样本人工合成新样本添加到数据集中。SMOTE的 算法流程 如下:

(1)对于少数类中每一个样本 非平衡数据如何进行面板回归_数据集_02,以欧氏距离为标准计算它到少数类样本集非平衡数据如何进行面板回归_过拟合_03中所有样本的距离,得到其 非平衡数据如何进行面板回归_数据_04近邻

(2)根据样本不平衡比例设置一个采样比例以确定采样倍率非平衡数据如何进行面板回归_过拟合_05,对于每一个少数类样本 非平衡数据如何进行面板回归_数据集_02,从其 非平衡数据如何进行面板回归_数据集_07 近邻中随机选择若干个样本,假设选择的近邻为 非平衡数据如何进行面板回归_数据集_08

(3)对于每一个随机选出的近邻 非平衡数据如何进行面板回归_数据集_08,分别与原样本按照如下的公式构建新的样本 。

非平衡数据如何进行面板回归_过拟合_10

非平衡数据如何进行面板回归_非平衡数据如何进行面板回归_11


SMOTE算法的缺陷

(1)在近邻选择时,存在一定的盲目性;

(2)无法克服非平衡数据集的数据分布问题,容易产生分布边缘化问题。

SMOTE算法的python代码实现

Nearest Neighbors的算法原理详见链接

# SMOTE算法及其python实现
import random
from sklearn.neighbors import NearestNeighbors
import numpy as np

class Smote:
    def __init__(self,samples,N=10,k=5):
        self.n_samples, self.n_attrs=samples.shape
        self.N = N   #采样倍率N
        self.k = k   #k近邻
        self.samples = samples
        self.newindex = 0
        # self.synthetic = np.zeros((self.n_samples*N, self.n_attrs))
        
    def over_sampling(self):
        N = int(self.N/100)
        self.synthetic = np.zeros((self.n_samples*N,self.n_attrs))
        neighbors=NearestNeighbors(n_neighbors=self.k).fit(self.samples)
        print('neighbors',neighbors)
        for i in range(len(self.samples)):
            print('samples',self.samples[i])
            #Finds the K-neighbors of a point.
            # reshape(-1,1) 将self.samples[i]变成只有一列,行数不限定的np.array
            nnarray=neighbors.kneighbors(self.samples[i].reshape((1,-1)),return_distance=False)[0]  
            print('nna',nnarray)
            self._populate(N,i,nnarray)
        return self.synthetic
    
    # for each minority class sample i ,choose N of the k nearest neighbors and generate N synthetic samples.
    def _populate(self,N,i,nnarray):
        for j in range(N):
            print('j',j)
            #random.randint(a,b)用于生成一个指定范围内的整数。其中参数a是下限,参数b是上限,生成的随机数n: a <= n <= b。
            nn=random.randint(0,self.k-1)   #包括end
            dif=self.samples[nnarray[nn]]-self.samples[i]
            gap=random.random()  #random.random()方法返回一个随机数,其在0至1的范围之内
            self.synthetic[self.newindex]=self.samples[i]+gap*dif
            self.newindex+=1
            print(self.newindex)
    
a=np.array([[1,2,3],[4,5,6],[2,3,1],[2,1,2],[2,3,4],[2,3,4]])
s=Smote(a,N=1000)
s.over_sampling()
# 输出结果
>>> 
neighbors NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
         metric_params=None, n_jobs=None, n_neighbors=5, p=2, radius=1.0)
samples [1 2 3]
nna [0 4 5 3 2]
j 0
1
j 1
2
j 2
3
j 3
4
j 4
5
j 5
6
j 6
7
j 7
8
j 8
9
j 9
10
samples [4 5 6]
nna [1 4 5 0 2]
j 0
11
j 1
12
j 2
13
j 3
14
j 4
15
j 5
16
j 6
17
j 7
18
j 8
19
j 9
20
samples [2 3 1]
nna [2 3 0 4 5]
j 0
21
j 1
22
j 2
23
j 3
24
j 4
25
j 5
26
j 6
27
j 7
28
j 8
29
j 9
30
samples [2 1 2]
nna [3 0 2 4 5]
j 0
31
j 1
32
j 2
33
j 3
34
j 4
35
j 5
36
j 6
37
j 7
38
j 8
39
j 9
40
samples [2 3 4]
nna [4 5 0 3 2]
j 0
41
j 1
42
j 2
43
j 3
44
j 4
45
j 5
46
j 6
47
j 7
48
j 8
49
j 9
50
samples [2 3 4]
nna [4 5 0 3 2]
j 0
51
j 1
52
j 2
53
j 3
54
j 4
55
j 5
56
j 6
57
j 7
58
j 8
59
j 9
60
array([[1.43114574, 2.43114574, 2.13770853],
       [1.505008  , 2.505008  , 1.989984  ],
       [1.82933855, 2.82933855, 3.82933855],
       [1.33369947, 2.33369947, 3.33369947],
       [1.36238991, 1.63761009, 2.63761009],
       [1.59592334, 2.59592334, 3.59592334],
       [1.52601224, 2.52601224, 1.94797552],
       [1.72191489, 2.72191489, 1.55617022],
       [1.90931557, 2.90931557, 1.18136886],
       [1.30839364, 2.30839364, 2.38321271],
       [4.        , 5.        , 6.        ],
       [3.65166415, 4.65166415, 5.65166415],
       [2.67063446, 3.67063446, 4.67063446],
       [2.00398821, 3.00398821, 4.00398821],
       [3.23754103, 4.23754103, 5.23754103],
       [4.        , 5.        , 6.        ],
       [1.96349197, 2.96349197, 3.96349197],
       [3.47063379, 4.47063379, 5.47063379],
       [2.79552331, 3.79552331, 4.79552331],
       [3.11046624, 4.11046624, 5.11046624],
       [2.        , 3.        , 3.15344032],
       [2.        , 3.        , 1.        ],
       [2.        , 1.53383837, 1.73308082],
       [1.87393359, 2.87393359, 1.25213282],
       [2.        , 2.51536734, 1.24231633],
       [2.        , 3.        , 1.        ],
       [2.        , 3.        , 1.        ],
       [2.        , 3.        , 1.        ],
       [2.        , 2.7956351 , 1.10218245],
       [2.        , 3.        , 2.76382984],
       [2.        , 2.29044423, 3.29044423],
       [2.        , 1.80864649, 1.59567676],
       [2.        , 2.02505744, 3.02505744],
       [1.99251032, 1.00748968, 2.00748968],
       [1.10108627, 1.89891373, 2.89891373],
       [2.        , 1.3940886 , 2.3940886 ],
       [2.        , 2.82132954, 3.82132954],
       [2.        , 2.07818169, 3.07818169],
       [2.        , 2.05311961, 1.47344019],
       [2.        , 1.01188768, 2.01188768],
       [1.74343118, 2.74343118, 3.74343118],
       [2.        , 3.        , 1.46135479],
       [2.        , 3.        , 3.26300097],
       [2.        , 2.50086116, 3.50086116],
       [2.        , 3.        , 4.        ],
       [2.        , 3.        , 4.        ],
       [2.        , 3.        , 4.        ],
       [2.        , 3.        , 4.        ],
       [2.        , 3.        , 4.        ],
       [1.14890347, 2.14890347, 3.14890347],
       [1.37136719, 2.37136719, 3.37136719],
       [2.        , 3.        , 3.21722991],
       [2.        , 3.        , 4.        ],
       [1.33912286, 2.33912286, 3.33912286],
       [1.47793068, 2.47793068, 3.47793068],
       [2.        , 3.        , 4.        ],
       [2.        , 2.38943522, 3.38943522],
       [2.        , 3.        , 1.65585476],
       [2.        , 1.16775435, 2.16775435],
       [1.86018336, 2.86018336, 3.86018336]])