数据集:​​蓝凑云​



文章目录



Homework 1: Linear Regression

Load 'train.csv’

import sys
import pandas as pd
import numpy as np
# data = pd.read_csv('gdrive/My Drive/hw1-regression/train.csv', header = None, encoding = 'big5')
data = pd.read_csv('./train.csv', encoding = 'big5')
data



日期

測站

測項

0

1

2

3

4

5

6

...

14

15

16

17

18

19

20

21

22

23

0

2014/1/1

豐原

AMB_TEMP

14

14

14

13

12

12

12

...

22

22

21

19

17

16

15

15

15

15

1

2014/1/1

豐原

CH4

1.8

1.8

1.8

1.8

1.8

1.8

1.8

...

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

2

2014/1/1

豐原

CO

0.51

0.41

0.39

0.37

0.35

0.3

0.37

...

0.37

0.37

0.47

0.69

0.56

0.45

0.38

0.35

0.36

0.32

3

2014/1/1

豐原

NMHC

0.2

0.15

0.13

0.12

0.11

0.06

0.1

...

0.1

0.13

0.14

0.23

0.18

0.12

0.1

0.09

0.1

0.08

4

2014/1/1

豐原

NO

0.9

0.6

0.5

1.7

1.8

1.5

1.9

...

2.5

2.2

2.5

2.3

2.1

1.9

1.5

1.6

1.8

1.5

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

4315

2014/12/20

豐原

THC

1.8

1.8

1.8

1.8

1.8

1.7

1.7

...

1.8

1.8

2

2.1

2

1.9

1.9

1.9

2

2

4316

2014/12/20

豐原

WD_HR

46

13

61

44

55

68

66

...

59

308

327

21

100

109

108

114

108

109

4317

2014/12/20

豐原

WIND_DIREC

36

55

72

327

74

52

59

...

18

311

52

54

121

97

107

118

100

105

4318

2014/12/20

豐原

WIND_SPEED

1.9

2.4

1.9

2.8

2.3

1.9

2.1

...

2.3

2.6

1.3

1

1.5

1

1.7

1.5

2

2

4319

2014/12/20

豐原

WS_HR

0.7

0.8

1.8

1

1.9

1.7

2.1

...

1.3

1.7

0.7

0.4

1.1

1.4

1.3

1.6

1.8

2

4320 rows × 27 columns

data = data.iloc[:, 3:]# 切片,行全部保留,列从第三列开始



0

1

2

3

4

5

6

7

8

9

...

14

15

16

17

18

19

20

21

22

23

0

14

14

14

13

12

12

12

12

15

17

...

22

22

21

19

17

16

15

15

15

15

1

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

...

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

2

0.51

0.41

0.39

0.37

0.35

0.3

0.37

0.47

0.78

0.74

...

0.37

0.37

0.47

0.69

0.56

0.45

0.38

0.35

0.36

0.32

3

0.2

0.15

0.13

0.12

0.11

0.06

0.1

0.13

0.26

0.23

...

0.1

0.13

0.14

0.23

0.18

0.12

0.1

0.09

0.1

0.08

4

0.9

0.6

0.5

1.7

1.8

1.5

1.9

2.2

6.6

7.9

...

2.5

2.2

2.5

2.3

2.1

1.9

1.5

1.6

1.8

1.5

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

4315

1.8

1.8

1.8

1.8

1.8

1.7

1.7

1.8

1.8

1.8

...

1.8

1.8

2

2.1

2

1.9

1.9

1.9

2

2

4316

46

13

61

44

55

68

66

70

66

85

...

59

308

327

21

100

109

108

114

108

109

4317

36

55

72

327

74

52

59

83

106

105

...

18

311

52

54

121

97

107

118

100

105

4318

1.9

2.4

1.9

2.8

2.3

1.9

2.1

3.7

2.8

3.8

...

2.3

2.6

1.3

1

1.5

1

1.7

1.5

2

2

4319

0.7

0.8

1.8

1

1.9

1.7

2.1

2

2

1.7

...

1.3

1.7

0.7

0.4

1.1

1.4

1.3

1.6

1.8

2

4320 rows × 24 columns

Preprocessing

取需要的數值部分,將 ‘RAINFALL’ 欄位全部補 0。

data[data == 'NR'] = 0# 将‘RAINFALL’字段全部补0,
<ipython-input-4-e27b85bf64ed>:1: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
data[data == 'NR'] = 0# 将‘RAINFALL’字段全部补0,
/usr/local/anaconda3/lib/python3.8/site-packages/pandas/core/frame.py:2986: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
self._where(-key, value, inplace=True)



0

1

2

3

4

5

6

7

8

9

...

14

15

16

17

18

19

20

21

22

23

0

14

14

14

13

12

12

12

12

15

17

...

22

22

21

19

17

16

15

15

15

15

1

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

...

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

1.8

2

0.51

0.41

0.39

0.37

0.35

0.3

0.37

0.47

0.78

0.74

...

0.37

0.37

0.47

0.69

0.56

0.45

0.38

0.35

0.36

0.32

3

0.2

0.15

0.13

0.12

0.11

0.06

0.1

0.13

0.26

0.23

...

0.1

0.13

0.14

0.23

0.18

0.12

0.1

0.09

0.1

0.08

4

0.9

0.6

0.5

1.7

1.8

1.5

1.9

2.2

6.6

7.9

...

2.5

2.2

2.5

2.3

2.1

1.9

