吴恩达机器学习--逻辑回归学习记录
- 前言
- 一、干货代码
- 二、实现效果
- 1.拟合图
- 2.数据分析结果
- 三、学习过程中的记录
- 数据集
前言
本文主要记录在学习吴恩达机器学习中逻辑回归这一章的学习过程与要点,数据集附在文末。
一、干货代码
示例
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import scipy.optimize as opt
def get_X(df): # 获取特征,并添加x0列向量(全为1)
ones = pd.DataFrame({'ones': np.ones(df.shape[0])})
data = pd.concat([ones, df], axis=1) # 按axis=1即为列合并连接
return data.iloc[:, :-1].values # 按相对位置,获取全部特征值数组(用values是为了得到 nparray)
def get_y(df): # 读取标签值
return np.array(df.iloc[:, -1]) # 这样可以转换为数组
def normalize_feature(df): # 进行归一化处理
return df.apply(lambda column: (column - column.mean()) / column.std()) # 标准差标准化
def sigmoid(z): # 实现sigmoid函数
return 1 / (1 + np.exp(-z))
def cost(theta, X, y): # 实现代价函数
first = np.multiply(-y, np.log(sigmoid(X.dot(theta.T))))
second = np.multiply((1 - y), np.log(1 - sigmoid(X.dot(theta.T))))
return np.sum(first - second) / (len(X))
def gradient(theta, X, y): # 实现求梯度
return X.T @ (sigmoid(X @ theta) - y) / len(X)
def gradientDescennt(theta, X, y, alpha, iters): # 进行梯度下降法拟合参数
costs = np.zeros(iters)
temp = np.ones(len(theta))
for i in range(iters): # 进行迭代
temp = theta - alpha * gradient(theta, X, y)
theta = temp
costs[i] = cost(theta, X, y)
return theta, costs
def predict(X, theta): # 对数据进行预测
return (sigmoid(X @ theta.T) >= 0.5).astype(int) # 实现变量类型转换
if __name__ == '__main__':
path = 'data/ex2data1.txt'
data = pd.read_csv(path, header=None, names=['exam1', 'exam2', 'admitted'])
X = get_X(data)
y = get_y(data)
theta = np.zeros(X.shape[1]) # 返回numpy.ndarray类型一行三列
# 调用高级方法拟合参数,当然你也可以试试梯度下降
res = opt.minimize(fun=cost, x0=theta, args=(X, y), jac=gradient, method='Newton-CG')
print(res)
theta_res = res.x # 获取拟合的θ参数
y_pred = predict(X, theta_res)
print(classification_report(y, y_pred))
# theta_res,costs =gradientDescennt(theta,X,y,0.00001,500000)
# 先绘制原来数据
positive = data[data['admitted'] == 1] # 挑选出录取的数据
negative = data[data['admitted'] == 0] # 挑选出未被录取的数据
fig, ax = plt.subplots(figsize=(10, 5)) # 获取绘图对象
# 对录取的数据根据两次考试成绩绘制散点图
ax.scatter(positive['exam1'], positive['exam2'], s=30, c='b', marker='o', label='Admitted')
# 对未被录取的数据根据两次考试成绩绘制散点图
ax.scatter(negative['exam1'], negative['exam2'], s=30, c='r', marker='x', label='Not Admitted')
# 添加图例
ax.legend()
# 设置x,y轴的名称
ax.set_xlabel('Exam1 Score')
ax.set_ylabel('Exam2 Score')
plt.title("fitted curve vs sample")
# 绘制决策边界
#print(theta_res)
exam_x = np.arange(X[:, 1].min(), X[:, 1].max(), 0.01)
theta_res = - theta_res / theta_res[2] # 获取函数系数θ_0/θ_2 θ_0/θ_2
#print(theta_res)
exam_y = theta_res[0] + theta_res[1] * exam_x
plt.plot(exam_x, exam_y)
plt.show()
二、实现效果
1.拟合图
2.数据分析结果
三、学习过程中的记录
data = pd.read_csv(path, header=None, names=['exam1', 'exam2', 'admitted'])
positive = data[data['admitted'] == 1] # 挑选出录取的数据
negative = data[data['admitted'] == 0] # 挑选出未被录取的数据
1.对于上方的代码,使用pandas读入数据后,可以使用如上方法获得所有admitted==1或者0的所有行的数据,这是一种筛选目标数据的方法。
def get_X(df): # 获取特征,并添加x_0列向量(全为1)
ones = pd.DataFrame({'ones': np.ones(df.shape[0])})
data = pd.concat([ones, df], axis=1) # 按axis=1列合并连接
return data.iloc[:, :-1].values # 按相对位置,获取全部特征值数组
def get_y(df): # 读取标签值
return np.array(df.iloc[:, -1]) # 这样可以转换为数组
def normalize_feature(df): # 进行归一化处理
return df.apply(lambda column: (column - column.mean()) / column.std()) # 标准差标准化
2.此处应当注意的是:在训练模型之前要对数据进行数组转化 当然很多时候在提取完数据后其自身就是数组形式(<class ‘numpy.ndarray’>),这只是习惯性的谨慎。很多时候取得的数据是DataFrame的形式,这个时候要记得转换成数组
所以在get_x中,将dataFrame合并后,用values取值转化为ndarray,
iloc方法中[:, :-1]表示取所有行,然后从第0列到倒数第二列(不含最后一列),这里的参数是左闭右开的
在get_y中,同理,取得是所有行以及仅仅最后一列
3.代价函数:
def cost(theta, X, y): # 实现代价函数
first = np.multiply(-y, np.log(sigmoid(X.dot(theta.T))))
second = np.multiply((1 - y), np.log(1 - sigmoid(X.dot(theta.T))))
return np.sum(first - second) / (len(X))
对于上面的X*theta.T,我们如果使用了“*”运算符进行矩阵乘法操作,将*当作矩阵乘法,
那么我们必须保证两边操作数类型是matrix矩阵类型。
另外dot也可以实现矩阵乘法,
但是它要求传参是ndarray类型,并且两个数组保持第一个矩阵列数等于第二个矩阵行数。
4.可视化决策边界
数据集
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1