文章目录

  • 一. 线性回归概述
  • 二. 线性回归的数学
  • 2.1 线性回归数学表达式
  • 2.2 误差
  • 2.2.1 误差的公式
  • 2.2.2 似然函数
  • 2.2.3 评估方法
  • 三. 求解最小二乘法
  • 3.1 矩阵式求解
  • 3.2 梯度下降法(GD)
  • 3.2.1 为什么要使用梯度下降
  • 3.2.2 梯度概念
  • 3.2.3 梯度下降法实验
  • 3.2.4 参数更新
  • 参考:


一. 线性回归概述

有监督分两类:

  1. 回归
    得到一个预测值,银行能借给你多少钱
  2. 分类
    得到一个类别值,银行是否借钱给你

一个例子:
数据: 工资和年龄(2个特征)

目标: 预测银行会贷款多少钱(标签)

考虑: 工资和年龄都会影响最终银行贷款的结果,那么它们各自有多啊的影响呢? (参数)

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归

找到最好的拟合曲线,需要找到每个变量的权重值

ggplot2 线性回归添加置信区间 线性回归的置信区间_python_02

二. 线性回归的数学

2.1 线性回归数学表达式

θ0是偏置量,因为不是所有点都是从原点开始的,所以有偏置量。

θ1和θ2 可以理解为权值,因为每个变量对结果影响的程度不一样。

ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_03

2.2 误差

此时我们已经有了拟合平面,但是很多点到拟合平面之间还是存在一定的距离,这个距离我们可以理解为我们回归模型的误差。

ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_04

2.2.1 误差的公式

第一个公式通过回归模型得到的预测值与误差。

第二个公式是 高斯分布的公式。

ggplot2 线性回归添加置信区间 线性回归的置信区间_python_05

2.2.2 似然函数

似然函数,通过样本估计参数值(权重值)

ggplot2 线性回归添加置信区间 线性回归的置信区间_ggplot2 线性回归添加置信区间_06

因为这个地方求的是误差之和,所以肯定是越小越好,前面log是正数,所以减号后面的值越大越好,这样两者相减才能得到一个比较小的值。

ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_07

现在就将目标转换为求解最小二乘法。

2.2.3 评估方法

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_08

三. 求解最小二乘法

从上面的推导可以得出结论:要求让似然函数越大越好,可转化为求θ取某个值时使J(θ)最小的问题。

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_09

求解最小二乘法的方法一般为两种:矩阵式、梯度下降法。

3.1 矩阵式求解

数据集含有m个样本,每个样本有n个特征时:

数据x可以写成m*(n+1)维的矩阵(+1是添加一列1,用于与截断b相乘);

θ则为n+1维的列向量(+1是截断b);

y为m维的列向量代表每m个样本结果的预测值。

  矩阵式的推导如下所示:

ggplot2 线性回归添加置信区间 线性回归的置信区间_数据分析_10

对于矩阵来说, ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_11

让J(θ)对θ求偏导,当偏导等于零时,则这个θ就是极值点。ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_12代表X矩阵的转置,ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_12与X的乘积一定会得到一个对称阵。

另外存在公式: ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_14 等于 ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_15

ggplot2 线性回归添加置信区间 线性回归的置信区间_ggplot2 线性回归添加置信区间_16的逆矩阵为:ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_17

0 = θ - ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_17 * ggplot2 线性回归添加置信区间 线性回归的置信区间_ggplot2 线性回归添加置信区间_19,转换等式得到:θ=ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_20

这种方法存在的问题:不存在学习的过程;矩阵求逆不是一个必然成功的行为(存在不可逆);

ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_21

3.2 梯度下降法(GD)

目标函数:

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_22


对于多元线性回归来说,拟合函数为:

ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_23


由于目标函数是由m个样本累加得到的,因此可以求一个平均得到损失函数:

ggplot2 线性回归添加置信区间 线性回归的置信区间_python_24

1. 对损失函数求偏导数,批量梯度下降:

ggplot2 线性回归添加置信区间 线性回归的置信区间_python_25

容易得到最优解,但是每次考虑所有样本,执行速度很慢。

2. 每次只用一个样本,随机梯度下降:

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_26