1.5

1.6

1.8

1.5

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

4315

1.8

1.8

1.8

1.8

1.8

1.7

1.7

1.8

1.8

1.8

...

1.8

1.8

2

2.1

2

1.9

1.9

1.9

2

2

4316

46

13

61

44

55

68

66

70

66

85

...

59

308

327

21

100

109

108

114

108

109

4317

36

55

72

327

74

52

59

83

106

105

...

18

311

52

54

121

97

107

118

100

105

4318

1.9

2.4

1.9

2.8

2.3

1.9

2.1

3.7

2.8

3.8

...

2.3

2.6

1.3

1

1.5

1

1.7

1.5

2

2

4319

0.7

0.8

1.8

1

1.9

1.7

2.1

2

2

1.7

...

1.3

1.7

0.7

0.4

1.1

1.4

1.3

1.6

1.8

2

4320 rows × 24 columns

raw_data = data.to_numpy()# 将dataframe转换为numpy格式
array([['14', '14', '14', ..., '15', '15', '15'],
['1.8', '1.8', '1.8', ..., '1.8', '1.8', '1.8'],
['0.51', '0.41', '0.39', ..., '0.35', '0.36', '0.32'],
...,
['36', '55', '72', ..., '118', '100', '105'],
['1.9', '2.4', '1.9', ..., '1.5', '2', '2'],
['0.7', '0.8', '1.8', ..., '1.6', '1.8', '2']], dtype=object)

將原始 4320 * 18 的資料依照每個月分重組成 12 個 18 (features) * 480 (hours) 的資料。

Extract Features (1)

month_data = {}
for month in range(12):
sample = np.empty([18, 480])#一天24小时,一个月20天
for day in range(20):
sample[:, day * 24 : (day + 1) * 24] = raw_data[18 * (20 * month + day) : 18 * (20 * month + day + 1), :]
month_data[month] =

李宏毅(2020)作业1-hw1_regression_ide

