728x90
참고 코드:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
class KMeans(object):
def __init__(self, k):
self.k = k
def fit(self, X, initial_centroid_index=None, max_iters=10, seed=16, plt_process=False):
m, n = X.shape
# 특별히 지정한 중심점이 없으면, 중심점을 랜덤으로 초기화
if initial_centroid_index is None:
np.random.seed(seed)
initial_centroid_index = np.random.randint(0, m, self.k)
centroid = X[initial_centroid_index, :]
idx = None
plt.ion()
for i in range(max_iters):
# 중심점을 기점으로 샘플을 분류
idx = self.find_closest_centroids(X, centroid)
if plt_process:
self.plot_converge(X, idx, initial_centroid_index)
# 중심점 다시 계산
centroid = self.compute_centroids(X, idx)
plt.ioff()
plt.show()
return centroid, idx
def find_closest_centroids(self, X, centroid):
distance = np.sum((X[:, np.newaxis, :] - centroid) ** 2, axis=2)
idx = distance.argmin(axis=1)
return idx
def compute_centroids(self, X, idx):
centroids = np.zeros((self.k, X.shape[1]))
for i in range(self.k):
centroids[i, :] = np.mean(X[idx == i], axis=0)
return centroids
def plot_converge(self, X, idx, initial_idx):
plt.cla()
plt.title("k-meas converge process")
plt.xlabel('density')
plt.ylabel('sugar content')
plt.scatter(X[:, 0], X[:, 1], c='lightcoral')
plt.scatter(X[initial_idx, 0], X[initial_idx, 1], label='initial center', c='k')
for i in range(self.k):
X_i = X[idx == i]
hull = ConvexHull(X_i).vertices.tolist()
hull.append(hull[0])
plt.plot(X_i[hull, 0], X_i[hull, 1], 'c--')
plt.legend()
plt.pause(0.5)
if __name__ == '__main__':
data = np.loadtxt('..\data\watermelon4_0_Ch.txt', delimiter=', ')
centroid, idx = KMeans(3).fit(data, plt_process=True, seed=24)
'단단한 머신러닝' 카테고리의 다른 글
[단단한 머신러닝 - 연습문제 참고 답안]Chapter11 - 특성 선택과 희소 학습 11.1 (0) | 2023.07.15 |
---|---|
[단단한 머신러닝 - 연습문제 참고 답안]Chapter9 - 클러스터링 9.5 (0) | 2022.01.23 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter9 - 클러스터링 9.3 (0) | 2022.01.23 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter9 - 클러스터링 9.2 (0) | 2022.01.23 |
[단단한 머신러닝 - 연습문제 참고 답안]Chapter10 - 차원 축소와 척도 학습 10.1 (0) | 2021.09.28 |