Multi-class Logistic Regression

  • 1. softmax函数
  • 2. 与二元Logistic回归的关系
  • 3. 误差函数
  • 3.1 多元回归的1-of-K表示(one-hot)
  • 3.2 训练样本集的似然函数
  • 3.3 交叉熵误差函数
  • 4. 最大似然估计
  • 代码实现(mnist数据集)

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_02摘记 一文中对二元二元和多元logistic回归 二元 多元logistic回归_回归_03回归进行了详细的介绍,本文主要描述采用 二元和多元logistic回归 二元 多元logistic回归_机器学习_04 函数实现多元二元和多元logistic回归 二元 多元logistic回归_回归_03回归:这实际上是用一个(不含隐藏层的)单层神经网络来实现多元分类,其输出函数采用的是 二元和多元logistic回归 二元 多元logistic回归_机器学习_04 函数。
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归

1. softmax函数

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归对于某个输入 二元和多元logistic回归 二元 多元logistic回归_算法_09,其对应的 二元和多元logistic回归 二元 多元logistic回归_机器学习_04 输出为向量值 二元和多元logistic回归 二元 多元logistic回归_回归_11,且满足 二元和多元logistic回归 二元 多元logistic回归_机器学习_12

二元和多元logistic回归 二元 多元logistic回归_回归_13 分类问题 二元和多元logistic回归 二元 多元logistic回归_样本集_14 中使用 二元和多元logistic回归 二元 多元logistic回归_机器学习_04 函数 二元和多元logistic回归 二元 多元logistic回归_回归_16 表示输出值分量

二元和多元logistic回归 二元 多元logistic回归_样本集_17 二元和多元logistic回归 二元 多元logistic回归_回归_18

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归

二元和多元logistic回归 二元 多元logistic回归_算法_20

此处,其实是采用激活函数为softmax的感知器模型:二元和多元logistic回归 二元 多元logistic回归_样本集_21

二元和多元logistic回归 二元 多元logistic回归_回归_22 对于多元Logistic回归: 二元和多元logistic回归 二元 多元logistic回归_算法_23

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归  其中,二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_25

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归  若记 二元和多元logistic回归 二元 多元logistic回归_算法_28二元和多元logistic回归 二元 多元logistic回归_算法_29,则 二元和多元logistic回归 二元 多元logistic回归_机器学习_30

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归  【为了方便描述】可以略掉 ‘二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_33’ 号,直接写成: 二元和多元logistic回归 二元 多元logistic回归_回归_34

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_36输出值分量 二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_37 描述成后验概率的形式:

二元和多元logistic回归 二元 多元logistic回归_样本集_17 二元和多元logistic回归 二元 多元logistic回归_算法_39
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归

2. 与二元Logistic回归的关系

对比二元Logistic回归

  • 二元和多元logistic回归 二元 多元logistic回归_回归_41正例的概率:二元和多元logistic回归 二元 多元logistic回归_回归_42
  • 二元和多元logistic回归 二元 多元logistic回归_回归_41负例的概率:二元和多元logistic回归 二元 多元logistic回归_样本集_44

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_46 时,二元和多元logistic回归 二元 多元logistic回归_机器学习_04 函数实际上等同于二元 二元和多元logistic回归 二元 多元logistic回归_机器学习_48 回归(假设 二元和多元logistic回归 二元 多元logistic回归_算法_49):

二元和多元logistic回归 二元 多元logistic回归_样本集_17 二元和多元logistic回归 二元 多元logistic回归_样本集_51

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归二元和多元logistic回归 二元 多元logistic回归_机器学习_53,那么类后验概率就是二元 二元和多元logistic回归 二元 多元logistic回归_机器学习_48 回归中情形。
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归

3. 误差函数

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归针对多元Logistic回归,首先要写出其误差函数。

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归假设训练样本集为 二元和多元logistic回归 二元 多元logistic回归_回归_58,其中 二元和多元logistic回归 二元 多元logistic回归_样本集_59,参数为 二元和多元logistic回归 二元 多元logistic回归_机器学习_60

