1 引言
回归分析技术旨在推断因变量和一个或多个自变量之间的关系,通常应用于监督学习领域。回归分析的方法有很多,可以应用在不同的情景之中,本文主要就其中的线性回归方法展开介绍.
本文的主要目的是尽可能清晰直观的介绍线性回归三种实现方式的原理和快速实现,所有模型的开发实现仅涉及Python和Numpy,这里采用生成的数据方便大家更好的理解和可视化。
2 线性相关
线性回归可以理解为一种统计分析过程,它推断因变量和一个或多个自变量之间为线性关系。通常来说,我们采用线性相关系数来描述X和Y的线性关系,公式如下:
上述公式中:
- μ表示随机变量的期望
- σ表示随机变量的标准差
- E表示期望操作
代码实现:
def correlation(X, Y):
mu_X, mu_Y = X.mean(), Y.mean()
cov = ((X - mu_X)*(Y - mu_Y)).mean()
sigma_X, sigma_Y = X.std(), Y.std()
return cov/(sigma_X*sigma_Y)
生成几张测试图,方便我们直观理解线性相关系数,图像如下:
3 简单线性回归
3.1 定义
简单线性回归对二维空间样本点进行处理,其中一个代表因变量Y,另一个表示自变量X,公式如下:
上式中,m表示斜率,b表示截距。可以用以下公式计算:
这里 和
3.2 实现
代码实现如下:
class linearRegression_simple(object):
def __init__(self):
# Init clas
self._m = 0
self._b = 0
def fit(self, X, y):
# Train model
X_, y_ = X.mean(), y.mean()
num = ((X - X_)*(y - y_)).sum()
den = ((X - X_)**2).sum()
self._m = num/den
self._b = y_ - self._m*X_
def pred(self, x):
# Predict
x = np.array(x)
return self._m*x + self._b
结果如下:
3.3 误差评估
这里一般采用均方差(MSE)作为我们的评价标准,其计算公式如下:
进而我们可以得到上述例子的回归误差如下:
4 多元线性回归
4.1 定义
多元线性回归对n维样本点进行处理,其中y代表因变量,另一个或多个代表自变量 , 公式如下:
4.2 实现
代码实现如下:
class linearRegression_multiple(object):
def __init__(self):
self._m = 0
self._b = 0
def fit(self, X, y):
X_, y_ = X.mean(axis=0), y.mean(axis=0)
num = ((X - X_)*(y - y_)).sum(axis=0)
den = ((X - X_)**2).sum(axis = 0)
self._m = num/den
self._b = y_ - (self._m*X_).sum()
def pred(self, x):
return (self._m*x).sum(axis=1) + self._b
假设我们数据为3维空间的超平面数据,包含一个因变量和两个自变量,采用上述算法,结果如下:
4.3 误差评估
这里依旧采用MSE作为评价标准,可视化后的误差结果如下:
5 梯度下降法
5.1 定义
在线性回归领域采用梯度下降法的目的在于最小化斜率m和截距b的误差,公式如下:
通常我们通过计算上述函数的梯度,如下:
上述公式中,相关梯度的详细计算如下:
5.2 实现
我们用python实现上述过程,代码如下:
class linearRegression_GD(object):
def __init__(self, mo=0, bo=0, rate=0.001):
self._m = mo # initial value for m
self._b = bo # initial value for b
self.rate = rate # iteration's rate
def fit_step(self, X, y):
n = X.size
dm = (2/n)*np.sum(-x*(y - (self._m*x + self._b)))
db = (2/n)*np.sum(-(y - (self._m*x + self._b)))
self._m -= dm*self.rate
self._b -= db*self.rate
def pred(self, x):
x = np.array(x)
return self._m*x + self._b
5.3 实现结果
这里我们采用学习率lr=0.01,共计迭代3072次,实现结果如下:
lrgd = linearRegression_GD(rate=0.01)
# Synthetic data 3
x, x_, y = synthData3()
iterations = 3072
for i in range(iterations):
lrgd.fit_step(x, y)
y_ = lrgd.pred(x)
可以看到,随着迭代次数的增加,误差逐渐减小,最后趋于稳定.
6 画图的重要性
1973年,统计学家F.J. Anscombe构造出了四组奇特的数据。它告诉人们,在分析数据之前,描绘数据所对应的图像有多么的重要。
我们使用简单线性回归,对上述四组数据进行处理,结果如下:
对应的误差分析如下:
[分析]
1)通过上图我们可以得到以下结论:
- 这四组数据中,x值的平均数都是9.0,y值的平均数都是7.5;
- x值的方差都是10.0,y值的方差都是3.75;
- 它们的相关系数都是0.82,线性回归方程都是y=3+0.5x。
单从上述这些统计数字上看来,四组数据所反映出的实际情况非常相近,而事实上,这四组数据有着天壤之别。
2)我们画出四组数据的图像,仔细观察后,可以得到
- 第一组数据的图像是大多人看到上述统计数字的第一反应,是最“正常”的一组数据;
- 第二组数据所反映的事实上是一个精确的二次函数关系,只是在错误地应用了线性模型后,各项统计数字与第一组数据恰好都相同;
- 第三组数据描述的是一个精确的线性关系,只是这里面有一个异常值,它导致了上述各个统计数字,尤其是相关系数的偏差;
- 第四组数据则是一个更极端的例子,其异常值导致了平均数、方差、相关系数、线性回归线等所有统计数字全部发生偏差。
7 总结
本文重点介绍了线性回归的三种实现方式,给出了具体示例和相应的处理代码,同时给出了一组数据用来说明画图对数据分析的重要性。