month_data
{0: array([[14.  , 14.  , 14.  , ..., 14.  , 13.  , 13.  ],
[ 1.8 , 1.8 , 1.8 , ..., 1.8 , 1.8 , 1.8 ],
[ 0.51, 0.41, 0.39, ..., 0.34, 0.41, 0.43],
...,
[35. , 79. , 2.4 , ..., 48. , 63. , 53. ],
[ 1.4 , 1.8 , 1. , ..., 1.1 , 1.9 , 1.9 ],
[ 0.5 , 0.9 , 0.6 , ..., 1.2 , 1.2 , 1.3 ]]),
1: array([[ 15. , 14. , 14. , ..., 8.4 , 8. , 7.6 ],
[ 1.8 , 1.8 , 1.7 , ..., 1.7 , 1.7 , 1.7 ],
[ 0.27, 0.26, 0.25, ..., 0.36, 0.35, 0.32],
...,
[113. , 109. , 104. , ..., 72. , 65. , 69. ],
[ 2.3 , 2.2 , 2.6 , ..., 1.9 , 2.9 , 1.5 ],
[ 2.5 , 2.2 , 2.2 , ..., 0.9 , 1.6 , 1.1 ]]),
2: array([[ 18. , 18. , 18. , ..., 14. , 13. , 13. ],
[ 1.8 , 1.8 , 1.8 , ..., 1.8 , 1.8 , 1.8 ],
[ 0.39, 0.36, 0.4 , ..., 0.42, 0.47, 0.49],
...,
[103. , 128. , 115. , ..., 60. , 94. , 53. ],
[ 1.7 , 1.4 , 1.8 , ..., 4.2 , 3.5 , 4.3 ],
[ 1.9 , 0.8 , 1.5 , ..., 3.1 , 2.4 , 2.4 ]]),
3: array([[ 19. , 18. , 17. , ..., 24. , 24. , 23. ],
[ 1.7 , 1.7 , 1.7 , ..., 1.8 , 1.8 , 1.9 ],
[ 0.42, 0.42, 0.42, ..., 0.41, 0.46, 0.42],
...,
[308. , 308. , 320. , ..., 331. , 261. , 273. ],
[ 1.7 , 2.2 , 2. , ..., 1. , 1. , 0.8 ],
[ 1.5 , 1.5 , 1.2 , ..., 0.6 , 1.1 , 0.9 ]]),
4: array([[1.90e+01, 1.90e+01, 2.00e+01, ..., 2.60e+01, 2.60e+01, 2.50e+01],
[1.80e+00, 1.80e+00, 1.80e+00, ..., 1.60e+00, 1.60e+00, 1.60e+00],
[4.80e-01, 4.70e-01, 4.50e-01, ..., 1.50e-01, 1.50e-01, 1.30e-01],
...,
[2.90e+02, 6.90e+01, 2.50e+02, ..., 1.74e+02, 1.95e+02, 1.69e+02],
[1.50e+00, 1.90e+00, 1.70e+00, ..., 3.10e+00, 3.10e+00, 2.90e+00],
[4.00e-01, 5.00e-01, 1.00e+00, ..., 2.90e+00, 2.40e+00, 3.10e+00]]),
5: array([[2.60e+01, 2.50e+01, 2.50e+01, ..., 2.70e+01, 2.70e+01, 2.80e+01],
[1.70e+00, 1.70e+00, 1.70e+00, ..., 1.60e+00, 1.60e+00, 1.60e+00],
[3.50e-01, 3.40e-01, 3.40e-01, ..., 2.60e-01, 1.90e-01, 1.60e-01],
...,
[1.18e+02, 1.22e+02, 1.19e+02, ..., 1.16e+02, 1.59e+02, 1.62e+02],
[1.60e+00, 1.40e+00, 1.30e+00, ..., 1.70e+00, 1.00e+00, 2.40e+00],
[1.50e+00, 1.50e+00, 1.30e+00, ..., 1.30e+00, 1.30e+00, 1.70e+00]]),
6: array([[2.60e+01, 2.50e+01, 2.60e+01, ..., 2.80e+01, 2.80e+01, 2.80e+01],
[1.60e+00, 1.60e+00, 1.60e+00, ..., 1.60e+00, 1.60e+00, 1.70e+00],
[1.40e-01, 1.30e-01, 1.30e-01, ..., 3.10e-01, 3.00e-01, 2.70e-01],
...,
[1.06e+02, 1.24e+02, 1.17e+02, ..., 1.27e+02, 1.33e+02, 1.72e+02],
[1.60e+00, 1.80e+00, 1.20e+00, ..., 1.60e+00, 1.40e+00, 1.70e+00],
[2.00e+00, 2.20e+00, 1.70e+00, ..., 1.70e+00, 1.30e+00, 1.60e+00]]),
7: array([[2.80e+01, 2.80e+01, 2.80e+01, ..., 2.60e+01, 2.60e+01, 2.60e+01],
[1.60e+00, 1.60e+00, 1.60e+00, ..., 1.70e+00, 1.70e+00, 1.70e+00],
[2.60e-01, 2.00e-01, 1.60e-01, ..., 1.60e-01, 1.40e-01, 1.30e-01],
...,
[2.04e+02, 1.77e+02, 1.72e+02, ..., 1.68e+02, 1.80e+02, 1.62e+02],
[2.90e+00, 2.80e+00, 2.70e+00, ..., 2.90e+00, 2.80e+00, 2.50e+00],
[3.00e+00, 2.80e+00, 2.70e+00, ..., 3.10e+00, 2.90e+00, 2.50e+00]]),
8: array([[ 25. , 25. , 25. , ..., 26. , 26. , 26. ],
[ 1.7 , 1.7 , 1.7 , ..., 1.6 , 1.6 , 1.7 ],
[ 0.28, 0.27, 0.26, ..., 0.28, 0.24, 0.23],
...,
[ 98. , 109. , 108. , ..., 163. , 71. , 55. ],
[ 1.8 , 1.9 , 1.1 , ..., 1.2 , 1.1 , 0.7 ],
[ 1.4 , 1.9 , 1.7 , ..., 3.4 , 1. , 0.7 ]]),
9: array([[ 25. , 25. , 25. , ..., 23. , 22. , 22. ],
[ 1.7 , 1.7 , 1.7 , ..., 1.8 , 1.7 , 1.7 ],
[ 0.24, 0.26, 0.27, ..., 0.42, 0.35, 0.26],
...,
[ 72. , 100. , 68. , ..., 109. , 110. , 107. ],
[ 1.1 , 1.4 , 1.1 , ..., 2.2 , 2.4 , 2.5 ],
[ 1.8 , 1.2 , 0.9 , ..., 2.1 , 2.2 , 2.3 ]]),
10: array([[ 22. , 21. , 21. , ..., 19. , 18. , 18. ],
[ 1.9 , 1.9 , 1.9 , ..., 1.7 , 1.7 , 1.7 ],
[ 0.79, 0.71, 0.61, ..., 0.36, 0.36, 0.37],
...,
[100. , 117. , 110. , ..., 117. , 117. , 114. ],
[ 1.1 , 1.9 , 1.7 , ..., 2.1 , 2.2 , 1.9 ],
[ 0.7 , 1.1 , 1.2 , ..., 1.8 , 2.1 , 1.9 ]]),
11: array([[ 23. , 23. , 23. , ..., 13. , 13. , 13. ],
[ 1.6 , 1.7 , 1.7 , ..., 1.8 , 1.8 , 1.8 ],
[ 0.22, 0.2 , 0.18, ..., 0.51, 0.57, 0.56],
...,
[ 93. , 50. , 99. , ..., 118. , 100. , 105. ],
[ 1.8 , 2.1 , 3.2 , ..., 1.5 , 2. , 2. ],
[ 1.3 , 0.9 , 1. , ..., 1.6 , 1.8 , 2. ]])}

数据维度的变化 4320(12 * 20 * 18)* 24==>12 * 18 * 480(20 * 24)

一个月有480h,每10h为一个框,一共有480-10+1个data

Extract Features (2)

train_data = np.zeros((471*12,18*9))#每1个小时有18个特征,9个小时有18*9个特征
label_data = np.zeros((471*12,1))
for month in range(12):
for i in range(471):
train_data[i+471*month,:] = month_data[month][:,i:i+9].flatten()
label_data[i+471*month,:] = month_data[month][9,i+9]

数据维度变化,12 * 18 * 480(20 * 24)==>5652(471 * 12) * 162(18 * 9)

train_data
array([[14. , 14. , 14. , ...,  2. ,  2. ,  0.5],
[14. , 14. , 13. , ..., 2. , 0.5, 0.3],
[14. , 13. , 12. , ..., 0.5, 0.3, 0.8],
...,
[17. , 18. , 19. , ..., 1.1, 1.4, 1.3],
[18. , 19. , 18. , ..., 1.4, 1.3, 1.6],
[19. , 18. , 17. , ..., 1.3, 1.6, 1.8]])
label_data
array([[30.],
[41.],
[44.],
...,
[17.],
[24.],
[29.]])

