without haste but without rest
12. Decision Tree 본문
0. 개요
(1). impurity (불순도)
- entropy
- gini
(2). pruning (가지치기)
1. data load
## 의사결정나무
## iris data
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
y = iris.target
X = iris.data[:, 2:]
feature_names = iris.feature_names[2:]
2. partitioning
#파티션 기준 엔트로피, 브랜치 길이 1
tree = DecisionTreeClassifier(criterion = 'entropy', max_depth = 1, random_state = 0)
tree.fit(X, y)
3. draw Decision Tree
# 디시전 트리 그리기
# conda install pydot
import io
import pydot
from IPython.core.display import Image
from sklearn.tree import export_graphviz
def draw_decision_tree(model):
dot_buf = io.StringIO()
export_graphviz(model, out_file = dot_buf, feature_names = feature_names)
graph = pydot.graph_from_dot_data(dot_buf.getvalue())[0]
image = graph.create_png()
return Image(image)
3-1 max_depth = 1
#파티션 기준 엔트로피, 브랜치 길이 1
tree = DecisionTreeClassifier(criterion = 'entropy', max_depth = 1, random_state = 0)
tree.fit(X, y)
draw_decision_tree(tree)
petal width 속성을 기준으로 첫번째 값을 정확히 분류해낸다. 깊이를 추가하면 2, 3번째 값을 분리하기 위해서 오른쪽 노드에서 뻗어나갈 것이다.
3-2 max_depth = 2
tree = DecisionTreeClassifier(criterion = 'entropy', max_depth = 2, random_state = 0)
tree.fit(X, y)
draw_decision_tree(tree)
다음으로 다시 petal width 속성을 기준으로 분류했는데 5개의 오답과 2개의 오답이 생겼다.
Cancer data
1. data load & split train, test
## cancer data
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
cancer = load_breast_cancer()
print(cancer.DESCR)
print(cancer.data.shape)
X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target,
test_size = 0.1,
stratify = cancer.target,
random_state = 42)
feature_names = cancer.feature_names
2. partitioning & draw desicion tree
2-1 max_depth = 1
tree = DecisionTreeClassifier(criterion = 'entropy', max_depth = 1, random_state = 0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))
draw_decision_tree(tree)
worst perimeter 속성을 기준으로 분류한다. 왼쪽 자식 노드에서 17개의 오답, 오른쪽 자식 노드에서 24개의 오답이 나왔다. 트레이닝 데이터 대비 테스트 데이터 성능이 잘 나온다.
2-2 max_depth = 2
## max depth = 2
tree = DecisionTreeClassifier(criterion = 'entropy', max_depth = 2, random_state = 0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))
draw_decision_tree(tree)
깊이를 2로 늘려서 다시 의사결정 트리를 생성하면 위와 같다. 단 이 모델의 경우 트레이닝 데이터 대비 성능이 깊이가 1인 모델보다 좋지 않다. 또한 단말 노드의 분류를 보면 불순도(impurity)가 더 낮아지지 않는 경향을 보인다. 따라서 가지치기(pruning)을 해야한다.
3. gini
3-1 max_depth = 1
# gini 계수
tree = DecisionTreeClassifier(criterion = 'gini', max_depth = 1, random_state = 0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))
draw_decision_tree(tree)
불순도 기준을 지니 계수를 사용한 모델이다. 엔트로피를 사용한 모델보다 테스트 데이터에서 성능이 비교적 더 높은 것을 확인할 수 있고, 엔트로피와는 다른 속성인 worst radius를 루트 노드의 기준으로 사용한다. 꽤 준수한 성능을 보여주고 있다.
3-2 max_depth = 2
# gini 계수
tree = DecisionTreeClassifier(criterion = 'gini', max_depth = 2, random_state = 0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))
draw_decision_tree(tree)
깊이를 더 늘렸다. 지니 계수를 사용해도 테스트 데이터에서 성능이 감소하는 것을 확인할 수 있다. 그래도 엔트로피 보다는 스코어가 잘 나오긴 한다. 단말 노드에서 왼쪽부터 시작해서 1, 4번째 노드는 분류를 잘하는 편인데, 3, 4번째 노드의 불순도가 높다.
여기서 모델의 단순함과 적당한 스코어를 추구한다면 가지치기를 하면되고, 모델이 더 많이 복잡해져도 스코어를 추구한다면 깊이를 늘릴 수 있다.
3-3 max_depth = 3
트리의 깊이를 3으로 확장하면 위와 같다. 단말 노드들을 확인하면 굳이 불순도를 계산할 필요도 없이 분류가 잘 되었구나 확인할 수 있다. 그런데 깊이가 1인 모델에 비해서 고작 1~2% 의 정확도를 얻는다. 득에 비해서 많이 복잡해지는 모델이다.
'Homework > DataMining' 카테고리의 다른 글
14. Deployment - pickle (0) | 2020.06.21 |
---|---|
13. Cross validation & Pipeline (0) | 2020.06.21 |
11. Logistic regression - deep learning (0) | 2020.06.02 |
10. OLS, SGD (0) | 2020.05.26 |
09. Clustering - dbscan, spectal (0) | 2020.05.19 |