2020年04月 - 考えるエンジニア
FC2ブログ

RからPythonへの道(14)

 3月、4月と会社の仕事が超多忙で怒涛の日々でした。4月に入って、新型コロナウィルスによる在宅勤務対応等で労務管理時間が激増。ゆっくり息を付く余裕もありませんでした。そのままGWに入って、ほぼ自宅監禁状態なんでしょうね。ストレスをうまく発散しつつ、GWはインドアで何かのんびりと勉強、自己啓発でもしようかなと気持ちを切り替えようとしています。

 今回は「13. 決定木(分類)(1)」について、RとPythonで計算していきたいと思います。教材はネット上で多く引用されている定番のirisのデータセットです。

 まずは、Rのコードです。irisのデータを読み込んで、13〜16行目で学習データ、評価データを7:3に分けました。18行目で学習データを用いてrpart関数で決定木で分類学習し、結果を2種類のグラフに描画しました。その後、25行目で評価データを代入し、28行目以降で性能評価をしました。
# Decision Tree : iris classification
library(rpart)
library(partykit)
library(rpart.plot)
head(iris)

# Pairs graph
plot(iris[,1:4], col=c(2,3,4)[iris$Species])

# Analysis
set.seed(100)
df.rows = nrow(iris)
train.rate = 0.7 # training data rate
train.index = sample(df.rows, df.rows * train.rate)
df.train = iris[train.index,] # training data
df.test = iris[-train.index,] # test data
cat("train=", nrow(df.train), "test=", nrow(df.test))
model.rpart = rpart(Species~., data = df.train) # decition tree model

# Graph
plot(as.party(model.rpart)) # pattern 1
rpart.plot(model.rpart , type = 4, extra = 1, digits = 3) # pattern 2

# Predict
pred.rpart = predict(model.rpart, df.test, type = "class")

# Result : cross tabulation table
result = table(pred.rpart, df.test$Species)
print(result)

# Calculation the accuracy
accuracy_prediction = sum(diag(result)) / sum(result)
print(accuracy_prediction)
 Rのコードの結果は以下の通りです。8行目のグラフは以下の通り。赤色、緑色、青色のプロット点はそれぞれ、setosa、versicolor、virginicaを表しています。iris_R_plot_200426.pngまた、21行目のグラフは以下の通りで、まず、Petal.Lengthが2.45未満ならsetosaで、2.45以上ならversicolorかvirginicaの2つに分類し、次に、versicolorとvirginicaをPetal.Length4.85のしきい値で分類しています。iris_R_res_graph1_200426.png22行目のグラフは以下の通りです。上のグラフと基本は同じですが、分類内訳が数字で表されています。分類された真ん中のversicolorの数字の「0 34 1」は順番にsetosa、versicolor、virginicaの数で、versicolorと判断された中で、34本はversicolor(正解)であり、1本はvirginica、つまり誤判定されたということですね。iris_R_res_graph2_200426.png今回の場合、学習させても100点満点の分類はできていませんがそのまま進んでいます。
5行目
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa

17行目
train= 105 test= 45

29行目
pred.rpart setosa versicolor virginica
setosa 19 0 0
versicolor 0 12 2
virginica 0 1 11

33行目
[1] 0.9333333
 29行目のクロス集計表の対角成分の19、12、11は正しく分類された本数で、対角成分外の数字の2と1の3本は誤分類された結果でした。正解率は「対角成分(正解数)の和」を「クロス集計表の数字の総和」で割った値なので93.3%ですね。

 次に、Pythonのコードです。流れはRのコードと同じです。
# Decision Tree : iris classification
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
import pydotplus
from sklearn.externals.six import StringIO
from IPython.display import Image

# Read data
iris = load_iris()
df = pd.DataFrame(iris.data)
df.columns = ['Sepal length', 'Sepal width', 'Petal length', 'Petal width']
df['Species'] = iris.target
df.loc[df['Species'] == 0, 'Species'] = "setosa"
df.loc[df['Species'] == 1, 'Species'] = "versicolor"
df.loc[df['Species'] == 2, 'Species'] = "virginica"
print(df.head())

# Pairs graph
g = sns.pairplot(df,hue='Species')
plt.show()

# Analysis
X = df[['Sepal length', 'Sepal width', 'Petal length', 'Petal width']]
y = df['Species']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
print("train=",str(len(X_train)), "test=", str(len(X_test)))
clf = tree.DecisionTreeClassifier(max_depth=2) # decision tree
clf = clf.fit(X_train, y_train)

# Graph
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf('DT_result.pdf')

# Predict
pred = clf.predict(X_test)

# Result : cross tabulation table
print(pd.crosstab(pred, y_test))

# Calculation the accuracy
print(sum(pred == y_test) / len(y_test))
 Pythonコードの結果は以下の通りです。23〜24行目のPairプロットグラフは以下の通りです。iris_Python_plot_200426.png38行目の学習データから計算された分類結果のグラフは以下の通りです。初めのsetosaとそれ以外の分類がRと異なり、Petal widthで行っていますが、それは学習に使ったデータが全く同じでないためと考えます。iris_Python_res_graph_200426.pngグラフ中の「gini」は「ジニ不純度」のことで、簡単に言うと、サンプルの中に異なるクラスのものがどの程度含まれているかを表す指標で、以下の式で表されます。geni_impurity_200426.png一番上の34、32、39の105を例に取ると、以下の計算です。
 IG = 1 - (34/105)^2 - (32/105)^2 - (39/105)^2 = 0.6643084
20行目
Sepal length Sepal width Petal length Petal width Species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa

30行目
train= 105 test= 45

44行目
Species setosa versicolor virginica
row_0
setosa 16 0 0
versicolor 0 17 3
virginica 0 1 8

47行目
0.9111111111111111
正解率は91.1です。Rの結果でもそうでしたが、うまく分類できているのではないかと思いました。用途によっては、学習時のパラメータを過学習しない程度に調整は必要かもしれません。

次回も決定木で分類を行う予定です。


『RからPythonへの道』バックナンバー
(1) はじめに
(2) 0. 実行環境(作業環境)
(3) 1. PythonからRを使う方法 2. RからPythonを使う方法
(4) 3. データフレーム
(5) 4. ggplot
(6) 5.行列
(7) 6.基本統計量
(8) 7. 回帰分析(単回帰)
(9) 8. 回帰分析(重回帰)
(10) 9. 回帰分析(ロジスティック回帰1)
(11) 10. 回帰分析(ロジスティック回帰2)
(12) 11. 回帰分析(リッジ、ラッソ回帰)
(13) 12. 回帰分析(多項式回帰)

今が耐え時

 3月は公私ともに慌しく、ブログのアップデートをしないまま、1か月経ってしまいました。明日あたりにブログページが広告ページに化ける前に投稿しました。

 コロナ疲れなんでしょうね。自律神経のバランスも良くないです。緊急事態宣言の対象エリアではありませんが、会社は在宅勤務にシフトしています。その準備で大変です。

 今が耐え時なんですよね。閉塞感があり、モチベーションが下がり気味ですが、みんな耐えているんですよね。ただ、考えていても答えが出るようなものでないので、しばらくは「考えないエンジニア」に徹しようと思います。You Tubeで「キッサコ」の般若心経でも聴きながら早く寝ようかな。

ご訪問者数

(Since 24 July, 2016)

タグクラウド


プロフィール

Dr.BobT

Author: Dr.BobT
興味のおもむくままに生涯考え続けるエンジニアでありたい。

月別アーカイブ

メールフォーム

名前:
メール:
件名:
本文: