Skip to content

Commit

Permalink
add note
Browse files Browse the repository at this point in the history
add kmeans  digits pca clustering
  • Loading branch information
math4mad committed Nov 3, 2023
1 parent aad4ca2 commit ffa2905
Show file tree
Hide file tree
Showing 5 changed files with 1,902 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"hash": "5cf6d9f94cae9ff1ac54ec601259637d",
"result": {
"markdown": "---\ntitle: \"1-kmeans-on-digits\"\nauthor: math4mad\ncode-fold: false\n---\n\n:::{.callout-note title=\"简介\"}\n **routine: project data to 2d space then proceed kmeans methods**\n\n 1. ref: [K-Means clustering on the handwritten digits data](https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html)\n 2. [kmeans in mlj](https://alan-turing-institute.github.io/MLJ.jl/dev/models/KMeans_Clustering/#KMeans_Clustering)\n 3. decision boundary:[Prediction](https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/examples/support-vector-machine/#Prediction)\n \n:::\n\n## 1. load package\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\n include(\"../utils.jl\")\n import MLJ: fit!, predict, transform,fitted_params\n using CSV, DataFrames, GLMakie, Random\n using MLJ\n Random.seed!(34343)\n```\n\n::: {.cell-output .cell-output-display execution_count=48}\n```\nTaskLocalRNG()\n```\n:::\n:::\n\n\n## 2. load data \n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n load_csv(str::AbstractString) =\n str |> d -> CSV.File(\"./data/$str.csv\") |> DataFrame |> dropmissing\n\n digits = load_csv(\"scikit_digits\")\n digits = coerce(digits, :target => Multiclass)\n y, X = unpack(digits, ==(:target); rng = 123);\n```\n:::\n\n\n## 3. MLJ workflow\n\n### 3.1 load model\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n PCA = @load PCA pkg = MultivariateStats\n KMeans = @load KMeans pkg = Clustering\n pca_model = PCA(; maxoutdim = 2)\n kmeans_model = KMeans(; k =9)\n\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nimport MLJMultivariateStatsInterface ✔\nimport MLJClusteringInterface ✔\n```\n:::\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=50}\n```\nKMeans(\n k = 9, \n metric = Distances.SqEuclidean(0.0), \n init = :kmpp)\n```\n:::\n:::\n\n\n### 3.2 usa pca model project data to 2d space\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\n pca_mach = machine(pca_model, X) |> fit!\n Xproj = transform(pca_mach, X)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: Training machine(PCA(maxoutdim = 2, …), …).\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=51}\n```{=html}\n<div><div style = \"float: left;\"><span>1797×2 DataFrame</span></div><div style = \"float: right;\"><span style = \"font-style: italic;\">1772 rows omitted</span></div><div style = \"clear: both;\"></div></div><div class = \"data-frame\" style = \"overflow-x: scroll;\"><table class = \"data-frame\" style = \"margin-bottom: 6px;\"><thead><tr class = \"header\"><th class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">Row</th><th style = \"text-align: left;\">x1</th><th style = \"text-align: left;\">x2</th></tr><tr class = \"subheader headerLastRow\"><th class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\"></th><th title = \"Float64\" style = \"text-align: left;\">Float64</th><th title = \"Float64\" style = \"text-align: left;\">Float64</th></tr></thead><tbody><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1</td><td style = \"text-align: right;\">2.53068</td><td style = \"text-align: right;\">7.10818</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">2</td><td style = \"text-align: right;\">-7.06737</td><td style = \"text-align: right;\">2.69455</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">3</td><td style = \"text-align: right;\">9.71838</td><td style = \"text-align: right;\">16.8406</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">4</td><td style = \"text-align: right;\">2.73069</td><td style = \"text-align: right;\">10.0099</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">5</td><td style = \"text-align: right;\">-13.1702</td><td style = \"text-align: right;\">-13.6058</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">6</td><td style = \"text-align: right;\">-13.2824</td><td style = \"text-align: right;\">8.86115</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">7</td><td style = \"text-align: right;\">15.9065</td><td style = \"text-align: right;\">-16.5451</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">8</td><td style = \"text-align: right;\">-17.972</td><td style = \"text-align: right;\">-7.82244</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">9</td><td style = \"text-align: right;\">-5.2641</td><td style = \"text-align: right;\">9.30009</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">10</td><td style = \"text-align: right;\">-11.3988</td><td style = \"text-align: right;\">-5.1612</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">11</td><td style = \"text-align: right;\">-0.171222</td><td style = \"text-align: right;\">-5.94507</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">12</td><td style = \"text-align: right;\">6.49767</td><td style = \"text-align: right;\">-25.2878</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">13</td><td style = \"text-align: right;\">-27.8753</td><td style = \"text-align: right;\">-5.61827</td></tr><tr><td style = \"text-align: right;\">&vellip;</td><td style = \"text-align: right;\">&vellip;</td><td style = \"text-align: right;\">&vellip;</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1786</td><td style = \"text-align: right;\">18.7481</td><td style = \"text-align: right;\">-2.69052</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1787</td><td style = \"text-align: right;\">18.0899</td><td style = \"text-align: right;\">2.95523</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1788</td><td style = \"text-align: right;\">-16.321</td><td style = \"text-align: right;\">7.03648</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1789</td><td style = \"text-align: right;\">-19.6796</td><td style = \"text-align: right;\">0.109262</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1790</td><td style = \"text-align: right;\">19.5261</td><td style = \"text-align: right;\">5.59811</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1791</td><td style = \"text-align: right;\">1.32969</td><td style = \"text-align: right;\">-4.05499</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1792</td><td style = \"text-align: right;\">1.38873</td><td style = \"text-align: right;\">2.25902</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1793</td><td style = \"text-align: right;\">2.15727</td><td style = \"text-align: right;\">-4.79302</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1794</td><td style = \"text-align: right;\">-10.5746</td><td style = \"text-align: right;\">2.93352</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1795</td><td style = \"text-align: right;\">-11.7413</td><td style = \"text-align: right;\">-11.0562</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1796</td><td style = \"text-align: right;\">5.0894</td><td style = \"text-align: right;\">-6.87343</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1797</td><td style = \"text-align: right;\">20.9178</td><td style = \"text-align: right;\">1.59327</td></tr></tbody></table></div>\n```\n:::\n:::\n\n\n### 3.3 project decision data to 2d space\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nfunction boundary_data(df::AbstractDataFrame,;n=200)\n n1=n2=n\n xlow,xhigh=extrema(df[:,:x1])\n ylow,yhigh=extrema(df[:,:x2])\n tx = range(xlow,xhigh; length=n1)\n ty = range(ylow,yhigh; length=n2)\n x_test = mapreduce(collect, hcat, Iterators.product(tx, ty));\n xtest=MLJ.table(x_test')\n return tx,ty, xtest\nend\n\n tx,ty, xtest=boundary_data(Xproj)\n```\n\n::: {.cell-output .cell-output-display execution_count=52}\n```\n(-31.16990412454417:0.3159297962408993:31.70012532739479, -30.0922050904867:0.2893801640107357:27.494447547649695, Tables.MatrixTable{LinearAlgebra.Adjoint{Float64, Matrix{Float64}}} with 40000 rows, 2 columns, and schema:\n :x1 Float64\n :x2 Float64)\n```\n:::\n:::\n\n\n### 3.4 kmean flow\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n kmeans_mach= machine(kmeans_model, Xproj) |> fit!\n\n ypred= predict(kmeans_mach, xtest)|>Array|>d->reshape(d,200,200)\n\n cen=fitted_params(kmeans_mach) #获取各聚类中心坐标\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: Training machine(KMeans(k = 9, …), …).\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=53}\n```\n(centers = [-14.940039081327807 -4.21080988089939 … 3.623938350681804 16.636412339611788; 6.0940188803548185 -1.9858645156958448 … 7.9454732936643895 -12.507323610991511],)\n```\n:::\n:::\n\n\n## 4. plot results\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nfunction plot_res()\n fig = Figure(resolution=(600,600))\n ax = Axis(fig[1, 1],title=\"digits pca kmeans\",subtitle=\"pca->clustering\")\n contourf!(ax, tx,ty,ypred)\n scatter!(ax,eachrow(cen.centers)...;marker=:xcross,markersize = 24,color=(:red,0.8))\n scatter!(ax,eachcol(Xproj)...;markersize = 8,color=(:lightgreen,0.1),strokecolor = :black, strokewidth =1)\n fig\nend\nplot_res()\n```\n\n::: {.cell-output .cell-output-display execution_count=54}\n![](1-kmeans-on-digits_files/figure-html/cell-8-output-1.png){}\n:::\n:::\n\n\n",
"supporting": [
"1-kmeans-on-digits_files"
],
"filters": [],
"includes": {
"include-in-header": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\" integrity=\"sha512-c3Nl8+7g4LMSTdrm621y7kf9v3SDPnhxLNhcjFJbKECVnmZHTdo+IRO05sNLTH/D3vA6u1X32ehoLC7WFVdheg==\" crossorigin=\"anonymous\"></script>\n<script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js\" integrity=\"sha512-bLT0Qm9VnAYZDflyKcBaQ2gg0hSYNQrJ8RilYldYQ1FxQYoCLtUjuuRuZo+fjqhx/qtq/1itJ0C2ejDxltZVFg==\" crossorigin=\"anonymous\"></script>\n<script type=\"application/javascript\">define('jquery', [],function() {return window.jQuery;})</script>\n"
]
}
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ website:
contents: category/regression/*.qmd
- section: "Classfication"
contents: category/classification/*.qmd
- section: "Clustering"
contents: category/clustering/*.qmd
- section: "Dimension Reduction"
contents: category/dimension-reduction/*.qmd
- href: category/schedule.qmd
Expand Down
87 changes: 87 additions & 0 deletions category/clustering/1-kmeans-on-digits.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
---
title: "1-kmeans-on-digits"
author: math4mad
code-fold: false
---

:::{.callout-note title="简介"}
**routine: project data to 2d space then proceed kmeans methods**

1. ref: [K-Means clustering on the handwritten digits data](https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html)
2. [kmeans in mlj](https://alan-turing-institute.github.io/MLJ.jl/dev/models/KMeans_Clustering/#KMeans_Clustering)
3. decision boundary:[Prediction](https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/examples/support-vector-machine/#Prediction)

:::

## 1. load package
```{julia}
include("../utils.jl")
import MLJ: fit!, predict, transform,fitted_params
using CSV, DataFrames, GLMakie, Random
using MLJ
Random.seed!(34343)
```

## 2. load data
```{julia}
load_csv(str::AbstractString) =
str |> d -> CSV.File("./data/$str.csv") |> DataFrame |> dropmissing
digits = load_csv("scikit_digits")
digits = coerce(digits, :target => Multiclass)
y, X = unpack(digits, ==(:target); rng = 123);
```

## 3. MLJ workflow

### 3.1 load model

```{julia}
PCA = @load PCA pkg = MultivariateStats
KMeans = @load KMeans pkg = Clustering
pca_model = PCA(; maxoutdim = 2)
kmeans_model = KMeans(; k =9)
```
### 3.2 usa pca model project data to 2d space
```{julia}
pca_mach = machine(pca_model, X) |> fit!
Xproj = transform(pca_mach, X)
```
### 3.3 project decision data to 2d space
```{julia}
function boundary_data(df::AbstractDataFrame,;n=200)
n1=n2=n
xlow,xhigh=extrema(df[:,:x1])
ylow,yhigh=extrema(df[:,:x2])
tx = range(xlow,xhigh; length=n1)
ty = range(ylow,yhigh; length=n2)
x_test = mapreduce(collect, hcat, Iterators.product(tx, ty));
xtest=MLJ.table(x_test')
return tx,ty, xtest
end
tx,ty, xtest=boundary_data(Xproj)
```
### 3.4 kmean flow

```{julia}
kmeans_mach= machine(kmeans_model, Xproj) |> fit!
ypred= predict(kmeans_mach, xtest)|>Array|>d->reshape(d,200,200)
cen=fitted_params(kmeans_mach) #获取各聚类中心坐标
```

## 4. plot results
```{julia}
function plot_res()
fig = Figure(resolution=(600,600))
ax = Axis(fig[1, 1],title="digits pca kmeans",subtitle="pca->clustering")
contourf!(ax, tx,ty,ypred)
scatter!(ax,eachrow(cen.centers)...;marker=:xcross,markersize = 24,color=(:red,0.8))
scatter!(ax,eachcol(Xproj)...;markersize = 8,color=(:lightgreen,0.1),strokecolor = :black, strokewidth =1)
fig
end
plot_res()
```
Loading

0 comments on commit ffa2905

Please sign in to comment.