Normalize (1)

归一化(x-mean)/std

mean_x = np.mean(train_data, axis = 0) #18 * 9 
std_x = np.std(train_data, axis = 0) #18 * 9
for i in range(len(train_data)): #12 * 471
for j in range(len(train_data[0])): #18 * 9
if std_x[j] != 0:# 除数不为0
train_data[i][j] = (train_data[i][j] - mean_x[j]) / std_x[j]
array([[-1.35825331, -1.35883937, -1.359222  , ...,  0.26650729,
0.2656797 , -1.14082131],
[-1.35825331, -1.35883937, -1.51819928, ..., 0.26650729,
-1.13963133, -1.32832904],
[-1.35825331, -1.51789368, -1.67717656, ..., -1.13923451,
-1.32700613, -0.85955971],
...,
[-0.88092053, -0.72262212, -0.56433559, ..., -0.57693779,
-0.29644471, -0.39079039],
[-0.7218096 , -0.56356781, -0.72331287, ..., -0.29578943,
-0.39013211, -0.1095288 ],
[-0.56269867, -0.72262212, -0.88229015, ..., -0.38950555,
-0.10906991, 0.07797893]])

李宏毅(2020)作业1-hw1_regression_回归_02

Training

Adagrad表达式
李宏毅(2020)作业1-hw1_regression_回归_03
有关Adagrad推导见​​​javascript:void(0)​

dim = 18 * 9 + 1
w = np.zeros([dim, 1])
x = np.concatenate((np.ones([12 * 471, 1]), train_data), axis = 1).astype(float)#将训练数据接1列1
learning_rate = 100
iter_time = 1000
adagrad = np.zeros([dim, 1])# 实现梯度的累加
eps = 0.0000000001
for t in range(iter_time):
loss = np.sqrt(np.sum(np.power(np.dot(x, w) - label_data, 2))/471/12)#rmse #损失函数
if(t%100==0):
print(str(t) + ":" + str(loss))
gradient = 2 * np.dot(x.transpose(), np.dot(x, w) - label_data) #dim*1
adagrad += gradient ** 2
w = w - learning_rate * gradient / np.sqrt(adagrad + eps)
np.save('weight.npy', w)
0:27.071214829194115
100:33.78905859777454
200:19.9137512981971
300:13.531068193689686
400:10.645466158446167
500:9.27735345547506
600:8.518042045956495
700:8.014061987588416
800:7.636756824775688
900:7.336563740371121





array([[ 2.13740269e+01],
[ 3.58888909e+00],
[ 4.56386323e+00],
[ 2.16307023e+00],
[-6.58545223e+00],
[-3.38885580e+01],
[ 3.22235518e+01],
[ 3.49340354e+00],
[-4.60308671e+00],
[-1.02374754e+00],
[-3.96791501e-01],
[-1.06908800e-01],
[ 2.22488184e-01],
[ 8.99634117e-02],
[ 1.31243105e-01],
[ 2.15894989e-02],
[-1.52867263e-01],
[ 4.54087776e-02],
[ 5.20999235e-01],
[ 1.60824213e-01],
[-3.17709451e-02],
[ 1.28529025e-02],
[-1.76839437e-01],
[ 1.71241371e-01],
[-1.31190032e-01],
[-3.51614451e-02],
[ 1.00826192e-01],
[ 3.45018257e-01],
[ 4.00130315e-02],
[ 2.54331382e-02],
[-5.04425219e-01],
[ 3.71483018e-01],
[ 8.46357671e-01],
[-8.11920428e-01],
[-8.00217575e-02],
[ 1.52737711e-01],
[ 2.64915130e-01],
[-5.19860416e-02],
[-2.51988315e-01],
[ 3.85246517e-01],
[ 1.65431451e-01],
[-7.83633314e-02],
[-2.89457231e-01],
[ 1.77615023e-01],
[ 3.22506948e-01],
[-4.59955256e-01],
[-3.48635358e-02],
[-5.81764363e-01],
[-6.43394528e-02],
[-6.32876949e-01],
[ 6.36624507e-02],
[ 8.31592506e-02],
[-4.45157961e-01],
[-2.34526366e-01],
[ 9.86608594e-01],
[ 2.65230652e-01],
[ 3.51938093e-02],
[ 3.07464334e-01],
[-1.04311239e-01],
[-6.49166901e-02],
[ 2.11224757e-01],
[-2.43159815e-01],
[-1.31285604e-01],
[ 1.09045810e+00],
[-3.97913710e-02],
[ 9.19563678e-01],
[-9.44824150e-01],
[-5.04137735e-01],
[ 6.81272939e-01],
[-1.34494828e+00],
[-2.68009542e-01],
[ 4.36204342e-02],
[ 1.89619513e+00],
[-3.41873873e-01],
[ 1.89162461e-01],
[ 1.73251268e-02],
[ 3.14431930e-01],
[-3.40828467e-01],
[ 4.92385651e-01],
[ 9.29634214e-02],
[-4.50983589e-01],
[ 1.47456584e+00],
[-3.03417236e-02],
[ 7.71229328e-02],
[ 6.38314494e-01],
[-7.93287087e-01],
[ 8.82877506e-01],
[ 3.18965610e+00],
[-5.75671706e+00],
[ 1.60748945e+00],
[ 1.36142440e+01],
[ 1.50029111e-01],
[-4.78389603e-02],
[-6.29463755e-02],
[-2.85383032e-02],
[-3.01562821e-01],
[ 4.12058013e-01],
[-6.77534154e-02],
[-1.00985479e-01],
[-1.68972973e-01],
[ 1.64093233e+00],
[ 1.89670371e+00],
[ 3.94713816e-01],
[-4.71231449e+00],
[-7.42760774e+00],
[ 6.19781936e+00],
[ 3.53986244e+00],
[-9.56245861e-01],
[-1.04372792e+00],
[-4.92863713e-01],
[ 6.31608790e-01],
[-4.85175956e-01],
[ 2.58400216e-01],
[ 9.43846795e-02],
[-1.29323184e-01],
[-3.81235287e-01],
[ 3.86819479e-01],
[ 4.04211627e-01],
[ 3.75568914e-01],
[ 1.83512261e-01],
[-8.01417708e-02],
[-3.10188597e-01],
[-3.96124612e-01],
[ 3.66227853e-01],
[ 1.79488593e-01],
[-3.14477051e-01],
[-2.37611443e-01],
[ 3.97076104e-02],
[ 1.38775912e-01],
[-3.84015069e-02],
[-5.47557119e-02],
[ 4.19975207e-01],
[ 4.46120687e-01],
[-4.31074826e-01],
[-8.74450768e-02],
[-5.69534264e-02],
[-7.23980157e-02],
[-1.39880128e-02],
[ 1.40489658e-01],
[-2.44952334e-01],
[ 1.83646770e-01],
[-1.64135512e-01],
[-7.41216452e-02],
[-9.71414213e-02],
[ 1.98829041e-02],
[-4.46965919e-01],
[-2.63440959e-01],
[ 1.52924043e-01],
[ 6.52532847e-02],
[ 7.06818266e-01],
[ 9.73757051e-02],
[-3.35687787e-01],
[-2.26559165e-01],
[-3.00117086e-01],
[ 1.24185231e-01],
[ 4.18872344e-01],
[-2.51891946e-01],
[-1.29095731e-01],
[-5.57512471e-01],
[ 8.76239582e-02],
[ 3.02594902e-01],
[-4.23463160e-01],
[ 4.89922051e-01]])
len(w)
163

