投稿

Pythonで決定木分析 Decision Tree

scikit-learntree Package使ってみました。

決定木はとてもシンプルで特に可視化と合わせると人に説明するのに便利です。予測はもっと複雑なモデルがいいと思いますが、分析して方向性を決めようみたいな話はこちらの方が他の人の腹落ち感は高いイメージです。

決定木はグラフを出力するのに苦労した記憶がありますが、scikit-learnにある「plot_tree」を使うととても簡単だったので計算から可視化までの流れを共有します。
    目次
  • 1: データ入手
  • 2: 計算
  • 3: 可視化
  • 4: 最後に

データ入手

決定木は教師あり学習になるのでデータと答えが必要になります。
今回はscikit-learnにあるiris(アヤメ)のデータを使用します。
sklearn.datasets.load_iris」ですね。
# データ取得
from sklearn.datasets import load_iris
iris = load_iris()
中にはdataとして、sepalとpetalの長さと幅が入っています。要は形状ですね。
他にtargetとしてアヤメ3種類のどの種類なのかがあります。

計算

では計算していきましょう。
sklearn.tree.DecisionTreeClassifier」を使用します。
# treeをインポート
from sklearn import tree

# 計算
clf = tree.DecisionTreeClassifier() #定義
iristree = clf.fit(iris.data, iris.target) #計算。データと結果を引数とする
一応精度を確認しておきましょう。
Methodの中にscoreというのがありますね。
# 精度確認
iristree.score(iris.data, iris.target)
1.0

100%ということですが、訓練データで精度確認しているかもですね。
実際はデータを分離しておいてから検証しましょう。

可視化

今回やってみて一番感動したのは可視化が簡単だったことです。
sklearn.tree.plot_tree」を使います。まずはそのまま
※図が小さかったのでdpi指定しました。
# 可視化
# Packageのインポート
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

# 表示
plt.figure(dpi=150)
plot_tree(iristree)
plt.show()

これだとわからないので見た目を整えます。パラメーターはいろいろいじってみて自分のわかりやすいようにしました。
# 表示
plt.figure(dpi=150)
plot_tree(iristree, feature_names=iris.feature_names, 
          class_names=iris.target_names, 
          filled=True,
          proportion=True)
plt.show()
ちょっとうるさい感じもするので階層を削って3つにしてみます。
# 計算
clf = tree.DecisionTreeClassifier(max_depth=3) #定義
iristree = clf.fit(iris.data, iris.target) #計算。データと結果を引数とする

# 表示
plt.figure(dpi=150)
plot_tree(iristree, feature_names=iris.feature_names, 
          class_names=iris.target_names, 
          filled=True,
          proportion=True)
plt.show()
いいですね。このくらいならパッと見せて説明するのに適しているかと思います。
Petal widthが最初の分かれ道ですね。ここで33.3%がSetosaだということです。

最後に

可視化するとどこで分岐するかがとてもわかりやすいですよね。
マーケティング施策の方向性検討であったり、説明にはとても有益だと思います。決定木は他に回帰にも使えるということですが、その場合はもっと精度を追ったモデル化をした方がいいかと思いました。