'''
사이킷런에서 제공하는 iris 데이터 사용
'''
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np
import treeCreater
import treePlottter
iris = datasets.load_iris()
X = pd.DataFrame(iris['data'], columns=iris['feature_names'])
y = pd.Series(iris['target_names'][iris['target']])
# 3개의 샘플을 취해 테스트세트로 사용
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=15)
# 남은 120개 샘플 중에서 30개를 검정세트로 사용
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=15)
# 가지치기 않함
tree_no_pruning = treeCreater.DecisionTree('gini')
tree_no_pruning.fit(X_train, y_train, X_val, y_val)
print('가지치기 없음:', np.mean(tree_no_pruning.predict(X_test) == y_test))
# treePlottter.create_plot(tree_no_pruning.tree_)
# 사전 가지치기
tree_pre_pruning = treeCreater.DecisionTree('gini', 'pre_pruning')
tree_pre_pruning.fit(X_train, y_train, X_val, y_val)
print('사전 가지치기:', np.mean(tree_pre_pruning.predict(X_test) == y_test))
# treePlottter.create_plot(tree_pre_pruning.tree_)
# 사후 가지치기
tree_post_pruning = treeCreater.DecisionTree('gini', 'post_pruning')
tree_post_pruning.fit(X_train, y_train, X_val, y_val)
print('사후 가지치기:', np.mean(tree_post_pruning.predict(X_test) == y_test))
# treePlottter.create_plot(tree_post_pruning.tree_)
참고 답안: https://blog.csdn.net/red_stone1/article/details/106110620
'단단한 머신러닝' 카테고리의 다른 글
[단단한 머신러닝 - 연습문제 참고 답안]Chapter4 - 의사결정 트리 4.10 (0) | 2021.08.23 |
---|---|
[단단한 머신러닝 - 연습문제 참고 답안]Chapter4 - 의사결정 트리 4.9 (0) | 2021.08.23 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter4 - 의사결정 트리 4.6 (2) | 2021.04.17 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter4 - 의사결정 트리 4.3~4.5 (0) | 2021.04.17 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter4 - 의사결정 트리 4.2 (0) | 2021.04.17 |