Testing

載入 test data,並且以相似於訓練資料預先處理和特徵萃取的方式處理,使 test data 形成 240 個維度為 18 * 9 + 1 的資料。

# testdata = pd.read_csv('gdrive/My Drive/hw1-regression/test.csv', header = None, encoding = 'big5')
testdata = pd.read_csv('./test.csv', header = None, encoding = 'big5')
test_data = testdata.iloc[:, 2:]
test_data[test_data == 'NR'] = 0
test_data = test_data.to_numpy()
test_x = np.empty([240, 18*9], dtype = float)
for i in range(240):
test_x[i, :] = test_data[18 * i: 18* (i + 1), :].reshape(1, -1)
for i in range(len(test_x)):
for j in range(len(test_x[0])):
if std_x[j] != 0:
test_x[i][j] = (test_x[i][j] - mean_x[j]) / std_x[j]
test_x = np.concatenate((np.ones([240, 1]), test_x), axis = 1).astype(float)
<ipython-input-14-cd2718394ff0>:4: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
test_data[test_data == 'NR'] = 0
/usr/local/anaconda3/lib/python3.8/site-packages/pandas/core/frame.py:2986: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
self._where(-key, value, inplace=True)





array([[ 1. , -0.24447681, -0.24545919, ..., -0.67065391,
-1.04594393, 0.07797893],
[ 1. , -1.35825331, -1.51789368, ..., 0.17279117,
-0.10906991, -0.48454426],
[ 1. , 1.5057434 , 1.34508393, ..., -1.32666675,
-1.04594393, -0.57829812],
...,
[ 1. , 0.3919669 , 0.54981237, ..., 0.26650729,
-0.20275731, 1.20302531],
[ 1. , -1.8355861 , -1.8360023 , ..., -1.04551839,
-1.13963133, -1.14082131],
[ 1. , -1.35825331, -1.35883937, ..., 2.98427476,
3.26367657, 1.76554849]])

Prediction

有了 weight 和測試資料即可預測 target。

