逻辑回归不是回归
由线性回归所知,回归主要处理连续型变量。而逻辑回归处理类别型变量,因此用于分类问题,就是用回归的办法来做分类。
举例
我们可以利用一个人饮食、睡眠等因素,预测这个人肿瘤的大小(回归问题),然后利用肿瘤大小阈值判定恶性还是良性(分类问题)
我们先利用简单线性回归的思路,利用肿瘤大小,预测这个肿瘤是良性还是恶性。
由上图所示,x轴为肿瘤大小,0代表良性肿瘤,1代表恶性肿瘤。
红×代表一个个例,粉线是回归出的函数图像,绿线是区分良性恶性的阈值
取一个概率为0.5的肿瘤大小作为阈值,假如值为y,肿瘤大小超过y为恶性肿瘤(恶性肿瘤概率大于0.5),小于y为良性,这就实现了一个简单二分类。但要是有离群值(异常值)的出现呢?
这时候就会出现错误了,因为我们拟合的回归函数太直了,要利用非线性的函数,才能更好地拟合这些数据。数学公式
x代表数据, Θ代表参数,就是多元线性回归的估计方程,只是参数换了个字符
为了描述方便,我们将上述公式用下面这个表示
上面这些公式是描述线性回归的,是直的,那我们要整成非线性怎么弄呢?利用sigmod函数(还有其他函数),公式如下:
因此,我们的公式变成了非线性的函数:
函数输出的结果是个例预测为1的概率,预测为0的概率只需用1减去hx即可
y=1的概率:
y=0的概率:
两个式子合起来:yi为个例的ground truth 标签(此例标签取值只有0和1,1时1-yi为0,0时yi=0)
例如预测个例1,预测其为1的概率为0.7,因此0的概率为0.3,其gt标签为1,因此P=0.7,就是说我们有0.7的概率预测正确
同样,预测个例2,预测其为1的概率为0.7,因此0的概率为0.3,其gt标签为0,因此P=0.3,就是说我们有0.3的概率预测正确
损失函数
现在问题来了,我们到现在只说了逻辑回归是咋工作的,还没说怎么才能拟合出一个好的函数,即不要忘了我们要得到一个合适的Θ
回顾一下,我们要得到一个好的线性回归函数,要保证的是找到合适的Θ使误差平方和最小
同样,逻辑回归是找到合适的Θ,使下面这个公式最小(交叉熵损失函数)
这都是些啥啊?我们一步一步推出来推导过程
首先,上面我们已经知道了个例预测正确的概率怎么算
对于m个个例的样本总体,都预测正确的概率:
此时我要保证这个概率越大越好,越大说明我们总体预测的越准确。
因此我们的目的是计算出一组Θ,使总体预测正确的概率最大。第二,由于连乘不好计算,因此转换为log,用加法表示
得到的这个函数越大,证明我们得到的Θ就越好.因为在函数最优化的时候习惯让一个函数越小越好,所以我们在前边加一个负号.得到公式如下:
推导结束~梯度下降
那么这Θ怎么求呢?我们选择使用梯度下降的算法
α为学习率
简单说,梯度下降就是求偏导
交叉熵损失函数的梯度和最小二乘的梯度形式上完全相同,仅仅是交叉熵的hx经过了sigmod函数处理
最后Θ更新法则如下:
Python实现
import numpy as np
import random
# 创建数据
def genData(numPoints,bias,variance):
x = np.zeros(shape=(numPoints,2))
y = np.zeros(shape=(numPoints))
for i in range(0,numPoints):
x[i][0]=1
x[i][1]=i
y[i]=(i+bias)+random.uniform(0,1)+variance
return x,y
# 梯度下降
def gradientDescent(x,y,theta,alpha,m,numIterations):
xTran = np.transpose(x)
for i in range(numIterations):
hypothesis = np.dot(x,theta)
loss = hypothesis-y
cost = np.sum(loss**2)/(2*m)
gradient=np.dot(xTran,loss)/m
theta = theta-alpha*gradient
print ("Iteration %d | cost :%f" %(i,cost))
return theta
x,y = genData(100, 25, 10)
print("x:", x)
print("y:", y)
m,n = np.shape(x)
n_y = np.shape(y)
print("m:"+str(m)+" n:"+str(n)+" n_y:"+str(n_y))
numIterations = 100000
alpha = 0.0005
theta = np.ones(n)
theta= gradientDescent(x, y, theta, alpha, m, numIterations)
print(theta)