Skip to content

Commit

Permalink
Better visualization of the k-means
Browse files Browse the repository at this point in the history
  • Loading branch information
percolator committed Nov 20, 2024
1 parent d72d0e7 commit 70e79d8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 36 deletions.
75 changes: 42 additions & 33 deletions dsbook/unsupervised/cluster.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
"cells": [
{
"cell_type": "markdown",
"id": "4f4047a3",
"id": "7982a0ae",
"metadata": {},
"source": [
"# Unsupervised Machine Learning\n",
"\n",
"[//]: # \"Add maths for GMMs\"\n",
"\n",
"[//]: # \"Coloring of e-step i k-means\"\n",
"\n",
"\n",
"## Introduction\n",
"\n",
"Unsupervised machine learning aims to learn patterns from data without predefined labels. Specifically, the goal is to learn a function $f(x)$ from the dataset $D = \\{\\mathbf{x}_i\\}$ by optimizing an objective function $g(D, f)$, or by simply partitioning the dataset $D$. This chapter provides an overview of clustering methods, which are a core part of unsupervised machine learning.\n",
Expand Down Expand Up @@ -44,7 +49,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "037925ef",
"id": "82aaff85",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -57,13 +62,14 @@
"# Generate sample data\n",
"X, y_true = make_blobs(n_samples=300, centers=5, cluster_std=0.60, random_state=1)\n",
"\n",
"# Function to perform one iteration of the k-Means EM step and plot in specified axes\n",
"def plot_kmeans_step(ax, X, centers, step_title):\n",
" labels = pairwise_distances_argmin(X, centers)\n",
" ax.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='plasma')\n",
" sns.scatterplot(x=centers[:, 0], y=centers[:, 1], color='black', s=200, alpha=0.5, ax=ax)\n",
"# Function to plot data points with cluster centers\n",
"def plot_kmeans_step(ax, X, centers, labels=None, step_title=\"\"):\n",
" if labels is not None:\n",
" ax.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='plasma', label='Data Points')\n",
" else:\n",
" ax.scatter(X[:, 0], X[:, 1], color='gray', s=50, alpha=0.6, label='Data Points')\n",
" sns.scatterplot(x=centers[:, 0], y=centers[:, 1], color='black', s=200, alpha=0.5, ax=ax, label='Cluster Centers')\n",
" ax.set_title(step_title)\n",
" return labels\n",
"\n",
"# Initialize the plot with a 4x2 grid (4 steps per column for E and M steps)\n",
"fig, axes = plt.subplots(4, 2, figsize=(12, 16))\n",
Expand All @@ -73,35 +79,38 @@
"rng = np.random.RandomState(1)\n",
"i = rng.permutation(X.shape[0])[:5]\n",
"centers = X[i]\n",
"plot_kmeans_step(axes[0, 1], X, centers, \"Initial Random Cluster Centers\")\n",
"plot_kmeans_step(axes[0, 1], X, centers, step_title=\"Initial Random Cluster Centers\")\n",
"\n",
"# Step 2: First E-Step (assign points to the nearest cluster center)\n",
"labels = plot_kmeans_step(axes[1, 0], X, centers, \"First E-Step: Assign Points to Nearest Cluster\")\n",
"labels = pairwise_distances_argmin(X, centers)\n",
"plot_kmeans_step(axes[1, 0], X, centers, labels, \"First E-Step: Assign Points to Nearest Cluster\")\n",
"\n",
"# Step 3: First M-Step (recalculate cluster centers)\n",
"centers = np.array([X[labels == i].mean(0) for i in range(5)])\n",
"plot_kmeans_step(axes[1, 1], X, centers, \"First M-Step: Recalculate Cluster Centers\")\n",
"centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])\n",
"plot_kmeans_step(axes[1, 1], X, centers, labels, step_title=\"First M-Step: Recalculate Cluster Centers\")\n",
"\n",
"# Step 4: Second E-Step (assign points to the nearest cluster center)\n",
"labels = plot_kmeans_step(axes[2, 0], X, centers, \"Second E-Step: Assign Points to Nearest Cluster\")\n",
"labels = pairwise_distances_argmin(X, centers)\n",
"plot_kmeans_step(axes[2, 0], X, centers, labels, \"Second E-Step: Assign Points to Nearest Cluster\")\n",
"\n",
"# Step 5: Second M-Step (recalculate cluster centers)\n",
"centers = np.array([X[labels == i].mean(0) for i in range(5)])\n",
"plot_kmeans_step(axes[2, 1], X, centers, \"Second M-Step: Recalculate Cluster Centers\")\n",
"centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])\n",
"plot_kmeans_step(axes[2, 1], X, centers, labels, step_title=\"Second M-Step: Recalculate Cluster Centers\")\n",
"\n",
"# Step 6: Third E-Step (assign points to the nearest cluster center)\n",
"labels = plot_kmeans_step(axes[3, 0], X, centers, \"Third E-Step: Assign Points to Nearest Cluster\")\n",
"labels = pairwise_distances_argmin(X, centers)\n",
"plot_kmeans_step(axes[3, 0], X, centers, labels, \"Third E-Step: Assign Points to Nearest Cluster\")\n",
"\n",
"# Step 7: Third M-Step (recalculate cluster centers)\n",
"centers = np.array([X[labels == i].mean(0) for i in range(5)])\n",
"plot_kmeans_step(axes[3, 1], X, centers, \"Third M-Step: Recalculate Cluster Centers\")\n",
"centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])\n",
"plot_kmeans_step(axes[3, 1], X, centers, labels, step_title=\"Third M-Step: Recalculate Cluster Centers\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "03a642ed",
"id": "684b06f7",
"metadata": {},
"source": [
"The algorithm automatically assigns the points to clusters, and we can see that it closely matches what we would expect by visual inspection.\n",
Expand All @@ -116,7 +125,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "4f27bb9c",
"id": "c2c293fa",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -151,7 +160,7 @@
},
{
"cell_type": "markdown",
"id": "fd35e3cf",
"id": "e9fb9c38",
"metadata": {},
"source": [
"### Drawbacks of k-Means\n",
Expand All @@ -165,7 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "2d9fc73d",
"id": "7ea2c4b5",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -178,7 +187,7 @@
},
{
"cell_type": "markdown",
"id": "a3cd8fa8",
"id": "e9713f76",
"metadata": {},
"source": [
"3. **Linear Cluster Boundaries**: The k-Means algorithm assumes that clusters are spherical and separated by linear boundaries. It struggles with complex geometries. Consider the following dataset with two crescent-shaped clusters:"
Expand All @@ -187,7 +196,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1cff529e",
"id": "47b982f9",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -201,7 +210,7 @@
},
{
"cell_type": "markdown",
"id": "f3df3d1f",
"id": "f292fb4c",
"metadata": {},
"source": [
"4. **Differences in euclidian size**: K-Means assumes that the cluster sizes, in terms of euclidian distance to its borders, are fairly similar for all clusters.\n",
Expand All @@ -211,7 +220,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "3361d509",
"id": "5419aa2a",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -245,7 +254,7 @@
},
{
"cell_type": "markdown",
"id": "b74ea11e",
"id": "18e9ef79",
"metadata": {},
"source": [
"## Multivariate Normal Distribution\n",
Expand All @@ -270,7 +279,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "fad6e6bc",
"id": "690cfecc",
"metadata": {
"tags": [
"hide-input"
Expand Down Expand Up @@ -321,7 +330,7 @@
},
{
"cell_type": "markdown",
"id": "e6b279c0",
"id": "350f1a5a",
"metadata": {},
"source": [
"## Gaussian Mixture Models (GMM)\n",
Expand All @@ -339,7 +348,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "23ac57a2",
"id": "0b302c2b",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -411,7 +420,7 @@
},
{
"cell_type": "markdown",
"id": "1e083d68",
"id": "23ea42ef",
"metadata": {},
"source": [
"Points near the cluster boundaries have lower certainty, reflected in smaller marker sizes.\n",
Expand All @@ -422,7 +431,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ee9a2d9f",
"id": "7a827b7c",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -442,7 +451,7 @@
},
{
"cell_type": "markdown",
"id": "34c7affe",
"id": "e715ca59",
"metadata": {},
"source": [
"GMM is able to model more complex, elliptical cluster boundaries, addressing one of the main limitations of k-Means.\n",
Expand Down
6 changes: 3 additions & 3 deletions dsbook/unsupervised/cluster.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,23 @@ plot_kmeans_step(axes[1, 0], X, centers, labels, "First E-Step: Assign Points to
# Step 3: First M-Step (recalculate cluster centers)
centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])
plot_kmeans_step(axes[1, 1], X, centers, step_title="First M-Step: Recalculate Cluster Centers")
plot_kmeans_step(axes[1, 1], X, centers, labels, step_title="First M-Step: Recalculate Cluster Centers")
# Step 4: Second E-Step (assign points to the nearest cluster center)
labels = pairwise_distances_argmin(X, centers)
plot_kmeans_step(axes[2, 0], X, centers, labels, "Second E-Step: Assign Points to Nearest Cluster")
# Step 5: Second M-Step (recalculate cluster centers)
centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])
plot_kmeans_step(axes[2, 1], X, centers, step_title="Second M-Step: Recalculate Cluster Centers")
plot_kmeans_step(axes[2, 1], X, centers, labels, step_title="Second M-Step: Recalculate Cluster Centers")
# Step 6: Third E-Step (assign points to the nearest cluster center)
labels = pairwise_distances_argmin(X, centers)
plot_kmeans_step(axes[3, 0], X, centers, labels, "Third E-Step: Assign Points to Nearest Cluster")
# Step 7: Third M-Step (recalculate cluster centers)
centers = np.array([X[labels == i].mean(axis=0) for i in range(len(centers))])
plot_kmeans_step(axes[3, 1], X, centers, step_title="Third M-Step: Recalculate Cluster Centers")
plot_kmeans_step(axes[3, 1], X, centers, labels, step_title="Third M-Step: Recalculate Cluster Centers")
plt.show()
```
Expand Down

0 comments on commit 70e79d8

Please sign in to comment.