w = np.load('weight.npy')
ans_y = np.dot(test_x, w)
array([[ 5.17496040e+00],
[ 1.83062143e+01],
[ 2.04912181e+01],
[ 1.15239429e+01],
[ 2.66160568e+01],
[ 2.05313481e+01],
[ 2.19065510e+01],
[ 3.17364687e+01],
[ 1.33916741e+01],
[ 6.44564665e+01],
[ 2.02645688e+01],
[ 1.53585761e+01],
[ 6.85894728e+01],
[ 4.84281137e+01],
[ 1.87023338e+01],
[ 1.01885957e+01],
[ 3.07403629e+01],
[ 7.11322178e+01],
[-4.13051739e+00],
[ 1.82356940e+01],
[ 3.85789223e+01],
[ 7.13115197e+01],
[ 7.41034816e+00],
[ 1.87179553e+01],
[ 1.49372503e+01],
[ 3.67197367e+01],
[ 1.79616970e+01],
[ 7.57894629e+01],
[ 1.23093102e+01],
[ 5.62953517e+01],
[ 2.51131609e+01],
[ 4.61024867e+00],
[ 2.48377055e+00],
[ 2.47594223e+01],
[ 3.04802805e+01],
[ 3.84639307e+01],
[ 4.42023106e+01],
[ 3.00868360e+01],
[ 4.04736750e+01],
[ 2.92264799e+01],
[ 5.60645605e+00],
[ 3.86660161e+01],
[ 3.46102134e+01],
[ 4.83896975e+01],
[ 1.47572477e+01],
[ 3.44668201e+01],
[ 2.74831069e+01],
[ 1.20008794e+01],
[ 2.13780362e+01],
[ 2.85444031e+01],
[ 2.01655138e+01],
[ 1.07966781e+01],
[ 2.21710358e+01],
[ 5.34462631e+01],
[ 1.22195811e+01],
[ 4.33009685e+01],
[ 3.21823351e+01],
[ 2.25672175e+01],
[ 5.67395142e+01],
[ 2.07450529e+01],
[ 1.50288546e+01],
[ 3.98553016e+01],
[ 1.29753407e+01],
[ 5.17416596e+01],
[ 1.87833696e+01],
[ 1.23487528e+01],
[ 1.56336237e+01],
[-5.88714707e-02],
[ 4.15080111e+01],
[ 3.15487475e+01],
[ 1.86042512e+01],
[ 3.74768197e+01],
[ 5.65203907e+01],
[ 6.58787719e+00],
[ 1.22293397e+01],
[ 5.20369640e+00],
[ 4.79273751e+01],
[ 1.30207057e+01],
[ 1.71103017e+01],
[ 2.06032345e+01],
[ 2.12844816e+01],
[ 3.86929353e+01],
[ 3.00207167e+01],
[ 8.87674067e+01],
[ 3.59847002e+01],
[ 2.67569136e+01],
[ 2.39635168e+01],
[ 3.27472428e+01],
[ 2.21890438e+01],
[ 2.09921589e+01],
[ 2.95559943e+01],
[ 4.09921689e+01],
[ 8.62511781e+00],
[ 3.23214718e+01],
[ 4.65980444e+01],
[ 2.28840708e+01],
[ 3.15181297e+01],
[ 1.11982335e+01],
[ 2.85274366e+01],
[ 2.91150680e-01],
[ 1.79669611e+01],
[ 2.71241639e+01],
[ 1.13982328e+01],
[ 1.64264269e+01],
[ 2.34252610e+01],
[ 4.06160827e+01],
[ 2.58641250e+01],
[ 5.42273695e+00],
[ 1.07949211e+01],
[ 7.28621369e+01],
[ 4.80228371e+01],
[ 1.57468083e+01],
[ 2.46704106e+01],
[ 1.28277933e+01],
[ 1.01580576e+01],
[ 2.72692233e+01],
[ 2.92087386e+01],
[ 8.83533962e+00],
[ 2.00510881e+01],
[ 2.02123337e+01],
[ 7.99060093e+01],
[ 1.80616143e+01],
[ 3.05428093e+01],
[ 2.59807924e+01],
[ 5.21257727e+00],
[ 3.03556973e+01],
[ 7.76832289e+00],
[ 1.53282683e+01],
[ 2.26663657e+01],
[ 6.27420542e+01],
[ 1.89507804e+01],
[ 1.90763556e+01],
[ 6.13715741e+01],
[ 1.58845621e+01],
[ 1.34094181e+01],
[ 8.48772484e-01],
[ 7.83499672e+00],
[ 5.70128290e+01],
[ 2.56079968e+01],
[ 4.96170473e+00],
[ 3.64148790e+01],
[ 2.87900067e+01],
[ 4.91941210e+01],
[ 4.03068699e+01],
[ 1.33161806e+01],
[ 2.76610119e+01],
[ 1.71580275e+01],
[ 4.96872626e+01],
[ 2.30302723e+01],
[ 3.92409365e+01],
[ 1.31967539e+01],
[ 5.94889370e+00],
[ 2.58216090e+01],
[ 8.25863421e+00],
[ 1.91463205e+01],
[ 4.31824865e+01],
[ 6.71784358e+00],
[ 3.38696152e+01],
[ 1.53699378e+01],
[ 1.69390450e+01],
[ 3.78853368e+01],
[ 1.92024845e+01],
[ 9.05950472e+00],
[ 1.02833996e+01],
[ 4.86724471e+01],
[ 3.05877162e+01],
[ 2.47740990e+00],
[ 1.28116039e+01],
[ 7.03247898e+01],
[ 1.48409677e+01],
[ 6.88655876e+01],
[ 4.27419924e+01],
[ 2.40002615e+01],
[ 2.34207249e+01],
[ 6.16721244e+01],
[ 2.54942028e+01],
[ 1.90048098e+01],
[ 3.48866829e+01],
[ 9.40231340e+00],
[ 2.95200113e+01],
[ 1.45739659e+01],
[ 9.12556314e+00],
[ 5.28125840e+01],
[ 4.50395380e+01],
[ 1.74524347e+01],
[ 3.84939353e+01],
[ 2.70389191e+01],
[ 6.55817097e+01],
[ 7.03730638e+00],
[ 5.27144771e+01],
[ 3.82064593e+01],
[ 2.11698011e+01],
[ 3.02475569e+01],
[ 2.71442299e+00],
[ 1.99329326e+01],
[-3.41333234e+00],
[ 3.24459994e+01],
[ 1.05829730e+01],
[ 2.17752257e+01],
[ 6.24652921e+01],
[ 2.41329437e+01],
[ 2.62012396e+01],
[ 6.37444772e+01],
[ 2.83429777e+00],
[ 1.43792470e+01],
[ 9.36985073e+00],
[ 9.88116661e+00],
[ 3.49494536e+00],
[ 1.22608049e+02],
[ 2.10835130e+01],
[ 1.75322206e+01],
[ 2.01830983e+01],
[ 3.63931322e+01],
[ 3.49351512e+01],
[ 1.88303127e+01],
[ 3.83445555e+01],
[ 7.79166341e+01],
[ 1.79532355e+00],
[ 1.34458279e+01],
[ 3.61311556e+01],
[ 1.51504035e+01],
[ 1.29418483e+01],
[ 1.13125241e+02],
[ 1.52246047e+01],
[ 1.48240260e+01],
[ 5.92673537e+01],
[ 1.05836953e+01],
[ 2.09930626e+01],
[ 9.78936588e+00],
[ 4.77118001e+00],
[ 4.79278069e+01],
[ 1.23994384e+01],
[ 4.81464766e+01],
[ 4.04663804e+01],
[ 1.69405903e+01],
[ 4.12665445e+01],
[ 6.90278920e+01],
[ 4.03462492e+01],
[ 1.43137440e+01],
[ 1.57707266e+01]])

