2020/04/26
RからPythonへの道(14)
3月、4月と会社の仕事が超多忙で怒涛の日々でした。4月に入って、新型コロナウィルスによる在宅勤務対応等で労務管理時間が激増。ゆっくり息を付く余裕もありませんでした。そのままGWに入って、ほぼ自宅監禁状態なんでしょうね。ストレスをうまく発散しつつ、GWはインドアで何かのんびりと勉強、自己啓発でもしようかなと気持ちを切り替えようとしています。今回は「13. 決定木(分類)(1)」について、RとPythonで計算していきたいと思います。教材はネット上で多く引用されている定番のirisのデータセットです。
まずは、Rのコードです。irisのデータを読み込んで、13〜16行目で学習データ、評価データを7:3に分けました。18行目で学習データを用いてrpart関数で決定木で分類学習し、結果を2種類のグラフに描画しました。その後、25行目で評価データを代入し、28行目以降で性能評価をしました。
# Decision Tree : iris classificationRのコードの結果は以下の通りです。8行目のグラフは以下の通り。赤色、緑色、青色のプロット点はそれぞれ、setosa、versicolor、virginicaを表しています。また、21行目のグラフは以下の通りで、まず、Petal.Lengthが2.45未満ならsetosaで、2.45以上ならversicolorかvirginicaの2つに分類し、次に、versicolorとvirginicaをPetal.Length4.85のしきい値で分類しています。22行目のグラフは以下の通りです。上のグラフと基本は同じですが、分類内訳が数字で表されています。分類された真ん中のversicolorの数字の「0 34 1」は順番にsetosa、versicolor、virginicaの数で、versicolorと判断された中で、34本はversicolor(正解)であり、1本はvirginica、つまり誤判定されたということですね。今回の場合、学習させても100点満点の分類はできていませんがそのまま進んでいます。
library(rpart)
library(partykit)
library(rpart.plot)
head(iris)
# Pairs graph
plot(iris[,1:4], col=c(2,3,4)[iris$Species])
# Analysis
set.seed(100)
df.rows = nrow(iris)
train.rate = 0.7 # training data rate
train.index = sample(df.rows, df.rows * train.rate)
df.train = iris[train.index,] # training data
df.test = iris[-train.index,] # test data
cat("train=", nrow(df.train), "test=", nrow(df.test))
model.rpart = rpart(Species~., data = df.train) # decition tree model
# Graph
plot(as.party(model.rpart)) # pattern 1
rpart.plot(model.rpart , type = 4, extra = 1, digits = 3) # pattern 2
# Predict
pred.rpart = predict(model.rpart, df.test, type = "class")
# Result : cross tabulation table
result = table(pred.rpart, df.test$Species)
print(result)
# Calculation the accuracy
accuracy_prediction = sum(diag(result)) / sum(result)
print(accuracy_prediction)
5行目29行目のクロス集計表の対角成分の19、12、11は正しく分類された本数で、対角成分外の数字の2と1の3本は誤分類された結果でした。正解率は「対角成分(正解数)の和」を「クロス集計表の数字の総和」で割った値なので93.3%ですね。
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa
17行目
train= 105 test= 45
29行目
pred.rpart setosa versicolor virginica
setosa 19 0 0
versicolor 0 12 2
virginica 0 1 11
33行目
[1] 0.9333333
次に、Pythonのコードです。流れはRのコードと同じです。
# Decision Tree : iris classificationPythonコードの結果は以下の通りです。23〜24行目のPairプロットグラフは以下の通りです。38行目の学習データから計算された分類結果のグラフは以下の通りです。初めのsetosaとそれ以外の分類がRと異なり、Petal widthで行っていますが、それは学習に使ったデータが全く同じでないためと考えます。グラフ中の「gini」は「ジニ不純度」のことで、簡単に言うと、サンプルの中に異なるクラスのものがどの程度含まれているかを表す指標で、以下の式で表されます。一番上の34、32、39の105を例に取ると、以下の計算です。
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
import pydotplus
from sklearn.externals.six import StringIO
from IPython.display import Image
# Read data
iris = load_iris()
df = pd.DataFrame(iris.data)
df.columns = ['Sepal length', 'Sepal width', 'Petal length', 'Petal width']
df['Species'] = iris.target
df.loc[df['Species'] == 0, 'Species'] = "setosa"
df.loc[df['Species'] == 1, 'Species'] = "versicolor"
df.loc[df['Species'] == 2, 'Species'] = "virginica"
print(df.head())
# Pairs graph
g = sns.pairplot(df,hue='Species')
plt.show()
# Analysis
X = df[['Sepal length', 'Sepal width', 'Petal length', 'Petal width']]
y = df['Species']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
print("train=",str(len(X_train)), "test=", str(len(X_test)))
clf = tree.DecisionTreeClassifier(max_depth=2) # decision tree
clf = clf.fit(X_train, y_train)
# Graph
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf('DT_result.pdf')
# Predict
pred = clf.predict(X_test)
# Result : cross tabulation table
print(pd.crosstab(pred, y_test))
# Calculation the accuracy
print(sum(pred == y_test) / len(y_test))
IG = 1 - (34/105)^2 - (32/105)^2 - (39/105)^2 = 0.6643084
20行目正解率は91.1です。Rの結果でもそうでしたが、うまく分類できているのではないかと思いました。用途によっては、学習時のパラメータを過学習しない程度に調整は必要かもしれません。
Sepal length Sepal width Petal length Petal width Species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
30行目
train= 105 test= 45
44行目
Species setosa versicolor virginica
row_0
setosa 16 0 0
versicolor 0 17 3
virginica 0 1 8
47行目
0.9111111111111111
次回も決定木で分類を行う予定です。
『RからPythonへの道』バックナンバー
(1) はじめに
(2) 0. 実行環境(作業環境)
(3) 1. PythonからRを使う方法 2. RからPythonを使う方法
(4) 3. データフレーム
(5) 4. ggplot
(6) 5.行列
(7) 6.基本統計量
(8) 7. 回帰分析(単回帰)
(9) 8. 回帰分析(重回帰)
(10) 9. 回帰分析(ロジスティック回帰1)
(11) 10. 回帰分析(ロジスティック回帰2)
(12) 11. 回帰分析(リッジ、ラッソ回帰)
(13) 12. 回帰分析(多項式回帰)