Skip to content

Commit

Permalink
Updated plotting script
Browse files Browse the repository at this point in the history
  • Loading branch information
percolator committed Nov 20, 2024
1 parent 041271a commit d72d0e7
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions dsbook/unsupervised/cluster.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ import matplotlib.pyplot as plt
# Generate sample data
X, y_true = make_blobs(n_samples=300, centers=5, cluster_std=0.60, random_state=1)
# Function to perform one iteration of the k-Means EM step and plot in specified axes
def plot_kmeans_step(ax, X, centers, step_title):
labels = pairwise_distances_argmin(X, centers)
ax.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='plasma')
sns.scatterplot(x=centers[:, 0], y=centers[:, 1], color='black', s=200, alpha=0.5, ax=ax)
# Function to plot data points with cluster centers
def plot_kmeans_step(ax, X, centers, labels=None, step_title=""):
if labels is not None:
ax.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='plasma', label='Data Points')
else:
ax.scatter(X[:, 0], X[:, 1], color='gray', s=50, alpha=0.6, label='Data Points')
sns.scatterplot(x=centers[:, 0], y=centers[:, 1], color='black', s=200, alpha=0.5, ax=ax, label='Cluster Centers')
ax.set_title(step_title)
return labels
# Initialize the plot with a 4x2 grid (4 steps per column for E and M steps)
fig, axes = plt.subplots(4, 2, figsize=(12, 16))
Expand All @@ -75,28 +76,31 @@ fig.tight_layout(pad=5)
rng = np.random.RandomState(1)
i = rng.permutation(X.shape[0])[:5]
centers = X[i]
plot_kmeans_step(axes[0, 1], X, centers, "Initial Random Cluster Centers")
plot_kmeans_step(axes[0, 1], X, centers, step_title="Initial Random Cluster Centers")
# Step 2: First E-Step (assign points to the nearest cluster center)
labels = plot_kmeans_step(axes[1, 0], X, centers, "First E-Step: Assign Points to Nearest Cluster")
labels = pairwise_distances_argmin(X, centers)
plot_kmeans_step(axes[1, 0], X, centers, labels, "First E-Step: Assign Points to Nearest Cluster")
# Step 3: First M-Step (recalculate cluster centers)
centers = np.array([X[labels == i].mean(0) for i in range(5)])
plot_kmeans_step(axes[1, 1], X, centers, "First M-Step: Recalculate Cluster Centers")
centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])
plot_kmeans_step(axes[1, 1], X, centers, step_title="First M-Step: Recalculate Cluster Centers")
# Step 4: Second E-Step (assign points to the nearest cluster center)
labels = plot_kmeans_step(axes[2, 0], X, centers, "Second E-Step: Assign Points to Nearest Cluster")
labels = pairwise_distances_argmin(X, centers)
plot_kmeans_step(axes[2, 0], X, centers, labels, "Second E-Step: Assign Points to Nearest Cluster")
# Step 5: Second M-Step (recalculate cluster centers)
centers = np.array([X[labels == i].mean(0) for i in range(5)])
plot_kmeans_step(axes[2, 1], X, centers, "Second M-Step: Recalculate Cluster Centers")
centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])
plot_kmeans_step(axes[2, 1], X, centers, step_title="Second M-Step: Recalculate Cluster Centers")
# Step 6: Third E-Step (assign points to the nearest cluster center)
labels = plot_kmeans_step(axes[3, 0], X, centers, "Third E-Step: Assign Points to Nearest Cluster")
labels = pairwise_distances_argmin(X, centers)
plot_kmeans_step(axes[3, 0], X, centers, labels, "Third E-Step: Assign Points to Nearest Cluster")
# Step 7: Third M-Step (recalculate cluster centers)
centers = np.array([X[labels == i].mean(0) for i in range(5)])
plot_kmeans_step(axes[3, 1], X, centers, "Third M-Step: Recalculate Cluster Centers")
centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])
plot_kmeans_step(axes[3, 1], X, centers, step_title="Third M-Step: Recalculate Cluster Centers")
plt.show()
```
Expand Down

0 comments on commit d72d0e7

Please sign in to comment.