Save Prediction to CSV File

import csv
with open('submit.csv', mode='w', newline='') as submit_file:
csv_writer = csv.writer(submit_file)
header = ['id', 'value']
print(header)
csv_writer.writerow(header)
for i in range(240):
row = ['id_' + str(i), ans_y[i][0]]
csv_writer.writerow(row)
print(row)
['id', 'value']
['id_0', 5.174960398984744]
['id_1', 18.306214253527884]
['id_2', 20.491218094180525]
['id_3', 11.523942869805381]
['id_4', 26.616056752306168]
['id_5', 20.53134808176121]
['id_6', 21.906551018797394]
['id_7', 31.736468747068834]
['id_8', 13.391674055111721]
['id_9', 64.45646650291957]
['id_10', 20.264568836159427]
['id_11', 15.35857607736121]
['id_12', 68.58947276926725]
['id_13', 48.428113747457175]
['id_14', 18.702333824193207]
['id_15', 10.188595737466702]
['id_16', 30.74036285982042]
['id_17', 71.13221776355113]
['id_18', -4.130517391262456]
['id_19', 18.23569401642868]
['id_20', 38.578922275007756]
['id_21', 71.3115197253133]
['id_22', 7.410348162634051]
['id_23', 18.717955330321388]
['id_24', 14.937250260084564]
['id_25', 36.719736694705304]
['id_26', 17.961697005662693]
['id_27', 75.7894628721054]
['id_28', 12.309310248614443]
['id_29', 56.29535173964958]
['id_30', 25.1131608656615]
['id_31', 4.610248674094034]
['id_32', 2.483770554515047]
['id_33', 24.759422261321255]
['id_34', 30.48028046559117]
['id_35', 38.46393074642665]
['id_36', 44.202310609330034]
['id_37', 30.08683601986599]
['id_38', 40.473675015740085]
['id_39', 29.22647990231738]
['id_40', 5.606456054343944]
['id_41', 38.66601607878964]
['id_42', 34.610213431877206]
['id_43', 48.389697507384795]
['id_44', 14.757247666944167]
['id_45', 34.46682011087209]
['id_46', 27.48310687418435]
['id_47', 12.000879378154032]
['id_48', 21.378036151603766]
['id_49', 28.54440309166329]
['id_50', 20.16551381841159]
['id_51', 10.79667814974648]
['id_52', 22.17103575575012]
['id_53', 53.44626310935226]
['id_54', 12.219581121610041]
['id_55', 43.30096845517151]
['id_56', 32.18233510328543]
['id_57', 22.56721751457082]
['id_58', 56.73951416554705]
['id_59', 20.745052945295463]
['id_60', 15.028854557473274]
['id_61', 39.85530159038511]
['id_62', 12.975340680728308]
['id_63', 51.74165959283005]
['id_64', 18.783369632539817]
['id_65', 12.3487528427777]
['id_66', 15.633623653541882]
['id_67', -0.05887147068501619]
['id_68', 41.50801107307595]
['id_69', 31.548747530656005]
['id_70', 18.604251157547075]
['id_71', 37.47681972488073]
['id_72', 56.52039065762305]
['id_73', 6.587877193521939]
['id_74', 12.229339737435012]
['id_75', 5.203696404134661]
['id_76', 47.9273751038006]
['id_77', 13.020705685594685]
['id_78', 17.110301693903608]
['id_79', 20.603234531002034]
['id_80', 21.2844815607846]
['id_81', 38.69293529051177]
['id_82', 30.020716675725858]
['id_83', 88.7674066672355]
['id_84', 35.9847002396683]
['id_85', 26.756913553477204]
['id_86', 23.963516843564435]
['id_87', 32.74724282808307]
['id_88', 22.189043755319926]
['id_89', 20.99215885362657]
['id_90', 29.555994316645425]
['id_91', 40.9921688665178]
['id_92', 8.62511780991153]
['id_93', 32.32147180887788]
['id_94', 46.59804436536766]
['id_95', 22.884070826723523]
['id_96', 31.518129728251658]
['id_97', 11.198233479766134]
['id_98', 28.5274366425296]
['id_99', 0.2911506800896202]
['id_100', 17.966961079539693]
['id_101', 27.124163929470157]
['id_102', 11.398232780652837]
['id_103', 16.42642686567352]
['id_104', 23.425261046922188]
['id_105', 40.616082670568396]
['id_106', 25.864125026560373]
['id_107', 5.422736951672377]
['id_108', 10.794921122256106]
['id_109', 72.86213692992125]
['id_110', 48.02283705948139]
['id_111', 15.746808276902982]
['id_112', 24.670410614177953]
['id_113', 12.827793326536712]
['id_114', 10.158057570240523]
['id_115', 27.269223342020968]
['id_116', 29.208738577932426]
['id_117', 8.835339619930693]
['id_118', 20.051088137129724]
['id_119', 20.212333743764255]
['id_120', 79.90600929870558]
['id_121', 18.061614288263613]
['id_122', 30.542809341304373]
['id_123', 25.980792377728058]
['id_124', 5.212577268164768]
['id_125', 30.355697305856225]
['id_126', 7.768322888914636]
['id_127', 15.328268255393361]
['id_128', 22.666365717697936]
['id_129', 62.742054211090064]
['id_130', 18.95078036798801]
['id_131', 19.07635563083852]
['id_132', 61.37157409163706]
['id_133', 15.88456205262969]
['id_134', 13.409418077705537]
['id_135', 0.8487724836112776]
['id_136', 7.834996717304136]
['id_137', 57.01282901179681]
['id_138', 25.607996751813808]
['id_139', 4.961704729242088]
['id_140', 36.41487903906275]
['id_141', 28.79000672197592]
['id_142', 49.19412096197634]
['id_143', 40.306869855734476]
['id_144', 13.316180593982693]
['id_145', 27.661011875229143]
['id_146', 17.158027524366748]
['id_147', 49.687262569296834]
['id_148', 23.030272291604792]
['id_149', 39.24093652484275]
['id_150', 13.19675388941252]
['id_151', 5.948893701039445]
['id_152', 25.821608976304244]
['id_153', 8.25863421429164]
['id_154', 19.14632051722559]
['id_155', 43.18248652651674]
['id_156', 6.71784357809301]
['id_157', 33.869615246810646]
['id_158', 15.369937846981856]
['id_159', 16.93904497355191]
['id_160', 37.885336794634846]
['id_161', 19.20248454105439]
['id_162', 9.059504715654704]
['id_163', 10.283399610648479]
['id_164', 48.672447125698284]
['id_165', 30.587716213230777]
['id_166', 2.4774098975321523]
['id_167', 12.811603937805945]
['id_168', 70.32478980976462]
['id_169', 14.840967694067032]
['id_170', 68.8655875667886]
['id_171', 42.74199244486633]
['id_172', 24.000261542920157]
['id_173', 23.420724860321418]
['id_174', 61.67212443568235]
['id_175', 25.494202845059192]
['id_176', 19.00480978686905]
['id_177', 34.886682881896846]
['id_178', 9.402313398379732]
['id_179', 29.520011314408023]
['id_180', 14.573965885700478]
['id_181', 9.125563143203582]
['id_182', 52.81258399813189]
['id_183', 45.03953799438963]
['id_184', 17.452434679183284]
['id_185', 38.49393527971429]
['id_186', 27.03891909264383]
['id_187', 65.58170967424581]
['id_188', 7.037306380769593]
['id_189', 52.71447713411571]
['id_190', 38.20645933704978]
['id_191', 21.169801059557862]
['id_192', 30.2475568794884]
['id_193', 2.7144229897163115]
['id_194', 19.932932587640817]
['id_195', -3.41333233760389]
['id_196', 32.44599940281314]
['id_197', 10.582973029979915]
['id_198', 21.775225707258457]
['id_199', 62.465292065677914]
['id_200', 24.132943687316452]
['id_201', 26.201239647400975]
['id_202', 63.74447723440289]
['id_203', 2.834297774129027]
['id_204', 14.379246986978885]
['id_205', 9.369850731753857]
['id_206', 9.881166613595404]
['id_207', 3.4949453589721333]
['id_208', 122.60804937921782]
['id_209', 21.08351301448058]
['id_210', 17.532220599455105]
['id_211', 20.18309834459702]
['id_212', 36.39313221228185]
['id_213', 34.93515120529068]
['id_214', 18.830312661458635]
['id_215', 38.34455552272334]
['id_216', 77.91663413807039]
['id_217', 1.7953235508882095]
['id_218', 13.445827939135793]
['id_219', 36.13115559041213]
['id_220', 15.150403498166291]
['id_221', 12.94184833441792]
['id_222', 113.1252409378639]
['id_223', 15.224604677934337]
['id_224', 14.824025968612105]
['id_225', 59.26735368854046]
['id_226', 10.583695290718495]
['id_227', 20.993062563532213]
['id_228', 9.789365880830392]
['id_229', 4.771180008705969]
['id_230', 47.92780690481286]
['id_231', 12.399438394751026]
['id_232', 48.14647656264414]
['id_233', 40.46638039656414]
['id_234', 16.940590270332933]
['id_235', 41.26654448941873]
['id_236', 69.027892033729]
['id_237', 40.34624924412242]
['id_238', 14.31374398287113]
['id_239', 15.770726634219834]