二元Logistic回归 假设训练样本为 二元和多元logistic回归 二元 多元logistic回归_回归_61,其中 二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_62,似然函数为:
 
二元和多元logistic回归 二元 多元logistic回归_算法_63
 
取“负的对数似然函数”作为误差函数,即:二元和多元logistic回归 二元 多元logistic回归_算法_64

3.1 多元回归的1-of-K表示(one-hot)

二元和多元logistic回归 二元 多元logistic回归_回归_13变量 二元和多元logistic回归 二元 多元logistic回归_回归_66 表示输入 二元和多元logistic回归 二元 多元logistic回归_算法_09

二元和多元logistic回归 二元 多元logistic回归_回归_22 引入目标向量 二元和多元logistic回归 二元 多元logistic回归_回归_69,满足 二元和多元logistic回归 二元 多元logistic回归_回归_70

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归  表示“输入 二元和多元logistic回归 二元 多元logistic回归_算法_09 属于第 二元和多元logistic回归 二元 多元logistic回归_机器学习_73 类” 或者说变量 二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_74

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_36向量值 二元和多元logistic回归 二元 多元logistic回归_回归_11 表示输入 二元和多元logistic回归 二元 多元logistic回归_算法_09 所对应的二元和多元logistic回归 二元 多元logistic回归_回归_78输出

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归  二元和多元logistic回归 二元 多元logistic回归_回归_80

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归  显然,二元和多元logistic回归 二元 多元logistic回归_机器学习_82
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归

3.2 训练样本集的似然函数

二元和多元logistic回归 二元 多元logistic回归_回归_13 对于第 二元和多元logistic回归 二元 多元logistic回归_样本集_85 个训练样本 二元和多元logistic回归 二元 多元logistic回归_样本集_86,其 二元和多元logistic回归 二元 多元logistic回归_机器学习_04 输出为 二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_88,且

二元和多元logistic回归 二元 多元logistic回归_回归_89

二元和多元logistic回归 二元 多元logistic回归_回归_22 训练样本集 二元和多元logistic回归 二元 多元logistic回归_回归_58 的似然函数 二元和多元logistic回归 二元 多元logistic回归_样本集_92

二元和多元logistic回归 二元 多元logistic回归_算法_93
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归

3.3 交叉熵误差函数

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归定义训练样本集 二元和多元logistic回归 二元 多元logistic回归_回归_58交叉熵误差函数 二元和多元logistic回归 二元 多元logistic回归_机器学习_97

二元和多元logistic回归 二元 多元logistic回归_样本集_98

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归使用交叉熵作为误差函数,是因为:

二元和多元logistic回归 二元 多元logistic回归_回归_13 若训练样本 二元和多元logistic回归 二元 多元logistic回归_算法_101 的类别 二元和多元logistic回归 二元 多元logistic回归_算法_102,则对应的目标向量 二元和多元logistic回归 二元 多元logistic回归_算法_103 只有第 二元和多元logistic回归 二元 多元logistic回归_机器学习_73 个分量 二元和多元logistic回归 二元 多元logistic回归_机器学习_105,而其他分量 二元和多元logistic回归 二元 多元logistic回归_样本集_106