去除累加操作,每次抽样一个样本来计算,速度快,结果不准。

3. 每次更新选择一部分数据,小批量梯度下降法:

ggplot2 线性回归添加置信区间 线性回归的置信区间_数据分析_27

3.2.1 为什么要使用梯度下降

当得到一个目标函数时,通常是不能直接求解的,线性回归能求出结果在机器学习中是一个特例。

机器学习常规套路:交给机器一堆数据,然后告诉它使用什么样的学习方式(目标函数),然后它朝着这个方向去学习。

算法优化:一步步完成迭代,每次优化一点点,积累起来就能获得大成功。

3.2.2 梯度概念

在一元函数中叫做求导,在多元函数中就叫做求梯度。梯度下降是一个最优化算法,通俗的来讲也就是沿着梯度下降的方向来求出一个函数的极小值。比如一元函数中,加速度减少的方向,总会找到一个点使速度达到最小。

通常情况下,数据不可能完全符合我们的要求,所以很难用矩阵去求解,所以机器学习就应该用学习的方法,因此我们采用梯度下降,不断迭代,沿着梯度下降的方向来移动,求出极小值。

梯度下降法包括批量梯度下降法和随机梯度下降法(SGD)以及二者的结合mini批量下降法(通常与SGD认为是同一种,常用于深度学习中)。

3.2.3 梯度下降法实验

对于梯度下降,我们可以形象地理解为一个人下山的过程。假设现在有一个人在山上,现在他想要走下山,但是他不知道山底在哪个方向,怎么办呢?显然我们可以想到的是,一定要沿着山高度下降的地方走,不然就不是下山而是上山了。山高度下降的方向有很多,选哪个方向呢?这个人比较有冒险精神,他选择最陡峭的方向,即山高度下降最快的方向。现在确定了方向,就要开始下山了。

又有一个问题来了,在下山的过程中,最开始选定的方向并不总是高度下降最快的地方。这个人比较聪明,他每次都选定一段距离,每走一段距离之后,就重新确定当前所在位置的高度下降最快的地方。这样,这个人每次下山的方向都可以近似看作是每个距离段内高度下降最快的地方。

现在我们将这个思想引入线性回归,在线性回归中,我们要找到参数矩阵 θ 使得损失函数 J(θ) 最小。如果把损失函数J(θ)看作是这座山,山底不就是损失函数最小的地方吗,那我们求解参数矩阵 θ 的过程,就是人走到山底的过程。

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_28

如图所示,这是一元线性回归(即假设函数)

ggplot2 线性回归添加置信区间 线性回归的置信区间_数据分析_29


的损失函数图像,一开始我们选定一个起始点(通常是

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_30


),然后沿着这个起始点开始,沿着这一点处损失函数下降最快的方向(即该点的梯度负方向)走一小步,走完一步之后,到达第二个点,然后我们又沿着第二个点的梯度负方向走一小步,到达第三个点,以此类推,直到我们到底局部最低点。为什么是局部最低点呢?因为我们到达的这个点的梯度为 0 向量(通常是和 0 向量相差在某一个可接受的范围内),这说明这个点是损失函数的极小值点,并不一定是最小值点。

ggplot2 线性回归添加置信区间 线性回归的置信区间_机器学习_31

从梯度下降法的思想,我们可以看到,最后得到的局部最低点与我们选定的起始点有关。通常情况下,如果起始点不同,最后得到的局部最低点也会不一样。

3.2.4 参数更新

每次更新参数的操作:

ggplot2 线性回归添加置信区间 线性回归的置信区间_数据分析_32

其中α为学习率(步长),对结果会产生巨大的影响,调节学习率这个超参数是建模中的重要内容。

选择方法:从小的开始,不行再小(一般是0.01,然后再调整)。

批处理数量:32、64、128比较常用,很多时候还要考虑内存和效率。

ggplot2 线性回归添加置信区间 线性回归的置信区间_线性回归_33

因为J(θ)是凸函数,所以GD求出的最优解是全局最优解。批量梯度下降法是求出整个数据集的梯度,再去更新θ,所以每次迭代都是在求全局最优解。

参考:

  1. https://study.163.com/course/introduction.htm?courseId=1003590004#/courseDetail?tab=1