728x90
3.5 수박 데이터 세트 3.0𝛼를 사용해 선형 판별분석에 대한 코드를 작성하고 결과를 기술하라.
참고 답안 코드 (1):
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
class LDA(object):
def fit(self, X_, y_, plot_=False):
pos = y_ == 1
neg = y_ == 0
X0 = X_[neg]
X1 = X_[pos]
u0 = X0.mean(0, keepdims=True) # (1, n)
u1 = X1.mean(0, keepdims=True)
sw = np.dot((X0 - u0).T, X0 - u0) + np.dot((X1 - u1).T, X1 - u1)
w = np.dot(np.linalg.inv(sw), (u0 - u1).T).reshape(1, -1) # (1, n)
if plot_:
fig, ax = plt.subplots()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['left'].set_position(('data', 0))
ax.spines['bottom'].set_position(('data', 0))
plt.scatter(X1[:, 0], X1[:, 1], c='k', marker='o', label='good')
plt.scatter(X0[:, 0], X0[:, 1], c='r', marker='x', label='bad')
plt.xlabel('밀도', labelpad=1)
plt.ylabel('당도')
plt.legend(loc='upper right')
x_tmp = np.linspace(-0.05, 0.15)
y_tmp = x_tmp * w[0, 1] / w[0, 0]
plt.plot(x_tmp, y_tmp, '#808080', linewidth=1)
wu = w / np.linalg.norm(w)
X0_project = np.dot(X0, np.dot(wu.T, wu))
plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
for i in range(X0.shape[0]):
plt.plot([X0[i, 0], X0_project[i, 0]], [X0[i, 1], X0_project[i, 1]], '--r', linewidth=1)
X1_project = np.dot(X1, np.dot(wu.T, wu))
plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
for i in range(X1.shape[0]):
plt.plot([X1[i, 0], X1_project[i, 0]], [X1[i, 1], X1_project[i, 1]], '--k', linewidth=1)
u0_project = np.dot(u0, np.dot(wu.T, wu))
plt.scatter(u0_project[:, 0], u0_project[:, 1], c='#FF4500', s=60)
u1_project = np.dot(u1, np.dot(wu.T, wu))
plt.scatter(u1_project[:, 0], u1_project[:, 1], c='#696969', s=60)
ax.annotate(r'u0 투영 포인트',
xy=(u0_project[:, 0], u0_project[:, 1]),
xytext=(u0_project[:, 0] - 0.2, u0_project[:, 1] - 0.1),
size=13,
va="center", ha="left",
arrowprops=dict(arrowstyle="->",
color="k",
)
)
ax.annotate(r'u1 투영 포인트',
xy=(u1_project[:, 0], u1_project[:, 1]),
xytext=(u1_project[:, 0] - 0.1, u1_project[:, 1] + 0.1),
size=13,
va="center", ha="left",
arrowprops=dict(arrowstyle="->",
color="k",
)
)
plt.axis("equal")
plt.show()
self.w = w
self.u0 = u0
self.u1 = u1
return self
def predict(self, X):
project = np.dot(X, self.w.T)
wu0 = np.dot(self.w, self.u0.T)
wu1 = np.dot(self.w, self.u1.T)
return (np.abs(project - wu1) < np.abs(project - wu0)).astype(int)
if __name__ == '__main__':
#data 경로는 사용자에 맞게 바꿔줍니다
data_path = r'C:\Users\hanmi\Documents\xiguabook\watermelon3_0_Ch.csv'
data = pd.read_csv(data_path).values
X = data[:, 7:9].astype(float)
y = data[:, 9]
y[y == 'yes'] = 1
y[y == 'no'] = 0
y = y.astype(int)
lda = LDA()
lda.fit(X, y, plot_=True)
print(lda.predict(X))
print(y)
참고 답안 코드 (2):
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def LDA(X0, X1):
"""
Get the optimal params of LDA model given training data.
Input:
X0: np.array with shape [N1, d]
X1: np.array with shape [N2, d]
Return:
omega: np.array with shape [1, d]. Optimal params of LDA.
"""
#shape [1, d]
mean0 = np.mean(X0, axis=0, keepdims=True)
mean1 = np.mean(X1, axis=0, keepdims=True)
Sw = (X0-mean0).T.dot(X0-mean0) + (X1-mean1).T.dot(X1-mean1)
omega = np.linalg.inv(Sw).dot((mean0-mean1).T)
return omega
if __name__=="__main__":
#read data from xls
work_book = pd.read_csv("watermelon_3a.csv", header=None)
positive_data = work_book.values[work_book.values[:, -1] == 1.0, :]
negative_data = work_book.values[work_book.values[:, -1] == 0.0, :]
print (positive_data)
#LDA
omega = LDA(negative_data[:, 1:-1], positive_data[:, 1:-1])
#plot
plt.plot(positive_data[:, 1], positive_data[:, 2], "bo")
plt.plot(negative_data[:, 1], negative_data[:, 2], "r+")
lda_left = 0
lda_right = -(omega[0]*0.9) / omega[1]
plt.plot([0, 0.9], [lda_left, lda_right], 'g-')
plt.xlabel('density')
plt.ylabel('sugar rate')
plt.title("LDA")
plt.show()
————————————————
source::https://blog.csdn.net/weixin_43518584/article/details/105588310
'단단한 머신러닝' 카테고리의 다른 글
[단단한 머신러닝 - 연습문제 참고 답안]Chapter3 - 선형 모델 3.7 (0) | 2021.04.12 |
---|---|
[단단한 머신러닝 - 연습문제 참고 답안]Chapter3 - 선형 모델 3.6 (0) | 2021.04.12 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter3 - 선형 모델 3.3 (0) | 2021.03.28 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter3 - 선형 모델 3.1 - 3.2 (0) | 2021.03.28 |
[단단한 머신러닝 - 연습문제 참고 답안] Chapter2 모델 평가 및 선택 (1) | 2021.03.28 |