generated from Pakillo/quarto-course-website-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add kmeans digits pca clustering
- Loading branch information
Showing
5 changed files
with
1,902 additions
and
0 deletions.
There are no files selected for viewing
15 changes: 15 additions & 0 deletions
15
_freeze/category/clustering/1-kmeans-on-digits/execute-results/html.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"hash": "5cf6d9f94cae9ff1ac54ec601259637d", | ||
"result": { | ||
"markdown": "---\ntitle: \"1-kmeans-on-digits\"\nauthor: math4mad\ncode-fold: false\n---\n\n:::{.callout-note title=\"简介\"}\n **routine: project data to 2d space then proceed kmeans methods**\n\n 1. ref: [K-Means clustering on the handwritten digits data](https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html)\n 2. [kmeans in mlj](https://alan-turing-institute.github.io/MLJ.jl/dev/models/KMeans_Clustering/#KMeans_Clustering)\n 3. decision boundary:[Prediction](https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/examples/support-vector-machine/#Prediction)\n \n:::\n\n## 1. load package\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\n include(\"../utils.jl\")\n import MLJ: fit!, predict, transform,fitted_params\n using CSV, DataFrames, GLMakie, Random\n using MLJ\n Random.seed!(34343)\n```\n\n::: {.cell-output .cell-output-display execution_count=48}\n```\nTaskLocalRNG()\n```\n:::\n:::\n\n\n## 2. load data \n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n load_csv(str::AbstractString) =\n str |> d -> CSV.File(\"./data/$str.csv\") |> DataFrame |> dropmissing\n\n digits = load_csv(\"scikit_digits\")\n digits = coerce(digits, :target => Multiclass)\n y, X = unpack(digits, ==(:target); rng = 123);\n```\n:::\n\n\n## 3. MLJ workflow\n\n### 3.1 load model\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n PCA = @load PCA pkg = MultivariateStats\n KMeans = @load KMeans pkg = Clustering\n pca_model = PCA(; maxoutdim = 2)\n kmeans_model = KMeans(; k =9)\n\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nimport MLJMultivariateStatsInterface ✔\nimport MLJClusteringInterface ✔\n```\n:::\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: For silent loading, specify `verbosity=0`. \n[ Info: For silent loading, specify `verbosity=0`. \n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=50}\n```\nKMeans(\n k = 9, \n metric = Distances.SqEuclidean(0.0), \n init = :kmpp)\n```\n:::\n:::\n\n\n### 3.2 usa pca model project data to 2d space\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\n pca_mach = machine(pca_model, X) |> fit!\n Xproj = transform(pca_mach, X)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: Training machine(PCA(maxoutdim = 2, …), …).\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=51}\n```{=html}\n<div><div style = \"float: left;\"><span>1797×2 DataFrame</span></div><div style = \"float: right;\"><span style = \"font-style: italic;\">1772 rows omitted</span></div><div style = \"clear: both;\"></div></div><div class = \"data-frame\" style = \"overflow-x: scroll;\"><table class = \"data-frame\" style = \"margin-bottom: 6px;\"><thead><tr class = \"header\"><th class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">Row</th><th style = \"text-align: left;\">x1</th><th style = \"text-align: left;\">x2</th></tr><tr class = \"subheader headerLastRow\"><th class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\"></th><th title = \"Float64\" style = \"text-align: left;\">Float64</th><th title = \"Float64\" style = \"text-align: left;\">Float64</th></tr></thead><tbody><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1</td><td style = \"text-align: right;\">2.53068</td><td style = \"text-align: right;\">7.10818</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">2</td><td style = \"text-align: right;\">-7.06737</td><td style = \"text-align: right;\">2.69455</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">3</td><td style = \"text-align: right;\">9.71838</td><td style = \"text-align: right;\">16.8406</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">4</td><td style = \"text-align: right;\">2.73069</td><td style = \"text-align: right;\">10.0099</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">5</td><td style = \"text-align: right;\">-13.1702</td><td style = \"text-align: right;\">-13.6058</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">6</td><td style = \"text-align: right;\">-13.2824</td><td style = \"text-align: right;\">8.86115</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">7</td><td style = \"text-align: right;\">15.9065</td><td style = \"text-align: right;\">-16.5451</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">8</td><td style = \"text-align: right;\">-17.972</td><td style = \"text-align: right;\">-7.82244</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">9</td><td style = \"text-align: right;\">-5.2641</td><td style = \"text-align: right;\">9.30009</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">10</td><td style = \"text-align: right;\">-11.3988</td><td style = \"text-align: right;\">-5.1612</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">11</td><td style = \"text-align: right;\">-0.171222</td><td style = \"text-align: right;\">-5.94507</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">12</td><td style = \"text-align: right;\">6.49767</td><td style = \"text-align: right;\">-25.2878</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">13</td><td style = \"text-align: right;\">-27.8753</td><td style = \"text-align: right;\">-5.61827</td></tr><tr><td style = \"text-align: right;\">⋮</td><td style = \"text-align: right;\">⋮</td><td style = \"text-align: right;\">⋮</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1786</td><td style = \"text-align: right;\">18.7481</td><td style = \"text-align: right;\">-2.69052</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1787</td><td style = \"text-align: right;\">18.0899</td><td style = \"text-align: right;\">2.95523</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1788</td><td style = \"text-align: right;\">-16.321</td><td style = \"text-align: right;\">7.03648</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1789</td><td style = \"text-align: right;\">-19.6796</td><td style = \"text-align: right;\">0.109262</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1790</td><td style = \"text-align: right;\">19.5261</td><td style = \"text-align: right;\">5.59811</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1791</td><td style = \"text-align: right;\">1.32969</td><td style = \"text-align: right;\">-4.05499</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1792</td><td style = \"text-align: right;\">1.38873</td><td style = \"text-align: right;\">2.25902</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1793</td><td style = \"text-align: right;\">2.15727</td><td style = \"text-align: right;\">-4.79302</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1794</td><td style = \"text-align: right;\">-10.5746</td><td style = \"text-align: right;\">2.93352</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1795</td><td style = \"text-align: right;\">-11.7413</td><td style = \"text-align: right;\">-11.0562</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1796</td><td style = \"text-align: right;\">5.0894</td><td style = \"text-align: right;\">-6.87343</td></tr><tr><td class = \"rowNumber\" style = \"font-weight: bold; text-align: right;\">1797</td><td style = \"text-align: right;\">20.9178</td><td style = \"text-align: right;\">1.59327</td></tr></tbody></table></div>\n```\n:::\n:::\n\n\n### 3.3 project decision data to 2d space\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nfunction boundary_data(df::AbstractDataFrame,;n=200)\n n1=n2=n\n xlow,xhigh=extrema(df[:,:x1])\n ylow,yhigh=extrema(df[:,:x2])\n tx = range(xlow,xhigh; length=n1)\n ty = range(ylow,yhigh; length=n2)\n x_test = mapreduce(collect, hcat, Iterators.product(tx, ty));\n xtest=MLJ.table(x_test')\n return tx,ty, xtest\nend\n\n tx,ty, xtest=boundary_data(Xproj)\n```\n\n::: {.cell-output .cell-output-display execution_count=52}\n```\n(-31.16990412454417:0.3159297962408993:31.70012532739479, -30.0922050904867:0.2893801640107357:27.494447547649695, Tables.MatrixTable{LinearAlgebra.Adjoint{Float64, Matrix{Float64}}} with 40000 rows, 2 columns, and schema:\n :x1 Float64\n :x2 Float64)\n```\n:::\n:::\n\n\n### 3.4 kmean flow\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n kmeans_mach= machine(kmeans_model, Xproj) |> fit!\n\n ypred= predict(kmeans_mach, xtest)|>Array|>d->reshape(d,200,200)\n\n cen=fitted_params(kmeans_mach) #获取各聚类中心坐标\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n[ Info: Training machine(KMeans(k = 9, …), …).\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=53}\n```\n(centers = [-14.940039081327807 -4.21080988089939 … 3.623938350681804 16.636412339611788; 6.0940188803548185 -1.9858645156958448 … 7.9454732936643895 -12.507323610991511],)\n```\n:::\n:::\n\n\n## 4. plot results\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nfunction plot_res()\n fig = Figure(resolution=(600,600))\n ax = Axis(fig[1, 1],title=\"digits pca kmeans\",subtitle=\"pca->clustering\")\n contourf!(ax, tx,ty,ypred)\n scatter!(ax,eachrow(cen.centers)...;marker=:xcross,markersize = 24,color=(:red,0.8))\n scatter!(ax,eachcol(Xproj)...;markersize = 8,color=(:lightgreen,0.1),strokecolor = :black, strokewidth =1)\n fig\nend\nplot_res()\n```\n\n::: {.cell-output .cell-output-display execution_count=54}\n![](1-kmeans-on-digits_files/figure-html/cell-8-output-1.png){}\n:::\n:::\n\n\n", | ||
"supporting": [ | ||
"1-kmeans-on-digits_files" | ||
], | ||
"filters": [], | ||
"includes": { | ||
"include-in-header": [ | ||
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\" integrity=\"sha512-c3Nl8+7g4LMSTdrm621y7kf9v3SDPnhxLNhcjFJbKECVnmZHTdo+IRO05sNLTH/D3vA6u1X32ehoLC7WFVdheg==\" crossorigin=\"anonymous\"></script>\n<script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js\" integrity=\"sha512-bLT0Qm9VnAYZDflyKcBaQ2gg0hSYNQrJ8RilYldYQ1FxQYoCLtUjuuRuZo+fjqhx/qtq/1itJ0C2ejDxltZVFg==\" crossorigin=\"anonymous\"></script>\n<script type=\"application/javascript\">define('jquery', [],function() {return window.jQuery;})</script>\n" | ||
] | ||
} | ||
} | ||
} |
Binary file added
BIN
+283 KB
_freeze/category/clustering/1-kmeans-on-digits/figure-html/cell-8-output-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
--- | ||
title: "1-kmeans-on-digits" | ||
author: math4mad | ||
code-fold: false | ||
--- | ||
|
||
:::{.callout-note title="简介"} | ||
**routine: project data to 2d space then proceed kmeans methods** | ||
|
||
1. ref: [K-Means clustering on the handwritten digits data](https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html) | ||
2. [kmeans in mlj](https://alan-turing-institute.github.io/MLJ.jl/dev/models/KMeans_Clustering/#KMeans_Clustering) | ||
3. decision boundary:[Prediction](https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/examples/support-vector-machine/#Prediction) | ||
|
||
::: | ||
|
||
## 1. load package | ||
```{julia} | ||
include("../utils.jl") | ||
import MLJ: fit!, predict, transform,fitted_params | ||
using CSV, DataFrames, GLMakie, Random | ||
using MLJ | ||
Random.seed!(34343) | ||
``` | ||
|
||
## 2. load data | ||
```{julia} | ||
load_csv(str::AbstractString) = | ||
str |> d -> CSV.File("./data/$str.csv") |> DataFrame |> dropmissing | ||
digits = load_csv("scikit_digits") | ||
digits = coerce(digits, :target => Multiclass) | ||
y, X = unpack(digits, ==(:target); rng = 123); | ||
``` | ||
|
||
## 3. MLJ workflow | ||
|
||
### 3.1 load model | ||
|
||
```{julia} | ||
PCA = @load PCA pkg = MultivariateStats | ||
KMeans = @load KMeans pkg = Clustering | ||
pca_model = PCA(; maxoutdim = 2) | ||
kmeans_model = KMeans(; k =9) | ||
``` | ||
### 3.2 usa pca model project data to 2d space | ||
```{julia} | ||
pca_mach = machine(pca_model, X) |> fit! | ||
Xproj = transform(pca_mach, X) | ||
``` | ||
### 3.3 project decision data to 2d space | ||
```{julia} | ||
function boundary_data(df::AbstractDataFrame,;n=200) | ||
n1=n2=n | ||
xlow,xhigh=extrema(df[:,:x1]) | ||
ylow,yhigh=extrema(df[:,:x2]) | ||
tx = range(xlow,xhigh; length=n1) | ||
ty = range(ylow,yhigh; length=n2) | ||
x_test = mapreduce(collect, hcat, Iterators.product(tx, ty)); | ||
xtest=MLJ.table(x_test') | ||
return tx,ty, xtest | ||
end | ||
tx,ty, xtest=boundary_data(Xproj) | ||
``` | ||
### 3.4 kmean flow | ||
|
||
```{julia} | ||
kmeans_mach= machine(kmeans_model, Xproj) |> fit! | ||
ypred= predict(kmeans_mach, xtest)|>Array|>d->reshape(d,200,200) | ||
cen=fitted_params(kmeans_mach) #获取各聚类中心坐标 | ||
``` | ||
|
||
## 4. plot results | ||
```{julia} | ||
function plot_res() | ||
fig = Figure(resolution=(600,600)) | ||
ax = Axis(fig[1, 1],title="digits pca kmeans",subtitle="pca->clustering") | ||
contourf!(ax, tx,ty,ypred) | ||
scatter!(ax,eachrow(cen.centers)...;marker=:xcross,markersize = 24,color=(:red,0.8)) | ||
scatter!(ax,eachcol(Xproj)...;markersize = 8,color=(:lightgreen,0.1),strokecolor = :black, strokewidth =1) | ||
fig | ||
end | ||
plot_res() | ||
``` |
Oops, something went wrong.