二元和多元logistic回归 二元 多元logistic回归_回归_22 在训练过程中,二元和多元logistic回归 二元 多元logistic回归_机器学习_108 是训练样本 二元和多元logistic回归 二元 多元logistic回归_算法_101 所对应 二元和多元logistic回归 二元 多元logistic回归_机器学习_04 输出的第 二元和多元logistic回归 二元 多元logistic回归_机器学习_73 个分量(训练样本的正确类别 二元和多元logistic回归 二元 多元logistic回归_机器学习_73

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_36 如果正确类别 二元和多元logistic回归 二元 多元logistic回归_样本集_114 二元和多元logistic回归 二元 多元logistic回归_机器学习_108 越大,二元和多元logistic回归 二元 多元logistic回归_机器学习_116

二元和多元logistic回归 二元 多元logistic回归_算法_117 理想情况下,正确类别 二元和多元logistic回归 二元 多元logistic回归_样本集_114 二元和多元logistic回归 二元 多元logistic回归_机器学习_119,那么交叉熵为 二元和多元logistic回归 二元 多元logistic回归_算法_120,也就是没有训练误差。

也可以采用均方误差 二元和多元logistic回归 二元 多元logistic回归_回归_121

4. 最大似然估计

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归为了求出参数 二元和多元logistic回归 二元 多元logistic回归_回归_123,同样采用最大似然估计。

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归可以将训练样本集分成 二元和多元logistic回归 二元 多元logistic回归_算法_125 个子集 二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_126,第 二元和多元logistic回归 二元 多元logistic回归_机器学习_73 个子集 二元和多元logistic回归 二元 多元logistic回归_机器学习_128 中的所有样本 二元和多元logistic回归 二元 多元logistic回归_算法_101 的类别都为 二元和多元logistic回归 二元 多元logistic回归_算法_102,对应的目标向量 二元和多元logistic回归 二元 多元logistic回归_算法_103 都满足 二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_132,由误差函数的表达式:

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归_133

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归二元和多元logistic回归 二元 多元logistic回归_机器学习_135 求参数 二元和多元logistic回归 二元 多元logistic回归_算法_136 的偏导分为两个部分:
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归
二元和多元logistic回归 二元 多元logistic回归_回归_13二元和多元logistic回归 二元 多元logistic回归_机器学习_135 的第 二元和多元logistic回归 二元 多元logistic回归_机器学习_73 个分量 二元和多元logistic回归 二元 多元logistic回归_机器学习_141 求参数 二元和多元logistic回归 二元 多元logistic回归_算法_136

二元和多元logistic回归 二元 多元logistic回归_机器学习_143

二元和多元logistic回归 二元 多元logistic回归_回归_22二元和多元logistic回归 二元 多元logistic回归_机器学习_135 的第 二元和多元logistic回归 二元 多元logistic回归_回归_146 个分量 二元和多元logistic回归 二元 多元logistic回归_机器学习_147 求参数 二元和多元logistic回归 二元 多元logistic回归_算法_136

二元和多元logistic回归 二元 多元logistic回归_机器学习_149

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归综合起来,两个公式可以表示为:

二元和多元logistic回归 二元 多元logistic回归_算法_151

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归采用梯度下降法时,权值更新公式为:

二元和多元logistic回归 二元 多元logistic回归_样本集_153

二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归其中 二元和多元logistic回归 二元 多元logistic回归_算法_155 为梯度下降法的步长。
二元和多元logistic回归 二元 多元logistic回归_二元和多元logistic回归

代码实现(mnist数据集)

import numpy as np
from dataset.mnist import load_mnist

def softmax_train(train,target,alpha,num): 
    xhat = np.concatenate((train,np.ones((len(train),1))),axis=1)
    nparam = len(xhat.T) #785
    beta = np.random.rand(nparam,10)    #785x10
    for i in range(num):
        wtx = np.dot(xhat,beta)
        wtx1 = wtx - np.max(wtx,axis=1).reshape(len(train),1)
        e_wtx = np.exp(wtx1)
        yx = e_wtx/np.sum(e_wtx,axis=1).reshape(len(xhat),1)
        print('  #'+str(i+1)+' : '+str(cross_entropy(yx,target)))
        t1 = target - yx
        t2 = np.dot(xhat.T, t1)
        beta = beta + alpha*t2
    
    return beta
    
def cross_entropy(yx,t):  
    sum1 = np.sum(yx*t,axis=1)
    ewx = np.log(sum1+0.000001)
    return -np.sum(ewx)/len(yx)
    
def classification(test, beta, test_t):
    xhat = np.concatenate((test,np.ones((len(test),1))),axis=1)    
    wtx = np.dot(xhat,beta)
    output = np.where(wtx==np.max(wtx,axis=1).reshape((len(test),1)))[1]
    
    print("Percentage Correct: ",np.where(output==test_t)[0].shape[0]/len(test))
    return np.array(output,dtype=np.uint8) 

if __name__ == '__main__': 
    (x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

    nread = 60000
    train_in = x_train[:nread,:]
    train_tgt = np.zeros((nread,10))    

    test_in = x_test[:10000,:]
    test_t = t_test[:10000]
  
    for i in range(nread):
        train_tgt[i,t_train[i]] = 1
        
    beta = softmax_train(train_in,train_tgt,0.001,60)
    print(beta)
    result = classification(test_in, beta, test_t)

测试结果:
#1 : 5.626381119337011
#2 : 5.415158063701459
#3 : 10.959830171565791
#4 : 8.062787294189338
#5 : 7.4643357380759765
#6 : 9.070059164063883
#7 : 9.81079287953052
#8 : 7.13921201579068
#9 : 7.176904417794094
#10 : 4.607102717465571
#11 : 3.9215536116316625
#12 : 4.199011112147004
#13 : 4.135313269465135
#14 : 3.214738972020379
#15 : 2.804664146283606
#16 : 2.901161881757491
#17 : 2.9996749271603456
#18 : 2.609904566490558
#19 : 2.6169338357951197
#20 : 2.538795429964946
#21 : 2.7159497447897256
#22 : 2.634980803678192
#23 : 2.974848646434367
#24 : 3.1286179795674154
#25 : 3.2208869228881407
#26 : 2.548910343301664
#27 : 2.5298981152704743
#28 : 2.3826001247525035
#29 : 2.4498572463653243
#30 : 2.3521370651353837
#31 : 2.4309032741212664
#32 : 2.366133209606206
#33 : 2.4462922376053364
#34 : 2.3850487760328933
#35 : 2.4481429887352792
#36 : 2.370067560256672
#37 : 2.376729198498193
#38 : 2.297488373847759
#39 : 2.265126273640295
#40 : 2.258495714414137
#41 : 2.327524884607823
#42 : 2.3130200962416128
#43 : 2.290046983208286
#44 : 2.1465196716967805
#45 : 2.0969060851949677
#46 : 1.8901858209971119
#47 : 1.844354795879705
#48 : 1.6340799726564934
#49 : 1.60064459794013
#50 : 1.4667008762515674
#51 : 1.4453938385590863
#52 : 1.3767004735390218
#53 : 1.359619935503484
#54 : 1.3153462460865966
#55 : 1.309895715988472
#56 : 1.2799649790773286
#57 : 1.2807586745656392
#58 : 1.2559139323742572
#59 : 1.2582212637839076
#60 : 1.237819660093416

权值:
[[7.69666472e-01 2.16009202e-01 9.81729719e-01 … 5.32453082e-01
7.88719040e-01 5.14326954e-01]
[3.90401951e-01 5.84040914e-01 7.94883641e-01 … 8.02009249e-01
3.29345264e-02 6.70861290e-01]
[8.69075434e-02 8.43381782e-01 4.77683466e-01 … 8.71965798e-01
4.47018470e-04 5.07498017e-01]

[7.96129468e-01 6.14364951e-01 8.32783158e-01 … 6.53493763e-01
2.06235991e-01 8.60469591e-01]
[1.67070291e-01 3.23211147e-02 2.41519794e-01 … 6.56026583e-01
5.98396521e-01 5.42304452e-01]
[8.43299673e-01 6.22843596e-01 6.05652099e-02 … 1.10339403e-01
1.61855811e-01 3.29385438e-01]]

识别率:
Percentage Correct: 0.9037