728x90

'''
사이킷런에서 제공하는 iris 데이터 사용
'''

import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np
import treeCreater
import treePlottter


iris = datasets.load_iris()
X = pd.DataFrame(iris['data'], columns=iris['feature_names'])
y = pd.Series(iris['target_names'][iris['target']])

# 3개의 샘플을 취해 테스트세트로 사용
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=15)

# 남은 120개 샘플 중에서 30개를 검정세트로 사용
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=15)


# 가지치기 않함
tree_no_pruning = treeCreater.DecisionTree('gini')
tree_no_pruning.fit(X_train, y_train, X_val, y_val)
print('가지치기 없음:', np.mean(tree_no_pruning.predict(X_test) == y_test))
# treePlottter.create_plot(tree_no_pruning.tree_)

# 사전 가지치기
tree_pre_pruning = treeCreater.DecisionTree('gini', 'pre_pruning')
tree_pre_pruning.fit(X_train, y_train, X_val, y_val)
print('사전 가지치기:', np.mean(tree_pre_pruning.predict(X_test) == y_test))
# treePlottter.create_plot(tree_pre_pruning.tree_)

# 사후 가지치기
tree_post_pruning = treeCreater.DecisionTree('gini', 'post_pruning')
tree_post_pruning.fit(X_train, y_train, X_val, y_val)
print('사후 가지치기:', np.mean(tree_post_pruning.predict(X_test) == y_test))
# treePlottter.create_plot(tree_post_pruning.tree_)

 

참고 답안: https://blog.csdn.net/red_stone1/article/details/106110620

+ Recent posts