diff --git a/_toc.yml b/_toc.yml index fd31798..b189c73 100644 --- a/_toc.yml +++ b/_toc.yml @@ -13,12 +13,15 @@ subtrees: entries: - file: ch-data-science/data-science-lifecycle - file: ch-data-science/machine-learning + - file: ch-data-science/deep-learning + - file: ch-data-science/hyperparameter - file: ch-data-science/python-ecosystem - file: ch-dask/index entries: - file: ch-dask/dask-intro - file: ch-dask/dask-dataframe-intro - file: ch-dask/dask-distributed + - file: ch-dask/gpu - file: ch-dask/task-graph-partitioning - file: ch-dask-dataframe/index entries: @@ -29,6 +32,8 @@ subtrees: - file: ch-dask-dataframe/shuffle - file: ch-dask-ml/index entries: + - file: ch-dask-ml/preprocessing + - file: ch-dask-ml/hyperparameter - file: ch-dask-ml/distributed-training - file: ch-ray-core/index entries: @@ -47,14 +52,11 @@ subtrees: - file: ch-ray-data/data-load-inspect-save - file: ch-ray-data/data-transform - file: ch-ray-data/preprocessor - - file: ch-ray-train-tune/index + - file: ch-ray-ml/index entries: - - file: ch-ray-train-tune/ray-train - - file: ch-ray-train-tune/ray-tune - - file: ch-ray-train-tune/tune-algorithm-scheduler - - file: ch-ray-serve/index - entries: - - file: ch-ray-serve/ray-serve + - file: ch-ray-ml/ray-train + - file: ch-ray-ml/ray-tune + - file: ch-ray-ml/ray-serve - file: ch-mpi/index entries: - file: ch-mpi/mpi-intro diff --git a/ch-dask-dataframe/indexing.ipynb b/ch-dask-dataframe/indexing.ipynb index 42f8f48..c1cace9 100644 --- a/ch-dask-dataframe/indexing.ipynb +++ b/ch-dask-dataframe/indexing.ipynb @@ -33,7 +33,7 @@ "import os\n", "import sys\n", "sys.path.append(\"..\")\n", - "from datasets import nyc_flights\n", + "from utils import nyc_flights\n", "\n", "import dask\n", "dask.config.set({'dataframe.query-planning': False})\n", diff --git a/ch-dask-dataframe/map-partitions.ipynb b/ch-dask-dataframe/map-partitions.ipynb index e3357cb..7be5ad1 100644 --- a/ch-dask-dataframe/map-partitions.ipynb +++ b/ch-dask-dataframe/map-partitions.ipynb @@ -30,7 +30,7 @@ "source": [ "import sys\n", "sys.path.append(\"..\")\n", - "from datasets import nyc_taxi\n", + "from utils import nyc_taxi\n", "\n", "import pandas as pd\n", "import dask\n", diff --git a/ch-dask-dataframe/read-write.ipynb b/ch-dask-dataframe/read-write.ipynb index a9ba5cd..15ac679 100644 --- a/ch-dask-dataframe/read-write.ipynb +++ b/ch-dask-dataframe/read-write.ipynb @@ -56,7 +56,7 @@ "\n", "import sys\n", "sys.path.append(\"..\")\n", - "from datasets import nyc_flights\n", + "from utils import nyc_flights\n", "\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", diff --git a/ch-dask-ml/distributed-training.ipynb b/ch-dask-ml/distributed-training.ipynb index f70fbb3..a8287d6 100644 --- a/ch-dask-ml/distributed-training.ipynb +++ b/ch-dask-ml/distributed-training.ipynb @@ -9,8 +9,8 @@ "\n", "如果训练数据量很大,Dask-ML 提供了分布式机器学习功能,可以在集群上对大数据进行训练。目前,Dask 提供了两类分布式机器学习 API:\n", "\n", - "* scikit-learn:与 scikit-learn 的调用方式类似\n", - "* XGBoost 和 LightGBM:与 XGBoost 和 LightGBM 的调用方式类似\n", + "* scikit-learn 式:与 scikit-learn 的调用方式类似\n", + "* XGBoost 和 LightGBM 决策树式:与 XGBoost 和 LightGBM 的调用方式类似\n", "\n", "## scikit-learn API\n", "\n", @@ -463,7 +463,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "训练好的模型可以用来预测(`predict`),也可以计算准确度(`score`)。" + "训练好的模型可以用来预测(`predict()`),也可以计算准确度(`score()`)。" ] }, { @@ -515,7 +515,10 @@ "\n", "尽管 Dask-ML 这种分布式训练的 API 与 scikit-learn 极其相似,scikit-learn 只能使用单核,Dask-ML 可以使用多核甚至集群,但并不意味着所有场景下都选择 Dask-ML,因为有些时候 Dask-ML 并非性能或性价比最优的选择。这一点与 Dask DataFrame 和 pandas 关系一样,如果数据量能放进单机内存,原生的 pandas 、NumPy 和 scikit-learn 的性能和兼容性总是最优的。\n", "\n", - "下面的代码对不同规模的训练数据进行了性能分析,在单机多核且数据量较小的场景,Dask-ML 的性能并不比 scikit-learn 更快。主要因为:很多机器学习算法是迭代式的,scikit-learn 中,迭代式算法使用 Python 原生 `for` 循环来实现;Dask-ML 参考了这种 `for` 循环,但对于 Dask 的 Task Graph 来说,`for` 循环会使得 Task Graph 很臃肿,执行效率并不是很高。\n", + "下面的代码对不同规模的训练数据进行了性能分析,在单机多核且数据量较小的场景,Dask-ML 的性能并不比 scikit-learn 更快。原因有很多,包括:\n", + "\n", + "* 很多机器学习算法是迭代式的,scikit-learn 中,迭代式算法使用 Python 原生 `for` 循环来实现;Dask-ML 参考了这种 `for` 循环,但对于 Dask 的 Task Graph 来说,`for` 循环会使得 Task Graph 很臃肿,执行效率并不是很高。\n", + "* 分布式实现需要在不同进程间分发和收集数据,相比单机单进程,额外增加了很多数据同步和通信开销。\n", "\n", "你也可以根据你所拥有的内存来测试一下性能。" ] @@ -2115,9 +2118,9 @@ "\n", "XGBoost 和 LightGBM 是两种决策树模型的实现,他们本身就对分布式训练友好,且集成了 Dask 的分布式能力。下面以 XGBoost 为例,介绍 XGBoost 如何基于 Dask 实现分布式训练,LightGBM 与之类似。\n", "\n", - "在 XGBoost 中,训练一个模型既可以使用 `train` 方法,也可以使用 scikit-learn 式的 `fit` 方法。这两种方式都支持 Dask 分布式训练。\n", + "在 XGBoost 中,训练一个模型既可以使用 `train` 方法,也可以使用 scikit-learn 式的 `fit()` 方法。这两种方式都支持 Dask 分布式训练。\n", "\n", - "下面的代码对单机的 XGBoost 和 Dask 分布式训练两种方式进行了性能对比。如果使用 Dask,需要将 [`xgboost.DMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.DMatrix) 修改为 [`xgboost.dask.DaskDMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.DaskDMatrix),将 [`xgboost.train`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.train) 修改为 [`xgboost.dask.train`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.train);并传入 Dask 集群客户端 `client`。" + "下面的代码对单机的 XGBoost 和 Dask 分布式训练两种方式进行了性能对比。如果使用 Dask,需要将 [`xgboost.DMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.DMatrix) 修改为 [`xgboost.dask.DaskDMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.DaskDMatrix),`xgboost.dask.DaskDMatrix` 可以将分布式的 Dask Array 或 Dask DataFrame 转化为 XGBoost 所需要的数据格式;再将 [`xgboost.train`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.train) 修改为 [`xgboost.dask.train`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.train);并传入 Dask 集群客户端 `client`。" ] }, { diff --git a/ch-dask-ml/hyperparameter.ipynb b/ch-dask-ml/hyperparameter.ipynb new file mode 100644 index 0000000..2ad541c --- /dev/null +++ b/ch-dask-ml/hyperparameter.ipynb @@ -0,0 +1,803 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(sec-dask-ml-hyperparameter)=\n", + "# 超参数调优\n", + "\n", + "我们可以使用 Dask 进行超参数调优,主要有两种方式:\n", + "\n", + "* 基于 scikit-learn 的 joblib 后端,将多个超参数调优任务分布到 Dask 集群\n", + "* 使用 Dask-ML 提供的超参数调优 API\n", + "\n", + "## scikit-learn joblib\n", + "\n", + "单机的 scikit-learn 已经提供了丰富易用的模型训练和超参数调优接口,它默认使用 joblib 在单机多核之间并行。像随机搜索和网格搜索等超参数调优任务容易并行,任务之间没有依赖关系,很容易并行起来。\n", + "\n", + "### 案例:飞机延误预测(scikit-learn)\n", + "\n", + "下面展示一个基于 scikit-learn 的机器学习分类案例,我们使用 scikit-learn 提供的网格搜索。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import sys\n", + "sys.path.append(\"..\")\n", + "from utils import nyc_flights\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "folder_path = nyc_flights()\n", + "file_path = os.path.join(folder_path, \"1991.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "input_cols = [\n", + " \"Year\",\n", + " \"Month\",\n", + " \"DayofMonth\",\n", + " \"DayOfWeek\",\n", + " \"CRSDepTime\",\n", + " \"CRSArrTime\",\n", + " \"UniqueCarrier\",\n", + " \"FlightNum\",\n", + " \"ActualElapsedTime\",\n", + " \"Origin\",\n", + " \"Dest\",\n", + " \"Distance\",\n", + " \"Diverted\",\n", + " \"ArrDelay\",\n", + "]\n", + "\n", + "df = pd.read_csv(file_path, usecols=input_cols)\n", + "df = df.dropna()\n", + "\n", + "# 预测是否延误\n", + "df[\"ArrDelayBinary\"] = 1.0 * (df[\"ArrDelay\"] > 10)\n", + "\n", + "df = df[df.columns.difference([\"ArrDelay\"])]\n", + "\n", + "# 将 Dest/Origin/UniqueCarrier 等字段转化为 category 类型\n", + "for col in df.select_dtypes([\"object\"]).columns:\n", + " df[col] = df[col].astype(\"category\").cat.codes.astype(np.int32)\n", + "\n", + "for col in df.columns:\n", + " df[col] = df[col].astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import SGDClassifier\n", + "\n", + "from sklearn.model_selection import GridSearchCV as SkGridSearchCV\n", + "from sklearn.model_selection import train_test_split as sk_train_test_split\n", + "\n", + "_y_label = \"ArrDelayBinary\"\n", + "X_train, X_test, y_train, y_test = sk_train_test_split(\n", + " df.loc[:, df.columns != _y_label], \n", + " df[_y_label], \n", + " test_size=0.25,\n", + " shuffle=False,\n", + ")\n", + "\n", + "model = SGDClassifier(penalty='elasticnet', max_iter=1_000, warm_start=True, loss='log_loss')\n", + "params = {'alpha': np.logspace(-4, 1, num=81)}\n", + "\n", + "sk_grid_search = SkGridSearchCV(model, params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在进行超参数搜索时,只需要添加 `with joblib.parallel_config('dask'):`,将网格搜索计算任务扩展到 Dask 集群。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-05-08 07:36:02,224 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client\n" + ] + } + ], + "source": [ + "import joblib\n", + "from dask.distributed import Client, LocalCluster\n", + "\n", + "# 修改为你的 Dask Scheduler IP 地址\n", + "client = Client(\"10.0.0.3:8786\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "with joblib.parallel_config('dask'):\n", + " sk_grid_search.fit(X_train, y_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "使用 `score()` 方法查看模型的准确度:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7775224665166276" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sk_grid_search.score(X_test, y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dask-ML API\n", + "\n", + "前面介绍了基于 scikit-learn 的超参数调优,整个流程中只需要修改 `joblib.parallel_config('dask')`,计算任务就被分发到 Dask 集群上。\n", + "\n", + "Dask-ML 自己也实现了一些超参数调优的 API,除了提供和 scikit-learn 对标的 `GridSearchCV`、`RandomizedSearchCV` 等算法外,还提供了连续减半算法、Hyperband 算法等,比如 [`SuccessiveHalvingSearchCV`](https://ml.dask.org/modules/generated/dask_ml.model_selection.SuccessiveHalvingSearchCV.html)、[`HyperbandSearchCV`](https://ml.dask.org/modules/generated/dask_ml.model_selection.HyperbandSearchCV.html)。\n", + "\n", + "### 案例:飞机延误预测(Dask-ML)\n", + "\n", + "下面展示一个基于 Dask-ML 的 Hyperband 超参数调优案例。\n", + "\n", + "Dask-ML 的超参数调优算法要求输入为 Dask DataFrame 或 Dask Array 等可被切分的数据,而非 pandas DataFrame,因此数据预处理部分需要改为 Dask。\n", + "\n", + "值得注意的是,Dask-ML 提供的 `SuccessiveHalvingSearchCV` 和 `HyperbandSearchCV` 等算法要求模型必须支持 `partial_fit()` 和 `score()`。`partial_fit()` 是 scikit-learn 中迭代式算法(比如梯度下降法)的一次迭代过程。连续减半算法和 Hyperband 算法先分配一些算力额度,不是完成试验的所有迭代,而只做一定次数的迭代(对 `partial_fit()` 进行一定次数的调用),评估性能(在验证集上调用 `score()` 方法),淘汰性能较差的试验。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import dask.dataframe as dd\n", + "\n", + "input_cols = [\n", + " \"Year\",\n", + " \"Month\",\n", + " \"DayofMonth\",\n", + " \"DayOfWeek\",\n", + " \"CRSDepTime\",\n", + " \"CRSArrTime\",\n", + " \"UniqueCarrier\",\n", + " \"FlightNum\",\n", + " \"ActualElapsedTime\",\n", + " \"Origin\",\n", + " \"Dest\",\n", + " \"Distance\",\n", + " \"Diverted\",\n", + " \"ArrDelay\",\n", + "]\n", + "\n", + "ddf = dd.read_csv(file_path, usecols=input_cols,)\n", + "\n", + "# 预测是否延误\n", + "ddf[\"ArrDelayBinary\"] = 1.0 * (ddf[\"ArrDelay\"] > 10)\n", + "\n", + "ddf = ddf[ddf.columns.difference([\"ArrDelay\"])]\n", + "ddf = ddf.dropna()\n", + "ddf = ddf.repartition(npartitions=8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "另外,Dask 处理类型变量时与 pandas/scikit-learn 也稍有不同,我们需要:\n", + "\n", + "* 将该特征转换为 `category` 类型,比如,使用 Dask DataFrame 的 `categorize()` 方法,或 Dask-ML 的 [`Categorizer`](https://ml.dask.org/modules/generated/dask_ml.preprocessing.Categorizer.html#dask_ml.preprocessing.Categorizer) 预处理器。\n", + "* 进行独热编码:Dask-ML 中的 [`DummyEncoder`](https://ml.dask.org/modules/generated/dask_ml.preprocessing.DummyEncoder.html#dask_ml.preprocessing.DummyEncoder) 对类别特征进行独热编码,是 scikit-learn `OneHotEncoder` 的 Dask 替代。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from dask_ml.preprocessing import DummyEncoder\n", + "\n", + "dummy = DummyEncoder()\n", + "ddf = ddf.categorize(columns=[\"Dest\", \"Origin\", \"UniqueCarrier\"])\n", + "dummified_ddf = dummy.fit_transform(ddf)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "并使用 Dask-ML 的 `train_test_split` 方法切分训练集和测试集:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from dask_ml.model_selection import train_test_split as dsk_train_test_split\n", + "\n", + "_y_label = \"ArrDelayBinary\"\n", + "X_train, X_test, y_train, y_test = dsk_train_test_split(\n", + " dummified_ddf.loc[:, dummified_ddf.columns != _y_label], \n", + " dummified_ddf[_y_label], \n", + " test_size=0.25,\n", + " shuffle=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "定义模型和搜索空间的方式与 scikit-learn 类似,然后调用 Dask-ML 的 `HyperbandSearchCV` 进行超参数调优。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/sklearn/model_selection/_search.py:318: UserWarning: The total space of parameters 30 is smaller than n_iter=81. Running 30 iterations. For exhaustive searches, use GridSearchCV.\n", + " warnings.warn(\n", + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/sklearn/model_selection/_search.py:318: UserWarning: The total space of parameters 30 is smaller than n_iter=34. Running 30 iterations. For exhaustive searches, use GridSearchCV.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
HyperbandSearchCV(estimator=SGDClassifier(loss='log_loss', penalty='elasticnet',\n",
+       "                                          warm_start=True),\n",
+       "                  max_iter=243,\n",
+       "                  parameters={'alpha': array([1.00000000e-04, 1.48735211e-04, 2.21221629e-04, 3.29034456e-04,\n",
+       "       4.89390092e-04, 7.27895384e-04, 1.08263673e-03, 1.61026203e-03,\n",
+       "       2.39502662e-03, 3.56224789e-03, 5.29831691e-03, 7.88046282e-03,\n",
+       "       1.17210230e-02, 1.74332882e-02, 2.59294380e-02, 3.85662042e-02,\n",
+       "       5.73615251e-02, 8.53167852e-02, 1.26896100e-01, 1.88739182e-01,\n",
+       "       2.80721620e-01, 4.17531894e-01, 6.21016942e-01, 9.23670857e-01,\n",
+       "       1.37382380e+00, 2.04335972e+00, 3.03919538e+00, 4.52035366e+00,\n",
+       "       6.72335754e+00, 1.00000000e+01])})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "HyperbandSearchCV(estimator=SGDClassifier(loss='log_loss', penalty='elasticnet',\n", + " warm_start=True),\n", + " max_iter=243,\n", + " parameters={'alpha': array([1.00000000e-04, 1.48735211e-04, 2.21221629e-04, 3.29034456e-04,\n", + " 4.89390092e-04, 7.27895384e-04, 1.08263673e-03, 1.61026203e-03,\n", + " 2.39502662e-03, 3.56224789e-03, 5.29831691e-03, 7.88046282e-03,\n", + " 1.17210230e-02, 1.74332882e-02, 2.59294380e-02, 3.85662042e-02,\n", + " 5.73615251e-02, 8.53167852e-02, 1.26896100e-01, 1.88739182e-01,\n", + " 2.80721620e-01, 4.17531894e-01, 6.21016942e-01, 9.23670857e-01,\n", + " 1.37382380e+00, 2.04335972e+00, 3.03919538e+00, 4.52035366e+00,\n", + " 6.72335754e+00, 1.00000000e+01])})" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dask_ml.model_selection import HyperbandSearchCV\n", + "\n", + "# client = Client(LocalCluster())\n", + "model = SGDClassifier(penalty='elasticnet', max_iter=1_000, warm_start=True, loss='log_loss')\n", + "params = {'alpha': np.logspace(-4, 1, num=30)}\n", + "\n", + "dsk_hyperband = HyperbandSearchCV(model, params, max_iter=243)\n", + "dsk_hyperband.fit(X_train, y_train, classes=[0.0, 1.0])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8241373877422404" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dsk_hyperband.score(X_test, y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "本书还会介绍 Ray 的超参数调优,相比 Dask,Ray 的兼容性和功能完善程度更好,读者可以根据自身需求选择适合自己的框架。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ch-dask-ml/index.md b/ch-dask-ml/index.md index 7068513..9e5ffa8 100644 --- a/ch-dask-ml/index.md +++ b/ch-dask-ml/index.md @@ -1,4 +1,14 @@ # Dask 机器学习 +本章将聚焦于 Dask 机器学习,主要介绍 Dask-ML 等库的使用。Dask-ML 基于 Dask 的分布式计算能力,面向机器学习应用,可以无缝对接 scikit-learn、XGBoost 等机器学习库。相比之下,Dask-ML 更适合传统机器学习的训练和推理,比如回归、决策树等等,深度学习相关的训练和推理更多基于 PyTorch 或 TensorFlow 等框架。 + +总结起来,Dask 和 Dask-ML 适合的场景有以下几类: + +* 原始数据无法放到单机内存中,需要进行分布式数据预处理和特征工程; +* 训练数据和模型可放到单机内存中,超参数调优需要多机并行; +* 训练数据无法放到单机内存中,需要进行分布式训练。 + +一方面,Dask 社区将主要精力投入在 Dask DataFrame 上,对 Dask-ML 和分布式训练的优化并不多;另一方面,深度学习已经冲击传统机器学习算法,Dask 设计之初并不是面向深度学习的。读者阅读本章,了解 Dask 机器学习能力后,可以根据自身需求选择适合自己的框架。 + ```{tableofcontents} ``` \ No newline at end of file diff --git a/ch-dask-ml/preprocessing.md b/ch-dask-ml/preprocessing.md new file mode 100644 index 0000000..7fe9a53 --- /dev/null +++ b/ch-dask-ml/preprocessing.md @@ -0,0 +1,8 @@ +(sec-dask-ml-preprocessing)= +# 数据预处理 + +{numref}`sec-data-science-lifecycle` 我们提到过,数据科学工作的重点是理解数据和处理数据,Dask 可以将很多单机的任务横向扩展到集群上,并且可以和 Python 社区数据可视化等库结合,完成探索性数据分析。 + +分布式数据预处理部分更多依赖 Dask DataFrame 和 Dask Array 的能力,这里不再赘述。 + +特征工程部分,Dask-ML 实现了很多 `sklearn.preprocessing` 的 API,比如 [`MinMaxScaler`](https://ml.dask.org/modules/generated/dask_ml.preprocessing.MinMaxScaler.html)。对 Dask 而言,稍有不同的是其独热编码,本书写作时,Dask 使用 [`DummyEncoder`](https://ml.dask.org/modules/generated/dask_ml.preprocessing.DummyEncoder.html) 对类别特征进行独热编码,`DummyEncoder` 是 scikit-learn `OneHotEncoder` 的 Dask 替代。我们将在 {numref}`sec-dask-ml-hyperparameter` 将展示一个类型特征的案例。 \ No newline at end of file diff --git a/ch-dask/gpu.ipynb b/ch-dask/gpu.ipynb new file mode 100644 index 0000000..58ec701 --- /dev/null +++ b/ch-dask/gpu.ipynb @@ -0,0 +1,1220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(sec-dask-gpu)=\n", + "# GPU\n", + "\n", + "GPU 以及其他异构计算加速器被广泛用来加速深度学习,Dask 社区联合 NVIDIA,提供了基于 GPU 的数据科学工具包,以加速各类任务。\n", + "\n", + "## Dask GPU 集群\n", + "\n", + "[Dask-CUDA](https://docs.rapids.ai/api/dask-cuda/stable/) 对 `dask.distributed` 扩展,可以识别和管理 GPU 设备。\n", + "\n", + "使用这些 GPU 设备前,先通过 `pip install dask_cuda` 安装 Dask-CUDA。跟 {numref}`sec-dask-distributed` 提到的`dask.distributed` 类似,Dask-CUDA 提供了一个单机的 `LocalCUDACluster`,`LocalCUDACluster` 会自动发现并注册该计算节点上的多张 GPU 卡,每张 GPU 自动配比一定数量的 CPU 核心。比如,我的环境有 4 张 GPU 卡,启动一个单机 Dask 集群,会自动启动 4 个 Dask Worker,每个 Worker 一张 GPU 卡。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/distributed/node.py:182: UserWarning: Port 8787 is already in use.\n", + "Perhaps you already have a cluster running?\n", + "Hosting the HTTP server on port 37111 instead\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-5c3311bf-0ce5-11ef-bd8c-000012e4fe80

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Cluster objectCluster type: dask_cuda.LocalCUDACluster
\n", + " Dashboard: http://127.0.0.1:37111/status\n", + "
\n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "

Cluster Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

LocalCUDACluster

\n", + "

209b2784

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "\n", + " \n", + "
\n", + " Dashboard: http://127.0.0.1:37111/status\n", + " \n", + " Workers: 4\n", + "
\n", + " Total threads: 4\n", + " \n", + " Total memory: 90.00 GiB\n", + "
Status: runningUsing processes: True
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-39587c13-5825-4748-be18-a18f23c602bb

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://127.0.0.1:46657\n", + " \n", + " Workers: 4\n", + "
\n", + " Dashboard: http://127.0.0.1:37111/status\n", + " \n", + " Total threads: 4\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 90.00 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 0

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:36681\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:38373/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:41031\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-jkx850hc\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 1

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:37987\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:38845/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:36415\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-gelyun5u\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 2

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:36139\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:44939/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:40211\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-c6owcg7k\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 3

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:46363\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:40611/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:38093\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-hyl9pn8_\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dask_cuda import LocalCUDACluster\n", + "from dask.distributed import Client\n", + "\n", + "cluster = LocalCUDACluster()\n", + "client = Client(cluster)\n", + "client" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们也可以启动一个 Dask GPU 集群,先启动 Dask Scheduler:\n", + "\n", + "```\n", + "dask scheduler\n", + "```\n", + "\n", + "再在每个 GPU 节点上启动支持 GPU 的 Worker,构成一个 Dask GPU 集群。\n", + "\n", + "```\n", + "dask cuda worker tcp://scheduler:8786\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-6039933f-0ce3-11ef-b163-000012e4fe80

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Direct
\n", + " Dashboard: http://10.0.0.3:8787/status\n", + "
\n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "

Scheduler Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-d073585d-dcac-41bf-9c5c-1055fe07576c

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.0.0.3:8786\n", + " \n", + " Workers: 8\n", + "
\n", + " Dashboard: http://10.0.0.3:8787/status\n", + " \n", + " Total threads: 8\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 180.00 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.2:34491

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.2:34491\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.2:38385/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.2:37559\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-p2de783n\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 4.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 216.19 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 8.81 kiB\n", + " \n", + " Write bytes: 14.61 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.2:39239

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.2:39239\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.2:45797/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.2:36259\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-mo04yp4a\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 6.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 216.30 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 9.76 kiB\n", + " \n", + " Write bytes: 14.86 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.2:40863

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.2:40863\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.2:43677/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.2:32877\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-4p9jsv4f\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 4.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 216.27 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 9.77 kiB\n", + " \n", + " Write bytes: 14.88 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.2:46243

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.2:46243\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.2:40513/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.2:45107\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-gt5epnxr\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 4.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 216.21 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 10.04 kiB\n", + " \n", + " Write bytes: 15.00 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.3:39647

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.3:39647\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.3:38377/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.3:34843\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-gqcyic7m\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 4.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 217.51 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 63.74 kiB\n", + " \n", + " Write bytes: 58.80 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.3:40155

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.3:40155\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.3:34723/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.3:46339\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-yo78gnof\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 6.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 218.25 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 63.73 kiB\n", + " \n", + " Write bytes: 58.80 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.3:45005

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.3:45005\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.3:42503/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.3:34929\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-skts4xjq\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 6.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 216.24 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 63.74 kiB\n", + " \n", + " Write bytes: 58.81 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.3:46333

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.3:46333\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.0.0.3:36413/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.3:44405\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-pu9uzxbg\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 4.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 218.16 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 64.86 kiB\n", + " \n", + " Write bytes: 59.93 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client = Client(\"10.0.0.3:8786\")\n", + "client" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{note}\n", + "Dask-CUDA 只发现并注册这些 GPU,但无法做到 GPU 的隔离,其他非 Dask 的任务仍然可以抢占该 GPU。GPU 资源的隔离应该需要借助 Kubernetes 等容器技术。 \n", + ":::\n", + "\n", + "## GPU 任务\n", + "\n", + "并不是所有任务都能被 GPU 加速,GPU 主要加速一些计算密集型任务,比如机器学习和深度学习等。目前,Dask 支持的 GPU 上框架包括:\n", + "\n", + "* [CuPy](https://cupy.dev/) 的 Dask 分布式版本\n", + "* [Dask-cuDF](https://docs.rapids.ai/api/dask-cudf/stable/) 将 GPU 加速版的 cuDF 横向扩展到 GPU 集群上\n", + "\n", + ":::{note}\n", + "使用 NVIDIA 的 GPU 的各类库,应该将 CUDA 目录添加到 `PATH` 和 `LD_LIBRARY_PATH` 环境变量中,CuPy 和 cuDF 需要依赖所安装的动态链接库。\n", + ":::\n", + "\n", + "### 案例:奇异值分解\n", + "\n", + "下面的代码在 GPU 上进行奇异值分解,是一种适合 GPU 加速的任务。设置 `dask.config.set({\"array.backend\": \"cupy\"})` 即可将 Dask Array 的执行后端改为 GPU 上的 CuPy。\n", + "\n", + "```python\n", + "import cupy\n", + "import dask\n", + "import dask.array as da\n", + "\n", + "dask.config.set({\"array.backend\": \"cupy\"})\n", + "rs = dask.array.random.RandomState(RandomState=cupy.random.RandomState)\n", + "x = rs.random((10000, 1000), chunks=(1000, 1000))\n", + "u, s, v = dask.array.linalg.svd(x)\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ch-dask/task-graph-partitioning.ipynb b/ch-dask/task-graph-partitioning.ipynb index e37b3bd..63b988e 100644 --- a/ch-dask/task-graph-partitioning.ipynb +++ b/ch-dask/task-graph-partitioning.ipynb @@ -71,10 +71,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "在这个例子中,Dask Task Graph 一共有 3 个 Task,实现了数据输入和加法计算。Dask 的设计思想是将复杂的并行计算切分成 Task,每个 Task 是一个 Python 函数。`.visualize()` 可视化的 Task Graph 中,圆圈是函数,方框是数据占位符。\n", - "Dask Scheduler 会生成 Task Graph,并将 Task Graph 中的各个具体的计算任务分发到 Dask Worker 上。\n", + "在这个例子中,Dask Task Graph 一共有 3 个 Task,实现了数据输入和加法计算。Dask 的设计思想是将复杂的并行计算切分成 Task,每个 Task 是一个 Python 函数。`.visualize()` 可视化的 Task Graph 中,圆圈是函数,方框是数据占位符。Dask Scheduler 会生成 Task Graph,并将 Task Graph 中的各个计算任务分发到 Dask Worker 上。\n", "\n", - "Dask 并没有创造新的计算引擎,而是通过 Task Graph 的方式将多个 Task 组织起来。Dask 所提供的各类复杂的功能都是基于这种思想实现的。\n", + "值得注意的是,在将多个 Task 组合在一起时,我们使用了 `dask.delayed`,`dask.delayed` 是一个偏底层的接口,可以允许用户手动构建计算图。如果用户需要自定义某些任务,就可以使用 `dask.delayed` 这个装饰器,比如这样:\n", + "\n", + "```\n", + "@dask.delayed\n", + "def f(x):\n", + " x = x + 1\n", + " return x\n", + "```\n", + "\n", + "应该使用:`dask.delayed(f)(x, y)`,而非 `dask.delayed(f(x, y))`,因为 `dask.delayed` 修饰的是 Python 函数,而不是函数的输出结果。经过 `dask.delayed` 修饰的 Python 函数将构成 Task Graph 的一个节点,Dask 并没有创造新的计算引擎,而是通过 Task Graph 的方式将多个 Task 组织起来。Dask 所提供的各类复杂的功能都是基于此实现的。\n", "\n", "## 数据切分\n", "\n", diff --git a/ch-data-science/deep-learning.md b/ch-data-science/deep-learning.md new file mode 100644 index 0000000..149416a --- /dev/null +++ b/ch-data-science/deep-learning.md @@ -0,0 +1,109 @@ +(sec-deep-learning-intro)= +# 深度学习 + +## 深度神经网络 + +深度学习是深度神经网络的简称。简单来说,神经网络是由很多个下面的公式组成,而深度神经网络是由很多个神经网络层堆叠而成的。 + +$$ +\begin{aligned} +\boldsymbol{z} &= \boldsymbol{W} \cdot \boldsymbol{x} + \boldsymbol{b} \\ +\boldsymbol{a} &= f(\boldsymbol{z}) +\end{aligned} +$$ + +$\boldsymbol{x}$ 是输入,$\boldsymbol{W}$ 是神经网络的参数,又被称为权重。神经网络的学习的过程,就是不断更新参数 $\boldsymbol{W}$ 的过程,这也是所谓的训练过程。训练好的模型可以用来推理,用来预测未知数据。 + +$f$ 是激活函数(Activation Function)。$\boldsymbol{W}$ 与 $\boldsymbol{x}$ 相乘仅是一个线性变换,就算很多个乘法叠加起来,仍然是线性变换,或者说没有激活函数的多层网络就退化成了一个单层线性模型。激活函数可以在神经网络中引入了非线性因素,使得多层神经网络理论上可以拟合任何输入数据到输出数据的模式。从模拟生物的神经元的角度,激活函数是为了让有些神经元被激活,有些神经元被抑制。常见的激活函数有 Sigmoid 和 ReLU。{numref}`sec-machine-learning-intro` 我们对 Sigmod 进行过可视化,ReLU 的公式为:$f(x) = \max (0, x)$。 + +## 前向传播 + +{numref}`fig-forward-pass` 是一种最简单的神经网络:将 $\boldsymbol{z^{[n]}} = \boldsymbol{W^{[n]}} \cdot \boldsymbol{a^{[n-1]}} + \boldsymbol{b^{[n]}}$ 和 $\boldsymbol{a^{[n]}} = f(\boldsymbol{z^{[n]}})$ 堆叠,前一层的输出 $\boldsymbol{a^{[n-1]}}$ 作为下一层的输入。这种网络又被成为前馈神经网络(Feedforward Neural Network),或者多层感知机(Multilayer Perceptron,MLP)。多层网络中,为了区分某一层,用方括号上标来表示,比如 $\boldsymbol{a^{[1]}}$ 是第一层的输出,$\boldsymbol{W^{[1]}}$ 是第一层的参数。 + +```{figure} ../img/ch-data-science/forward-pass.svg +--- +width: 800px +name: fig-forward-pass +--- +神经网络的前向传播 +``` + +{numref}`fig-forward-pass` 是神经网络前向传播的过程:假设输入 $\boldsymbol{x}$ 是一个 3 维的向量;{numref}`fig-forward-pass` 中的每个圆圈为向量的一个元素(一个标量值),图中同时也演示了第一层的 $\boldsymbol{a^{[1]}}$ 的向量化计算方式,以及 $z^{[1]}_1$ 的标量化计算方式,实际场景中往往需要使用现代处理器的向量化引擎完成计算。 + +## 反向传播 + +神经网络的训练过程就是不断更新各层的 $\boldsymbol{W}$ 和 $\boldsymbol{b}$。 + +首先以某种随机初始化方式,初始化各层的 $\boldsymbol{W}$ 和 $\boldsymbol{b}$。比如,初始化为正态分布的小数。 + +然后确定一个损失函数(Loss Function)$L$。损失函数计算了神经网络预测值 $\hat{y}$ 与真实值 $y$ 之间的差距,训练的目标就是让损失函数变小。比如,预测房价的案例,我们使用误差的平方(Squared Error)作为损失函数,某一个样本的损失函数为 $L = (y - \hat{y})^2$。 + +然后计算损失函数对每层参数的导数。$L$ 关于第 $l$ 层 $\boldsymbol{W^{[l]}}$ 和 $\boldsymbol{b^{[l]}}$ 的导数为 $\frac{\partial L}{\partial \boldsymbol{W^{[l]}}}$ 和 $\frac{\partial L}{\partial \boldsymbol{b^{[l]}}}$,再按照下面的公式更新 $\boldsymbol{W^{[l]}}$ 和 $\boldsymbol{b^{[l]}}$。 + +$$ +\begin{aligned} +\boldsymbol{W^{[l]}} &= \boldsymbol{W^{[l]}}-\alpha\frac{\partial L}{\partial \boldsymbol{W^{[l]}}}\\ +\boldsymbol{b^{[l]}} &= \boldsymbol{b^{[l]}}-\alpha\frac{\partial L}{\partial \boldsymbol{b^{[l]}}}\\ +\end{aligned} +$$ + +公式中,$\alpha$ 是学习率,即参数更新的速度,如果学习率太大则容易振荡,不容易收敛,太小则收敛速度又会过慢。 + +各层的导数又被称为梯度,参数沿着梯度方向下降,又被成为梯度下降法。计算各层的导数时,往往是从最后的损失函数开始,向前一层一层地求梯度,即先求最后第 $n$ 层的梯度,得到第 $n$ 层的梯度,结合链式法则,求第 $n-1$ 层的梯度。{numref}`fig-back-propagation` 展示了神经网络的反向传播过程。 + +```{figure} ../img/ch-data-science/back-propagation.svg +--- +width: 800px +name: fig-back-propagation +--- +神经网络的反向传播 +``` + +## 超参数 + +神经网络训练过程中,有很多训练模型之前需要人为设定的一些参数,这些参数不能通过模型的反向传播算法来自动学习,而需要手动选择和调整。这些参数又被成为超参数(Hyperparameter),超参数的选择通常基于经验或反复试验。以下是一些超参数: + +* 学习率,即刚才提到的 $\alpha$,控制着每次更新参数的步长。 +* 网络结构:模型的层数、每层的神经元数量、激活函数的选择等。不同的网络结构对于不同的任务可能有不同的性能表现。 + +## 实现细节 + +神经网络训练实现起来要关注以下三个步骤: + +* 一次前向传播 +* 一次反向传播 +* 一次更新模型权重 + +{numref}`fig-model-training-input-output` 整理了神经网络的第 i 层进行训练时,以上三个步骤的输入和输出。 + +```{figure} ../img/ch-data-science/model-training-input-output.svg +--- +width: 800px +name: fig-model-training-input-output +--- +前向传播、反向传播和更新模型权重的输入和输出 +``` + +对于前向传播,输入有两部分:i-1 层输出 $\boldsymbol{a^{[i-1]}}$ 和第 i 层的模型权重 $\boldsymbol{W^{[i]}}$、$\boldsymbol{b^{[i]}}$;输出又被称为激活(Activation)。 + +对于反向传播,输入有三部分:i 层输出 $\boldsymbol{a^{[i]}}$;第 i 层的模型权重 $\boldsymbol{W^{[i]}}$、$\boldsymbol{b^{[i]}}$;损失对 i 层输出的导数 $\boldsymbol{\boldsymbol{\frac{\partial L}{a^{[i]}}}}$。根据链式法则,可以求出损失对 i 层模型权重的导数 $\boldsymbol{\frac{\partial L}{\partial W^{[i]} +}}$、$\boldsymbol{\frac{\partial L}{\partial b^{[i]} +}}$,也就是梯度。 + +得到梯度后,需要沿着梯度下降的方向更新模型权重。如果是最简单的梯度下降法,优化器直接在模型原有权重基础上做减法,不需要额外保存状态,比如:$\boldsymbol{W^{[l]}} = \boldsymbol{W^{[l]}}-\alpha\frac{\partial L}{\partial \boldsymbol{W^{[l]}}}$ + +复杂一点的优化器,比如 Adam, 在梯度下降时引入了动量的概念。动量是梯度的指数移动平均,需要维护一个梯度的移动平均矩阵,这个矩阵就是优化器的状态。因此,优化器状态、原来的模型权重和梯度共同作为输入,可以得到更新后的模型权重。至此才能完成一轮模型的训练。 + +如果只考虑前向传播和反向传播,对于一个神经网络,其训练过程如 {numref}`fig-model-training` 所示。{numref}`fig-model-training` 演示了 3 层神经网络,前向过程用 FWD 表示,反向过程用 BWD 表示。 + +```{figure} ../img/ch-data-science/model-training.svg +--- +width: 800px +name: fig-model-training +--- +前向传播(图中用 FWD 表示)和反向传播(图中用 BWD 表示) +``` + +## 推理 + +模型训练就是前向和反向传播,模型推理只需要前向传播,只不过输入层换成了需要预测的 $\boldsymbol{x}$。 \ No newline at end of file diff --git a/ch-data-science/hyperparameter.md b/ch-data-science/hyperparameter.md new file mode 100644 index 0000000..9b9fb33 --- /dev/null +++ b/ch-data-science/hyperparameter.md @@ -0,0 +1,121 @@ +(sec-hyperparameter-optimization)= +# 超参数调优 + +{numref}`sec-deep-learning-intro` 中我们提到了模型的参数和超参数的概念。超参数指的是模型参数(权重)之外的一些参数,比如深度学习模型训练时控制梯度下降速度的学习率,又比如决策树中分支的数量。超参数通常有两类: + +* 模型:神经网络的设计,比如多少层,卷积神经网络的核大小,决策树的分支数量等。 +* 训练和算法:学习率、批量大小等。 + + +## 搜索算法 + +确定这些超参数的方式是开启多个试验(Trial),每个试验测试超参数的某个值,根据模型训练结果的好坏来做选择,这个过程称为超参数调优。寻找最优超参数的过程这个过程可以手动进行,手动费时费力,效率低下,所以业界提出一些自动化的方法。常见的自动化的搜索方法有如下几种,{numref}`fig-tune-algorithms` 展示了在二维搜索空间中进行超参数搜索,每个点表示一种超参数组合,颜色越暖,表示性能越好。迭代式的算法从初始点开始,后续试验依赖之前试验的结果,最后向性能较好的方向收敛。 + +* 网格搜索(Grid Search):网格搜索是一种穷举搜索方法,它通过遍历所有可能的超参数组合来寻找最优解,这些组合会逐一被用来训练和评估模型。网格搜索简单直观,但当超参数空间很大时,所需的计算成本会急剧增加。 +* 随机搜索(Random Search):随机搜索不是遍历所有可能的组合,而是在解空间中随机选择超参数组合进行评估。这种方法的效率通常高于网格搜索,因为它不需要评估所有可能的组合,而是通过随机抽样来探索参数空间。随机搜索尤其适用于超参数空间非常大或维度很高的情况下,它可以在较少的尝试中发现性能良好的超参数配置。然而,由于随机性的存在,随机搜索可能会错过一些局部最优解,因此可能需要更多的尝试次数来确保找到一个好的解。 +* 贝叶斯优化(Bayesian Optimization):贝叶斯优化是一种**迭代式**超参数搜索技术。它基于贝叶斯定理的技术,它利用概率模型来指导搜索最优超参数的过程。这种方法的核心思想是构建一个贝叶斯模型,通常是高斯过程(Gaussian Process),来近似评估目标函数的未知部分。贝叶斯优化能够在有限的评估次数内,智能地选择最有希望的超参数组合进行尝试,特别适用于计算成本高昂的场景。 + +超参数调优是一种黑盒优化,所谓黑盒优化,指的是目标函数是一个黑盒,我们只能通过观察其输入和输出来推断其行为。黑盒的概念比较难以理解,但是我们可以相比梯度下降算法,梯度下降算法**不是**一种黑盒优化算法,我们可以得到目标函数的梯度(或近似值),并用梯度来指导搜索方向,最终找到目标函数的(局部)最优解。黑盒优化算法一般无法找到目标函数的数学表达式和梯度,也无法使用基于梯度的优化技术。贝叶斯优化、遗传算法、模拟退火等都是黑盒优化,这些算法通常在超参数搜索空间中选择一些候选解,运行目标函数,得到超参数组合的实际性能,基于实际性能,不断迭代调整,即重复上述过程,直到满足条件。 + +```{figure} ../img/ch-data-science/tune-algorithms.svg +--- +width: 800px +name: fig-tune-algorithms +--- +在一个二维搜索空间中进行超参数搜索,每个点表示一种超参数组合,暖色表示性能较好,冷色表示性能较差。 +``` + +### 贝叶斯优化 + +贝叶斯优化基于贝叶斯定理,这里不深入探讨详细的数学公式。简单来说,它需要先掌握搜索空间中几个观测样本点(Observation)的实际性能,构建概率模型,描述每个超参数在每个取值点上模型性能指标的**均值**和**方差**。其中,均值代表这个点最终的期望效果,均值越大表示模型最终性能指标越大,方差表示这个点的不确定性,方差越大表示这个点不确定,值得去探索。{numref}`fig-bayesian-optimization-explained` 在一个 1 维超参数搜索空间中迭代 3 步的过程,虚线是目标函数的真实值,实线是预测值(或者叫后验概率分布均值),实线上下的蓝色区域为置信区间。贝叶斯优化利用了高斯回归过程,即目标函数是由一系列观测样本点所构成的随机过程,通过高斯概率模型来描述这个随机过程的概率分布。贝叶斯优化通过不断地收集观测样本点来更新目标函数的后验分布,直到后验分布基本贴合真实分布。对应 {numref}`fig-bayesian-optimization-explained` 中,进行迭代 3 之前只有两个观测样本点,经过迭代 3 和迭代 4 之后中增加了新的观测样本点,这几个样本点附近的预测值逐渐接近真实值。 + +贝叶斯优化有两个核心概念: + +* 代理模型(Surrogate Model):代理模型拟合观测值,预测实际性能,可以理解为图中的实线。 +* 采集函数(Acquisition Function):采集函数用于选择下一个采样点,它使用一些方法,衡量每一个点对目标函数优化的贡献,可以理解为图中橘黄色的线。 + +为防止陷入局部最优,采集函数在选取下一个取值点时,应该既考虑利用(Exploit)那些均值较大的,又探索(Explore)那些方差较大的,即在利用和探索之间寻找一个平衡。例如,模型训练非常耗时,有限的计算资源只能再跑 1 组超参数了,那应该选择均值较大的,因为这样能选到最优结果的可能性最高;如果我们计算资源还能可以跑上千次,那应该多探索不同的可能性。在 {numref}`fig-bayesian-optimization-explained` 的例子中,第 3 次迭代和第 2 次迭代都在第 2 次迭代的观测值附近选择新的点,是在探索和利用之间的一个平衡。 + +```{figure} ../img/ch-data-science/bayesian-optimization-explained.svg +--- +width: 600px +name: fig-bayesian-optimization-explained +--- +使用贝叶斯优化进行过一些迭代后,如何选择下一个点。 +``` + +相比网格搜索和随机搜索,贝叶斯优化并不容易并行化,因为贝叶斯优化需要先运行一些超参数组合,掌握一些实际观测数据。 + +## 调度器 + +### 连续减半算法 + +连续减半算法(Successive Halving Algorithm, SHA){cite}`karnin2013Almost` 的核心思想非常简单,如 {numref}`fig-successive-halving` 所示: + +1. SHA 最开始给每个超参数组合一些计算资源额度。 +2. 将这些超参数组合都训练执行完后,对结果进行评估。 +3. 选出排序靠前的超参数组合,进行下一轮(Rung)训练,性能较差的超参数组合早停。 +4. 下一轮每个超参数组合的计算资源额度以一定的策略增加。 + +```{figure} ../img/ch-data-science/successive-halving.svg +--- +width: 600px +name: fig-successive-halving +--- +SHA 算法示意图:优化某指标最小值 +``` + +计算资源额度(Budget)可以是训练的迭代次数,或训练样本数量等。更精确地,SHA 每轮丢掉 $\frac{\eta - 1}{\eta}$ 的超参数组合,留下 $ \frac{1}{\eta}$ 进入下一轮,下一轮每个超参数组合的计算资源额度变为原来的 $\eta$ 倍。{numref}`tab-sha-resources` 中,每轮总的计算资源为 $B$,总共 81 个超参数组合;第一轮每个试验能分到 $\frac{B}{81}$ 的计算资源;假设 $\eta$ 为 3,只有 $\frac{1}{3}$ 的试验会被提升到下一轮,经过 5 轮后,某个最优超参数组合会被选拔出来。 + +```{table} 使用 SHA 算法,每个试验所能分配到的计算资源。 +:name: tab-sha-resources +| | 超参数组合数量 $n$ | 每个试验所被分配的计算资源 $\frac{B}{n}$ | +|:------: |:---: |:-----: | +| Rung 1 | 81 | $\frac{B}{81}$ | +| Rung 2 | 27 | $\frac{B}{27}$ | +| Rung 3 | 9 | $\frac{B}{9}$ | +| Rung 4 | 3 | $\frac{B}{3}$ | +| Rung 5 | 1 | $B$ | +``` + +SHA 中,需要等待同一轮所有超参数组合训练完并评估结果后,才能进入下一轮;第一轮时,可以并行地执行多个试验,而进入到后几轮,试验越来越少,并行度越来越低。ASHA(Asynchronous Successive Halving Algorithm) 针对 SHA 进行了优化,ASHA 算法不需要等某一轮的训练和评估结束选出下一轮入选者,而是在当前轮进行中的同时,选出可以提升到下一轮的超参数组合,前一轮的训练评估与下一轮的训练评估是同步进行的。 + +SHA 和 ASHA 的一个主要假设是,如果一个试验在初始时间表现良好,那么它在更长的时间内也会表现良好。这个假设显然太过粗糙,一个反例是学习率:较大的学习率在短期内可能会比较小的学习率表现得更好,但长远来看,较大学习率不一定是最优的,SHA 调度器很有可能导致较小学习率的试验被错误地提前终止。从另外一个角度,为了避免潜在的优质试验提前结束,需要在第一轮时给每个试验更多的计算资源,但由于总的计算资源额度有限($B$),所以一种折中方式是选择较少的超参数组合,即 $n$ 的数量要少一些。 + +### Hyperband + +SHA/ASHA 等算法面临着 $n$ 和 $\frac{B}{n}$ 相互平衡的问题:如果 $n$ 太大,每个试验所能分到的资源有限,导致优质试验可能提前结束;如果 $n$ 太小,可选择的搜索空间有限,也可能导致优质试验未被囊括到搜索空间中。HyperBand 算法在 SHA 基础上提出了一种对冲机制。HyperBand 有点像金融投资组合,使用多种金融资产来对冲风险,初始轮不是一个固定的 $n$,而是有多个可能的 $n$。如 {numref}`fig-hyperband-algo` 所示,算法实现上,HyperBand 使用了两层循环,内层循环直接调用 SHA 算法,外层循环尝试不同的 $n$,每种可能性是一种 $s$。HyperBand 额外引入了变量 $R$,$R$ 指的是某一个超参数组合所能分配的最大的计算资源额度,$s_{max}$ 是一共多少可能性,它可以被计算出来:$\lfloor \log_{\eta}{R} \rfloor$;由于额外引入了 $R$,此时总的计算资源 $B = (s_{max} + 1)R$,加一是因为 $s$ 从 0 开始计算。 + +```{figure} ../img/ch-data-science/hyperband-algo.png +--- +width: 600px +name: fig-hyperband-algo +--- +HyperBand 算法 +``` + +{numref}`fig-hyperband-example` 是一个例子:横轴是外层循环,共有 5 个(0 到 4)可能性,初始的计算资源 $n$ 和每个超参数组合所能获得的计算资源 $r$ 形成一个组合(Bracket);纵轴是内层循环,对于某一种初始的 Bracket,执行 SHA 算法,一直迭代到选出最优试验。 + +```{figure} ../img/ch-data-science/hyperband-example.svg +--- +width: 600px +name: fig-hyperband-example +--- +Hyperband 示意图 +``` + +### BOHB + +BOHB {cite}`falkner2018BOHB` 是一种结合了贝叶斯优化和 Hyperband 的调度器。 + +## Population Based Training + +种群训练(Population Based Training,PBT){cite}`jaderberg2017Population` 主要针对深度神经网络训练,它借鉴了遗传算法的思想,可以同时优化模型参数和超参数。PBT 中,种群可以简单理解成不同的试验,PBT 并行地启动多个试验,每个试验从超参数搜索空间中随机选择一个超参数组合,并随机初始化参数矩阵,训练过程中会定期地评估模型指标。模型训练过程中,基于模型性能指标,PBT 会**利用**或**探索**当前试验的模型参数或超参数。当前试验的指标不理想,PBT 会执行“利用”,将当前模型权重换成种群中其他表现较好的参数权重。PBT 也会“探索”:变异生成新的超参数进行接下来的训练。在一次完整的训练过程中,其他超参数调优方法会选择一种超参数组合完成整个训练;PBT 在训练过程中借鉴效果更好的模型权重,或使用新的超参数,因此它被认为同时优化模型参数和超参数。 + +```{figure} ../img/ch-data-science/population-based-training.svg +--- +width: 600px +name: fig-population-based-training +--- +PBT 训练中的利用和探索。利用指模型表现不理想时,将当前模型换成其他表现较好的参数权重;探索指变异生成新的超参数。 +``` \ No newline at end of file diff --git a/ch-data-science/machine-learning.ipynb b/ch-data-science/machine-learning.ipynb index 2b868ab..457f0e7 100644 --- a/ch-data-science/machine-learning.ipynb +++ b/ch-data-science/machine-learning.ipynb @@ -17,46 +17,85 @@ "\n", "给定数据集 $D = \\lbrace(\\boldsymbol{x}_{1}, y_{1}), (\\boldsymbol{x}_{2}, y_{2}), ... , (\\boldsymbol{x}_{m}, y_{m}) \\rbrace$ ,数据集中有 $m$ 个数据对。第 $i$ 条数据为 $(\\boldsymbol{x}_{i}, y_{i})$ ,这条数据被称为一组训练样本(Training Example)。在 {numref}`sec-data-science-lifecycle` 的房价例子中,$\\boldsymbol{x_{i}}$ 是一个向量,向量中的每个元素是数据科学家构建的特征,比如街区收入、房屋年龄、房间数、卧室数、街区人口等。我们可以基于这些数据,使用某种机器学习模型对其进行建模,学习到数据中的规律,得到一个模型,其中某个给定的数据集 $D$ 为样本(Sample),又被称为训练集(Training Set),$\\boldsymbol{x}$ 为特征(Feature),$y$ 为真实值(Label)或者目标值(Target)。\n", "\n", - "当前,性能较好的机器学习算法主要有以深度学习为代表的神经网络算法和以梯度提升决策树为代表的决策树算法。下面主要介绍深度学习算法的基础。\n", + "## 线性回归\n", "\n", - "## 神经网络\n", + "### 一元线性回归\n", "\n", - "简单来说,神经网络是由很多个下面的公式组成。\n", + "我们从线性回归开始,了解机器学习模型的数学原理。中学时,我们使用 $ y = ax + b $ 对很多问题进行建模,方程描述了变量 $y$ 随着变量 $x$ 而变化。方程是一条直线。如果建立好这样的数学模型,已知 $x$ 我们就可以得到预测的 $\\hat{y}$ 了。统计学家给变量 $y$ 带上了一个小帽子,表示这是预测值,以区别于真实观测到的数据。方程只有一个自变量 $x$,且不含平方立方等非一次项,因此被称为 ** 一元线性方程 **。\n", + "\n", + "在对数据集进行建模时,我们只关注房屋面积和房价两个维度的数据。我们可以对参数 $a$ 和 $b$ 取不同值来构建不同的直线,这样就形成了一个参数家族。参数家族中有一个最佳组合,可以在统计上以最优的方式描述数据集。那么一元线性回归的监督学习过程就可以被定义为:给定 $m$ 个数据对 $(x, y)$ ,寻找最佳参数 $a^*$ 和 $b^*$,使模型可以更好地拟合这些数据。$a$ 和 $b$ 可以取不同的参数,到底哪个参数组合是最佳的呢?如何衡量模型是否以最优的方式拟合数据呢?机器学习用损失函数(Loss Function)的来衡量这个问题。损失函数又称为代价函数(Cost Function),它计算了模型预测值 $\\hat{y}$ 和真实值 $y$ 之间的差异程度。从名字也可以看出,这个函数计算的是模型犯错的损失或代价,损失函数越大,模型越差,越不能拟合数据。统计学家通常使用 $L(\\hat{y}, y)$ 来表示损失函数。\n", + "\n", + "对于线性回归,一个简单实用的损失函数为预测值与真实值误差平方的平均值,下面公式中,$(i)$ 表示数据集中的第 $i$ 个样本点:\n", "\n", "$$\n", - "\\begin{aligned}\n", - "\\boldsymbol{z} &= \\boldsymbol{W} \\cdot \\boldsymbol{x} + \\boldsymbol{b} \\\\\n", - "\\boldsymbol{a} &= f(\\boldsymbol{z})\n", - "\\end{aligned}\n", + "L(\\hat{y}, y) = \\frac{1}{m} \\sum_{i=1}^m(\\hat{y}_{i}- y_{i})^2\n", "$$\n", "\n", - "$\\boldsymbol{x}$ 是输入,$\\boldsymbol{W}$ 是神经网络的参数(Parameter),又被称为权重(Weight)。神经网络的学习的过程,就是不断更新参数 $\\boldsymbol{W}$ 的过程,这也是所谓的训练过程。训练好的模型可以用来推理,用来预测未知数据。\n", + "在其基础上代入公式 $\\hat{y}=ax + b$,得到:\n", "\n", - "$f$ 是激活函数(Activation Function)。$\\boldsymbol{W}$ 与 $\\boldsymbol{x}$ 相乘仅是一个线性变换,就算很多个乘法叠加起来,仍然是线性变换,或者说没有激活函数的多层网络就退化成了一个单层线性模型。激活函数可以在神经网络中引入了非线性因素,使得多层神经网络理论上可以拟合任何输入数据到输出数据的模式。从模拟生物的神经元的角度,激活函数是为了让有些神经元被激活,有些神经元被抑制。常见的激活函数有 Sigmoid 和 ReLU。\n", + "$$\n", + "L(\\hat{y}, y) =\\frac{1}{m} \\sum_{i=1}^m[(ax_{i} + b) - y_{i}]^2\n", + "$$\n", "\n", - "Sigmoid 的公式为:\n", + "对于给定数据集,$x$ 和 $y$ 的值是已知的,参数 $a$ 和 $b$ 是要求解的,模型求解的过程就是解下面公式的过程:\n", "\n", "$$\n", - "f(x) = \\frac{1}{1+\\exp{(-x)}}\n", + "a^*, b^* = \\mathop{\\arg\\min}_{a, b}L(a, b)\n", "$$\n", "\n", - "ReLU 的公式为:\n", + "式中 $\\arg\\min$ 是一种常见的数学符号,表示寻找能让 $L$ 函数最小的参数 $a*$ 和 $b*$。\n", "\n", "$$\n", - "f(x) = \\max (0, x)\n", + "a^*, b^* = \\mathop{\\arg\\min}_{a, b}\\frac{1}{m}\\sum_{i=1}^m[(ax_{i} + b) - y_{i}]^2\n", "$$\n", "\n", - "将这两个函数可视化,效果如下:" + "\n", + "求解这个函数一般有两个方法:\n", + "\n", + "* 基于微积分和线性代数知识,求使得 $L$ 导数为 0 的点,这个点一般为最优点。这种方式只能解那些简单的模型。\n", + "* 基于梯度下降,迭代地搜索最优点。梯度下降法能解很多复杂复杂的模型,比如深度学习模型,{numref}`sec-deep-learning-intro` 进一步解释了梯度下降法。\n", + "\n", + "### 线性回归的一般形式\n", + "\n", + "我们现在把回归问题扩展到更为一般的场景。假设 $\\boldsymbol{x}$ 是多元的,或者说是多维的。比如,要预测房价,需要考虑的因素很多,包括学区、卧室数量(两居、三居、四居)、周边商业、交通等。如下面公式所示,每个因素是一个 $w$:\n", + "\n", + "$$\n", + "f(\\boldsymbol{x}) = b + W_1 \\times x_1 + W_2 \\times x_2 + ... + W_n \\times x_n\n", + "$$\n", + "\n", + "这里的 $\\boldsymbol{w}$ 是 **参数**(Parameter),也可以叫做 **权重**(Weight)。这里共有 $n$ 种维度的影响因素,机器学习领域将这 $n$ 种影响因素称为 **特征**(Feature)。用向量表示为:\n", + "\n", + "$$\n", + "f(\\boldsymbol{x}) = b + \\boldsymbol{W} \\boldsymbol{{x}}\n", + "$$\n", + "\n", + "要预测的 $y$ 是实数,从负无穷到正无穷,预测实数模型被称为**回归模型**。\n", + "\n", + "## 逻辑回归\n", + "\n", + "回归问题是指目标值为整个实数域,分类问题是指目标值为有限的离散值。比如,我们想进行一个情感分类,目标值有 0 和 1 两个选项,表示负向和正向,一个二分类函数可以表示为:\n", + "\n", + "$$\n", + "y = \n", + "\\begin{cases}\n", + " 0 &\\text{if } z < 0 \\\\\n", + " 1 &\\text{if } z \\geq 0\n", + "\\end{cases}\n", + "$$\n", + "\n", + "在线性回归的基础上,在其外层套上一个函数 $g(z)$:\n", + "\n", + "$$\n", + "g(z)= \\frac 1 {1+e^{-z}}\n", + "$$\n", + "\n", + "这个 $g(z)$ 被称为 Sigmoid 函数或 Logistic 函数,下面对 Sigmod 函数进行了可视化。" ] }, { "cell_type": "code", - "execution_count": 34, - "metadata": { - "tags": [ - "hide-input" - ] - }, + "execution_count": 7, + "metadata": {}, "outputs": [ { "data": { @@ -64,12 +103,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2024-02-06T21:51:21.931679\n", + " 2024-05-08T13:54:08.374456\n", " image/svg+xml\n", " \n", " \n", @@ -84,42 +123,42 @@ " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" ], "text/plain": [ - "
" + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "
" ] }, "metadata": {}, @@ -970,7 +867,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(6, 3), sharey=True)\n", + "fig, axs = plt.subplots(1, 1, figsize=(4, 3), sharey=True)\n", "\n", "# 创建 x 轴的数据\n", "x = np.linspace(-4, 4, 200)\n", @@ -978,26 +875,13 @@ "# 创建 sigmoid 函数的 y 轴数据\n", "sigmoid = 1 / (1 + np.exp(-x))\n", "\n", - "# 创建 ReLu 函数的 y 轴数据\n", - "relu = np.maximum(0, x)\n", - "\n", - "# 在第一个子图上画 sigmoid 函数\n", - "axs[0].plot(x, sigmoid)\n", - "axs[0].axhline(0, color='black',linewidth=0.5)\n", - "axs[0].axvline(0, color='black',linewidth=0.5)\n", - "axs[0].grid(True)\n", - "axs[0].set_title('Sigmoid')\n", + "plt.plot(x, sigmoid, label='Sigmoid function')\n", + "plt.title('Sigmoid')\n", + "plt.xlabel('x')\n", + "plt.ylabel('g(x)')\n", + "plt.grid(True, which='both', linestyle='--', linewidth=0.5)\n", "\n", - "# 在第二个子图上画 ReLu 函数\n", - "axs[1].plot(x, relu)\n", - "axs[1].axhline(0, color='black',linewidth=0.5)\n", - "axs[1].axvline(0, color='black',linewidth=0.5)\n", - "axs[1].grid(True)\n", - "axs[1].set_title('ReLU')\n", - "\n", - "plt.yticks([0, 1, 2, 3, 4])\n", "plt.tight_layout()\n", - "# 显示图像\n", "plt.show()" ] }, @@ -1005,99 +889,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 前向传播\n", - "\n", - "\n", - "{numref}`fig-forward-pass` 是一种最简单的神经网络:将 $\\boldsymbol{z^{[n]}} = \\boldsymbol{W^{[n]}} \\cdot \\boldsymbol{a^{[n-1]}} + \\boldsymbol{b^{[n]}}$ 和 $\\boldsymbol{a^{[n]}} = f(\\boldsymbol{z^{[n]}})$ 堆叠,前一层的输出 $\\boldsymbol{a^{[n-1]}}$ 作为下一层的输入。这种网络又被成为前馈神经网络(Feedforward Neural Network),或者多层感知机(Multilayer Perceptron,MLP)。多层网络中,为了区分某一层,用方括号上标来表示,比如 $\\boldsymbol{a^{[1]}}$ 是第一层的输出,$\\boldsymbol{W^{[1]}}$ 是第一层的参数。\n", - "\n", - "```{figure} ../img/ch-data-science/forward-pass.svg\n", - "---\n", - "width: 800px\n", - "name: fig-forward-pass\n", - "---\n", - "神经网络的前向传播\n", - "```\n", - "\n", - "{numref}`fig-forward-pass` 是神经网络前向传播的过程:假设输入 $\\boldsymbol{x}$ 是一个 3 维的向量;{numref}`fig-forward-pass` 中的每个圆圈为向量的一个元素(一个标量值),图中同时也演示了第一层的 $\\boldsymbol{a^{[1]}}$ 的向量化计算方式,以及 $z^{[1]}_1$ 的标量化计算方式,实际场景中往往需要使用现代处理器的向量化引擎完成计算。\n", - "\n", - "## 反向传播\n", - "\n", - "神经网络的训练过程就是不断更新各层的 $\\boldsymbol{W}$ 和 $\\boldsymbol{b}$。\n", - "\n", - "首先以某种随机初始化方式,初始化各层的 $\\boldsymbol{W}$ 和 $\\boldsymbol{b}$。比如,初始化为正态分布的小数。\n", - "\n", - "然后确定一个损失函数(Loss Function)$L$。损失函数计算了神经网络预测值 $\\hat{y}$ 与真实值 $y$ 之间的差距,训练的目标就是让损失函数变小。比如,预测房价的案例,我们使用误差的平方(Squared Error)作为损失函数,某一个样本的损失函数为 $L = (y - \\hat{y})^2$。\n", - "\n", - "然后计算损失函数对每层参数的导数。$L$ 关于第 $l$ 层 $\\boldsymbol{W^{[l]}}$ 和 $\\boldsymbol{b^{[l]}}$ 的导数为 $\\frac{\\partial L}{\\partial \\boldsymbol{W^{[l]}}}$ 和 $\\frac{\\partial L}{\\partial \\boldsymbol{b^{[l]}}}$,再按照下面的公式更新 $\\boldsymbol{W^{[l]}}$ 和 $\\boldsymbol{b^{[l]}}$。\n", + "Logistic 函数的性质决定了它可以将 $(-\\infty, +\\infty)$ 映射到 $(0, 1)$ 上, Logistic 函数有明确的分界线,在中心点处取值为 0.5,因为 Logistic 函数有明确的分界线,可以用来进行分类。我们将线性回归套入 Logistic 函数,可以得到:\n", "\n", "$$\n", - "\\begin{aligned}\n", - "\\boldsymbol{W^{[l]}} &= \\boldsymbol{W^{[l]}}-\\alpha\\frac{\\partial L}{\\partial \\boldsymbol{W^{[l]}}}\\\\\n", - "\\boldsymbol{b^{[l]}} &= \\boldsymbol{b^{[l]}}-\\alpha\\frac{\\partial L}{\\partial \\boldsymbol{b^{[l]}}}\\\\\n", - "\\end{aligned}\n", + "y = f(\\boldsymbol{x}) = g(\\boldsymbol{W} \\boldsymbol{x}) = \\frac 1{1+e^{-\\boldsymbol{W} \\boldsymbol{x}}}\n", "$$\n", "\n", - "公式中,$\\alpha$ 是学习率,即参数更新的速度,如果学习率太大则容易振荡,不容易收敛,太小则收敛速度又会过慢。\n", - "\n", - "各层的导数又被称为梯度,参数沿着梯度方向下降,又被成为梯度下降法。计算各层的导数时,往往是从最后的损失函数开始,向前一层一层地求梯度,即先求最后第 $n$ 层的梯度,得到第 $n$ 层的梯度,结合链式法则,求第 $n-1$ 层的梯度。{numref}`fig-back-propagation` 展示了神经网络的反向传播过程。\n", - "\n", - "```{figure} ../img/ch-data-science/back-propagation.svg\n", - "---\n", - "width: 800px\n", - "name: fig-back-propagation\n", - "---\n", - "神经网络的反向传播\n", - "```\n", - "\n", - "## 超参数\n", - "\n", - "神经网络训练过程中,有很多训练模型之前需要人为设定的一些参数,这些参数不能通过模型的反向传播算法来自动学习,而需要手动选择和调整。这些参数又被成为超参数,超参数的选择通常基于经验或反复试验。以下是一些超参数:\n", - "\n", - "* 学习率,即刚才提到的 $\\alpha$,控制着每次更新参数的步长。\n", - "* 网络结构:模型的层数、每层的神经元数量、激活函数的选择等。不同的网络结构对于不同的任务可能有不同的性能表现。\n", - "\n", - "## 实现细节\n", - "\n", - "神经网络训练实现起来要关注以下三个步骤:\n", - "\n", - "* 一次前向传播\n", - "* 一次反向传播\n", - "* 一次更新模型权重\n", - "\n", - "{numref}`fig-model-training-input-output` 整理了神经网络的第 i 层进行训练时,以上三个步骤的输入和输出。\n", - "\n", - "```{figure} ../img/ch-data-science/model-training-input-output.svg\n", - "---\n", - "width: 800px\n", - "name: fig-model-training-input-output\n", - "---\n", - "前向传播、反向传播和更新模型权重的输入和输出\n", - "```\n", - "\n", - "对于前向传播,输入有两部分:i-1 层输出 $\\boldsymbol{a^{[i-1]}}$ 和第 i 层的模型权重 $\\boldsymbol{W^{[i]}}$、$\\boldsymbol{b^{[i]}}$;输出又被称为激活(Activation)。\n", - "\n", - "对于反向传播,输入有三部分:i 层输出 $\\boldsymbol{a^{[i]}}$;第 i 层的模型权重 $\\boldsymbol{W^{[i]}}$、$\\boldsymbol{b^{[i]}}$;损失对 i 层输出的导数 $\\boldsymbol{\\boldsymbol{\\frac{\\partial L}{a^{[i]}}}}$。根据链式法则,可以求出损失对 i 层模型权重的导数 $\\boldsymbol{\\frac{\\partial L}{\\partial W^{[i]}\n", - "}}$、$\\boldsymbol{\\frac{\\partial L}{\\partial b^{[i]}\n", - "}}$,也就是梯度。\n", - "\n", - "得到梯度后,需要沿着梯度下降的方向更新模型权重。如果是最简单的梯度下降法,优化器直接在模型原有权重基础上做减法,不需要额外保存状态,比如:$\\boldsymbol{W^{[l]}} = \\boldsymbol{W^{[l]}}-\\alpha\\frac{\\partial L}{\\partial \\boldsymbol{W^{[l]}}}$\n", - "\n", - "复杂一点的优化器,比如 Adam, 在梯度下降时引入了动量的概念。动量是梯度的指数移动平均,需要维护一个梯度的移动平均矩阵,这个矩阵就是优化器的状态。因此,优化器状态、原来的模型权重和梯度共同作为输入,可以得到更新后的模型权重。至此才能完成一轮模型的训练。\n", - "\n", - "如果只考虑前向传播和反向传播,对于一个神经网络,其训练过程如 {numref}`fig-model-training` 所示。{numref}`fig-model-training` 演示了 3 层神经网络,前向过程用 FWD 表示,反向过程用 BWD 表示。\n", - "\n", - "```{figure} ../img/ch-data-science/model-training.svg\n", - "---\n", - "width: 800px\n", - "name: fig-model-training\n", - "---\n", - "前向传播(图中用 FWD 表示)和反向传播(图中用 BWD 表示)\n", - "```\n", - "\n", - "## 推理\n", - "\n", - "模型训练就是前向和反向传播,模型推理只需要前向传播,只不过输入层换成了需要预测的 $\\boldsymbol{x}$。" + "这里不再赘述 Logistic 回归的训练求解数学推导,感兴趣的读者可以在互联网上搜索相关知识。" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/ch-ray-train-tune/index.md b/ch-modin-xorbits/index.md similarity index 54% rename from ch-ray-train-tune/index.md rename to ch-modin-xorbits/index.md index c266b11..5081073 100644 --- a/ch-ray-train-tune/index.md +++ b/ch-modin-xorbits/index.md @@ -1,4 +1,4 @@ -# Ray Train 和 Tune +# Modin 与 Xorbits ```{tableofcontents} ``` \ No newline at end of file diff --git a/ch-modin-xorbits/modin.ipynb b/ch-modin-xorbits/modin.ipynb new file mode 100644 index 0000000..e69de29 diff --git a/ch-ray-core/remote-class.ipynb b/ch-ray-core/remote-class.ipynb index 938794f..ab318bb 100644 --- a/ch-ray-core/remote-class.ipynb +++ b/ch-ray-core/remote-class.ipynb @@ -12,7 +12,7 @@ "\n", "{numref}`sec-remote-function` 展示了如何将一个无状态的函数扩展到 Ray 集群上进行分布式计算,但实际的场景中,我们经常需要进行有状态的计算。最简单的有状态计算包括维护一个计数器,每遇到某种条件,计数器加一。这类有状态的计算对于给定的输入,不一定得到确定的输出。单机场景我们可以使用 Python 的类(Class)来实现,计数器可作为类的成员变量。Ray 可以将 Python 类拓展到集群上,即远程类(Remote Class),又被称为行动者(Actor)。Actor 的名字来自 Actor 编程模型 {cite}`hewitt1973Universal` ,这是一个典型的分布式计算编程模型,被广泛应用在大数据和人工智能领域,但 Actor 编程模型比较抽象,我们先从计数器的案例来入手。\n", "\n", - "## 案例1:分布式计数器" + "## 案例:分布式计数器" ] }, { @@ -316,7 +316,7 @@ "\n", "Actor 编程模型是消息驱动的,给某个 Actor 发送消息,它就会对该消息进行响应,修改自身的状态或者继续给其他 Actor 发送消息。Actor 编程模型不需要显式地在多个进程之间同步数据,因此也没有锁的问题以及同步等待的时间。Actor 编程模型可被用于大量异步操作的场景。\n", "\n", - "## 案例2:排行榜\n", + "## 案例:排行榜\n", "\n", "接下来我们基于 Actor 实现一个更加复杂的案例:成绩排行榜。这个排行榜的状态是一个键值对,名为 `self.board`,键是名字(`name`),是一个 `str` 类型,值是分数(`score`),是一个 `float` 类型。" ] @@ -600,7 +600,7 @@ "id": "25814b1f", "metadata": {}, "source": [ - "## 案例3:Actor Pool\n", + "## 案例:Actor Pool\n", "\n", "实践上,经常创建一个 Actor 资源池(Actor Pool),[`ActorPool`](https://docs.ray.io/en/latest/ray-core/api/doc/ray.util.ActorPool.html) 有点像 `multiprocessing.Pool`,Actor Pool 中有包含多个 Actor,每个 Actor 功能一样,而且可以分式地在多个计算节点上运行。" ] @@ -625,7 +625,6 @@ "source": [ "from ray.util import ActorPool\n", "\n", - "\n", "@ray.remote\n", "class PoolActor:\n", " def add(self, operands):\n", diff --git a/ch-ray-core/remote-function.ipynb b/ch-ray-core/remote-function.ipynb index 5d175e7..36704da 100644 --- a/ch-ray-core/remote-function.ipynb +++ b/ch-ray-core/remote-function.ipynb @@ -152,7 +152,7 @@ "id": "8dd151fc", "metadata": {}, "source": [ - "## 案例1:斐波那契数列\n", + "## 案例:斐波那契数列\n", "\n", "接下来,我们用斐波那契数列的案例来演示如何使用 Ray 对 Python 函数进行分布式的扩展。\n", "\n", @@ -338,7 +338,7 @@ "\n", "原生 Python 函数 `func_name()` 的调用是同步执行的,或者说等待结果返回才进行后续计算,又或者说这个调用是阻塞的。一个 Ray 函数`func_name.remote()` 是异步执行的,或者说调用者不需要等待这个函数的计算真正执行完, Ray 就立即返回了一个 `ray.ObjectRef`,函数的计算是在后台某个计算节点上执行的。`ray.get(ObjectRef)` 会等待后台计算结果执行完,将结果返回给调用者。`ray.get(ObjectRef)` 是一个一个阻塞调用。\n", "\n", - "### 案例2:蒙特卡洛估计 $\\pi$\n", + "### 案例:蒙特卡洛估计 $\\pi$\n", "\n", "接下来我们使用蒙特卡洛方法来估计 $\\pi$。如 {ref}`fig-square-circle`: 我们在一个 $2 \\times 2$ 的正方形中随机撒点,正方形内有一个半径为1的圆。所撒的点以一定概率落在圆内,假定我们已知落在圆内的概率是 $\\frac{\\pi}{4}$,我们可以根据随机撒点的概率情况推算出 $\\pi$ 的值。根据概率论相关知识,撒的点越多,概率越接近真实值。\n", "\n", @@ -555,7 +555,7 @@ "id": "9d4f19c4", "metadata": {}, "source": [ - "## 案例3:分布式图片处理\n", + "## 案例:分布式图片处理\n", "\n", "接下来我们模拟一个更加计算密集的分布式图片预处理的任务。所处理内容均为高清像素图片,大概4MB。这些图片数据预处理工作在当前人工智能场景下非常普遍。接下来的任务主要包括:\n", "\n", diff --git a/ch-ray-core/remote-object.ipynb b/ch-ray-core/remote-object.ipynb index 206c98e..782201c 100644 --- a/ch-ray-core/remote-object.ipynb +++ b/ch-ray-core/remote-object.ipynb @@ -286,9 +286,9 @@ "origin_pos": 8 }, "source": [ - "## Example 1: Transforming Data\n", + "## 案例:对数据进行转换\n", "\n", - "The data of remote objects is immutable. For example, the following operation is common in the local memory but cannot be directly applied to a remote object." + "Remote Ojbect 中的数据是不可修改的(Immutable),即无法对变量原地更改。下面的代码中,在单机上,我们可以对变量 `a` 进行赋值,但这些原地更改 Remote Object 的值。" ] }, { diff --git a/ch-ray-data/data-load-inspect-save.ipynb b/ch-ray-data/data-load-inspect-save.ipynb index 3eaf6b7..3e2018d 100644 --- a/ch-ray-data/data-load-inspect-save.ipynb +++ b/ch-ray-data/data-load-inspect-save.ipynb @@ -85,7 +85,7 @@ "\n", "import sys\n", "sys.path.append(\"..\")\n", - "from datasets import nyc_taxi\n", + "from utils import nyc_taxi\n", "\n", "import ray\n", "\n", diff --git a/ch-ray-data/data-transform.ipynb b/ch-ray-data/data-transform.ipynb index 6bbaf40..6596809 100644 --- a/ch-ray-data/data-transform.ipynb +++ b/ch-ray-data/data-transform.ipynb @@ -108,7 +108,7 @@ "from typing import Any, Dict\n", "\n", "sys.path.append(\"..\")\n", - "from datasets import nyc_taxi\n", + "from utils import nyc_taxi\n", "\n", "import numpy as np\n", "import pandas as pd\n", diff --git a/ch-ray-data/map-map-batches.svg b/ch-ray-data/map-map-batches.svg deleted file mode 100644 index 9292629..0000000 --- a/ch-ray-data/map-map-batches.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
map()
map()
map_batches()
map_batche...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/ch-ray-data/preprocessor.ipynb b/ch-ray-data/preprocessor.ipynb index 3d6fc37..d0ce7cf 100644 --- a/ch-ray-data/preprocessor.ipynb +++ b/ch-ray-data/preprocessor.ipynb @@ -7,256 +7,7 @@ "(sec-ray-data-preprocessor)=\n", "# Preprocessor\n", "\n", - "{numref}`sec-ray-data-transform` 介绍了通用接口 `map()` 和 `map_batches()`。对于结构化的表格类数据,Ray Data 在 `map()` 和 `map_batches()` 基础上,增加了一个高阶的 API:预处理器(Preprocessor)。[Preprocessor](https://docs.ray.io/en/latest/data/api/preprocessor.html) 是一系列特征处理操作,可与机器学习模型训练和推理更好地结合。其使用方式与 scikit-learn 的 [sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) 非常相似,熟悉 scikit-learn 的用户可以快速迁移过来。对于非结构化数据,比如图片、视频等,仍然建议使用 `map()` 或者 `map_batches()`。\n", - "\n", - "## 使用\n", - "\n", - "Preprocessor 主要有 4 类操作:\n", - "\n", - "1. [`fit()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessor.Preprocessor.fit.html):计算 Ray Data `Dataset` 状态信息,比如计算某一列数据的方差或者均值。\n", - "2. [`transform()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessor.Preprocessor.transform.html):执行转换操作。如果这个转换操作是有状态的,那必须先进行 `fit()`。\n", - "3. [`transform_batch()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessor.Preprocessor.transform_batch.html):对一个批次数据进行转换操作。\n", - "4. [`fit_transform()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessor.Preprocessor.fit_transform.html):结合了 `fit()` 和 `transform()` 的一个操作,先对 `Dataset` 进行 `fit()`,再进行 `transform()`。\n", - "\n", - "下面根据出租车数据集,来演示一下如何使用 Preprocessor。出租车数据是一个典型的结构化数据,里面有很多列,比如该旅程的距离,这些列可被用来作为机器学习算法的特征,而喂给机器学习模型前,需要进行特征处理。比如 [`MinMaxScaler`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.MinMaxScaler.html) 将特征进行归一化:\n", - "\n", - "$$\n", - "x' = \\frac{x - \\min(x)}{\\max(x) - \\min(x)}\n", - "$$" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mray\u001b[39;00m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ray\u001b[38;5;241m.\u001b[39mis_initialized:\n\u001b[0;32m---> 13\u001b[0m \u001b[43mray\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshutdown\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m ray\u001b[38;5;241m.\u001b[39minit()\n", - "File \u001b[0;32m~/miniconda3/envs/dispy/lib/python3.11/site-packages/ray/_private/client_mode_hook.py:103\u001b[0m, in \u001b[0;36mclient_mode_hook..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minit\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m is_client_mode_enabled_by_default:\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(ray, func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m)(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/dispy/lib/python3.11/site-packages/ray/_private/worker.py:1838\u001b[0m, in \u001b[0;36mshutdown\u001b[0;34m(_exiting_interpreter)\u001b[0m\n\u001b[1;32m 1836\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _global_node\u001b[38;5;241m.\u001b[39mis_head():\n\u001b[1;32m 1837\u001b[0m _global_node\u001b[38;5;241m.\u001b[39mdestroy_external_storage()\n\u001b[0;32m-> 1838\u001b[0m \u001b[43m_global_node\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkill_all_processes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcheck_alive\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mallow_graceful\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 1839\u001b[0m _global_node \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1840\u001b[0m storage\u001b[38;5;241m.\u001b[39m_reset()\n", - "File \u001b[0;32m~/miniconda3/envs/dispy/lib/python3.11/site-packages/ray/_private/node.py:1604\u001b[0m, in \u001b[0;36mNode.kill_all_processes\u001b[0;34m(self, check_alive, allow_graceful, wait)\u001b[0m\n\u001b[1;32m 1598\u001b[0m \u001b[38;5;66;03m# Kill the raylet first. This is important for suppressing errors at\u001b[39;00m\n\u001b[1;32m 1599\u001b[0m \u001b[38;5;66;03m# shutdown because we give the raylet a chance to exit gracefully and\u001b[39;00m\n\u001b[1;32m 1600\u001b[0m \u001b[38;5;66;03m# clean up its child worker processes. If we were to kill the plasma\u001b[39;00m\n\u001b[1;32m 1601\u001b[0m \u001b[38;5;66;03m# store (or Redis) first, that could cause the raylet to exit\u001b[39;00m\n\u001b[1;32m 1602\u001b[0m \u001b[38;5;66;03m# ungracefully, leading to more verbose output from the workers.\u001b[39;00m\n\u001b[1;32m 1603\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ray_constants\u001b[38;5;241m.\u001b[39mPROCESS_TYPE_RAYLET \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_processes:\n\u001b[0;32m-> 1604\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_kill_process_type\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1605\u001b[0m \u001b[43m \u001b[49m\u001b[43mray_constants\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPROCESS_TYPE_RAYLET\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1606\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_alive\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheck_alive\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1607\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_graceful\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_graceful\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1608\u001b[0m \u001b[43m \u001b[49m\u001b[43mwait\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1609\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1611\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ray_constants\u001b[38;5;241m.\u001b[39mPROCESS_TYPE_GCS_SERVER \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_processes:\n\u001b[1;32m 1612\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_kill_process_type(\n\u001b[1;32m 1613\u001b[0m ray_constants\u001b[38;5;241m.\u001b[39mPROCESS_TYPE_GCS_SERVER,\n\u001b[1;32m 1614\u001b[0m check_alive\u001b[38;5;241m=\u001b[39mcheck_alive,\n\u001b[1;32m 1615\u001b[0m allow_graceful\u001b[38;5;241m=\u001b[39mallow_graceful,\n\u001b[1;32m 1616\u001b[0m wait\u001b[38;5;241m=\u001b[39mwait,\n\u001b[1;32m 1617\u001b[0m )\n", - "File \u001b[0;32m~/miniconda3/envs/dispy/lib/python3.11/site-packages/ray/_private/node.py:1425\u001b[0m, in \u001b[0;36mNode._kill_process_type\u001b[0;34m(self, process_type, allow_graceful, check_alive, wait)\u001b[0m\n\u001b[1;32m 1423\u001b[0m \u001b[38;5;66;03m# Ensure thread safety\u001b[39;00m\n\u001b[1;32m 1424\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mremoval_lock:\n\u001b[0;32m-> 1425\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_kill_process_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1426\u001b[0m \u001b[43m \u001b[49m\u001b[43mprocess_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1427\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_graceful\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_graceful\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1428\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_alive\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheck_alive\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1429\u001b[0m \u001b[43m \u001b[49m\u001b[43mwait\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1430\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/dispy/lib/python3.11/site-packages/ray/_private/node.py:1481\u001b[0m, in \u001b[0;36mNode._kill_process_impl\u001b[0;34m(self, process_type, allow_graceful, check_alive, wait)\u001b[0m\n\u001b[1;32m 1479\u001b[0m timeout_seconds \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1480\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1481\u001b[0m \u001b[43mprocess\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout_seconds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m subprocess\u001b[38;5;241m.\u001b[39mTimeoutExpired:\n\u001b[1;32m 1483\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/dispy/lib/python3.11/subprocess.py:1264\u001b[0m, in \u001b[0;36mPopen.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1262\u001b[0m endtime \u001b[38;5;241m=\u001b[39m _time() \u001b[38;5;241m+\u001b[39m timeout\n\u001b[1;32m 1263\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1264\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m:\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;66;03m# https://bugs.python.org/issue25942\u001b[39;00m\n\u001b[1;32m 1267\u001b[0m \u001b[38;5;66;03m# The first keyboard interrupt waits briefly for the child to\u001b[39;00m\n\u001b[1;32m 1268\u001b[0m \u001b[38;5;66;03m# exit under the common assumption that it also received the ^C\u001b[39;00m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;66;03m# generated SIGINT and will exit rapidly.\u001b[39;00m\n\u001b[1;32m 1270\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/miniconda3/envs/dispy/lib/python3.11/subprocess.py:2040\u001b[0m, in \u001b[0;36mPopen._wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 2038\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m TimeoutExpired(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, timeout)\n\u001b[1;32m 2039\u001b[0m delay \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(delay \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m, remaining, \u001b[38;5;241m.05\u001b[39m)\n\u001b[0;32m-> 2040\u001b[0m time\u001b[38;5;241m.\u001b[39msleep(delay)\n\u001b[1;32m 2041\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2042\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturncode \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "import os\n", - "import shutil\n", - "import urllib.request\n", - "from typing import Any, Dict\n", - "\n", - "import sys\n", - "sys.path.append(\"..\")\n", - "from datasets import nyc_taxi\n", - "\n", - "import ray\n", - "\n", - "if ray.is_initialized:\n", - " ray.shutdown()\n", - "\n", - "ray.init()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'nyc_taxi' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mray\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpreprocessors\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MinMaxScaler\n\u001b[0;32m----> 3\u001b[0m dataset_path \u001b[38;5;241m=\u001b[39m \u001b[43mnyc_taxi\u001b[49m()\n\u001b[1;32m 4\u001b[0m ds \u001b[38;5;241m=\u001b[39m ray\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mread_parquet(dataset_path,\n\u001b[1;32m 5\u001b[0m columns\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrip_distance\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 6\u001b[0m ds\u001b[38;5;241m.\u001b[39mtake(\u001b[38;5;241m1\u001b[39m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'nyc_taxi' is not defined" - ] - } - ], - "source": [ - "from ray.data.preprocessors import MinMaxScaler\n", - "\n", - "dataset_path = nyc_taxi()\n", - "ds = ray.data.read_parquet(dataset_path,\n", - " columns=[\"trip_distance\"])\n", - "ds.take(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "经过 `MinMaxScaler` 归一化之后,原来的值变为一个归一化之后的值。" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-12-15 14:17:29,924\tINFO split_read_output_blocks.py:101 -- Using autodetected parallelism=173 for stage ReadParquet to satisfy output blocks of size at least DataContext.get_current().target_min_block_size=1.0MiB.\n", - "2023-12-15 14:17:29,925\tINFO split_read_output_blocks.py:106 -- To satisfy the requested parallelism of 173, each read task output is split into 173 smaller blocks.\n", - "2023-12-15 14:17:29,926\tINFO streaming_executor.py:104 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]\n", - "2023-12-15 14:17:29,927\tINFO streaming_executor.py:105 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", - "2023-12-15 14:17:29,928\tINFO streaming_executor.py:107 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n", - "\n", - "\u001b[A\n", - "\u001b[A\n", - "\n", - "\u001b[A\u001b[A\n", - "\n", - "Running: 0.0/8.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory: 0%| | 0/173 [00:00 TaskPoolMapOperator[ReadParquet] -> TaskPoolMapOperator[MapBatches(MinMaxScaler._transform_pandas)] -> LimitOperator[limit=1]\n", - "2023-12-15 14:17:31,198\tINFO streaming_executor.py:105 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", - "2023-12-15 14:17:31,200\tINFO streaming_executor.py:107 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n", - "\u001b[36m(ReadParquet->SplitBlocks(173) pid=6869)\u001b[0m /Users/luweizheng/anaconda3/envs/dispy/lib/python3.11/site-packages/ray/data/_internal/arrow_block.py:128: FutureWarning: promote has been superseded by mode='default'.\n", - "\u001b[36m(ReadParquet->SplitBlocks(173) pid=6869)\u001b[0m return transform_pyarrow.concat(tables) \n", - " \r" - ] - }, - { - "data": { - "text/plain": [ - "[{'trip_distance': 1.8353531664835362e-05}]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "preprocessor = MinMaxScaler(columns=[\"trip_distance\"])\n", - "preprocessor.fit(ds)\n", - "minmax_ds = preprocessor.transform(ds)\n", - "minmax_ds.take(1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/luweizheng/anaconda3/envs/dispy/lib/python3.11/site-packages/ray/data/preprocessor.py:125: UserWarning: `fit` has already been called on the preprocessor (or at least one contained preprocessors if this is a chain). All previously fitted state will be overwritten!\n", - " warnings.warn(\n", - "2023-12-15 14:51:29,990\tINFO split_read_output_blocks.py:101 -- Using autodetected parallelism=173 for stage ReadParquet to satisfy output blocks of size at least DataContext.get_current().target_min_block_size=1.0MiB.\n", - "2023-12-15 14:51:29,993\tINFO split_read_output_blocks.py:106 -- To satisfy the requested parallelism of 173, each read task output is split into 173 smaller blocks.\n", - "2023-12-15 14:51:29,995\tINFO streaming_executor.py:104 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]\n", - "2023-12-15 14:51:29,997\tINFO streaming_executor.py:105 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", - "2023-12-15 14:51:29,998\tINFO streaming_executor.py:107 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n", - "\n", - "\u001b[A\n", - "\u001b[A\n", - "\n", - "\u001b[A\u001b[A\n", - "\n", - "Running: 0.0/8.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory: 0%| | 0/173 [00:00 TaskPoolMapOperator[ReadParquet] -> TaskPoolMapOperator[MapBatches(MinMaxScaler._transform_pandas)] -> LimitOperator[limit=1]\n", - "2023-12-15 14:51:31,838\tINFO streaming_executor.py:105 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", - "2023-12-15 14:51:31,840\tINFO streaming_executor.py:107 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n", - "\u001b[36m(ReadParquet->SplitBlocks(173) pid=6870)\u001b[0m /Users/luweizheng/anaconda3/envs/dispy/lib/python3.11/site-packages/ray/data/_internal/arrow_block.py:128: FutureWarning: promote has been superseded by mode='default'.\n", - "\u001b[36m(ReadParquet->SplitBlocks(173) pid=6870)\u001b[0m return transform_pyarrow.concat(tables) \n", - " \r" - ] - }, - { - "data": { - "text/plain": [ - "[{'trip_distance': 1.8353531664835362e-05}]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[33m(raylet)\u001b[0m [2023-12-15 15:24:05,019 E 6858 18306341] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-12-15_14-07-48_510959_95090 is over 95% full, available space: 49982570496; capacity: 1000240963584. Object creation will fail if spilling is required.\n" - ] - } - ], - "source": [ - "minmax_ds_ft = preprocessor.fit_transform(ds)\n", - "minmax_ds_ft.take(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 分类变量和数值变量\n", - "\n", - "### 分类变量\n", - "\n", - "机器学习模型无法接受分类变量,所以需要进行一些转换。{numref}`tab-categorical-data-preprocessor` 是几个处理分类变量的 Preprocessor。\n", - "\n", - "```{table} 用于处理分类变量的 Preprocessor\n", - ":name: tab-categorical-data-preprocessor\n", - "\n", - "| Preprocessor \t| 变量类型 \t | 案例 \t|\n", - "|:-----------------:\t|:--------:\t|:----------------------------------: |\n", - "| [`LabelEncoder`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.LabelEncoder.html) \t| 无序分类 \t | 猫,狗,牛,羊 \t |\n", - "| [`OrdinalEncoder`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.OrdinalEncoder.html) \t| 有序分类 \t | 高中,本科,硕士,博士 \t |\n", - "| [`MultiHotEncoder`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.MultiHotEncoder.html) \t| 多分类 \t | [\"喜剧\", \"动画\"], [\"悬疑\", \"动作\"] |\n", - "```\n", - "\n", - "### 数值变量\n", - "\n", - "使用下面的转换将数据进行转换,以适应特定的机器学习模型,{numref}`tab-numerical-data-preprocessor` 是几个处理数值变量的 Preprocessor。\n", - "\n", - "```{table} 用于处理数值变量的 Preprocessor\n", - ":name: tab-numerical-data-preprocessor\n", - "\n", - "| Preprocessor \t| 变量类型 \t| 计算方式 \t| 备注 \t|\n", - "|--------------------\t|----------------------\t|--------------------------------------------\t|----------------------------------------------------------\t|\n", - "| [`RobustScaler`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.RobustScaler.html) \t| 有离群值 \t| $x' = \\frac{x - \\mu_{1/2}}{\\mu_h - \\mu_l}$ \t| $\\mu_{1/2}$ 是中位数,$\\mu_h$ 是最大值,$\\mu_l$ 是最小值 \t|\n", - "| [`MaxAbsScaler`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.MaxAbsScaler.html) \t| 数据稀疏 \t| $x' = \\frac{x}{\\max{\\vert x \\vert}}$ \t| \t|\n", - "| [`PowerTransformer`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.PowerTransformer.html) \t| 将数据变为正太分布 \t| Yeo-Johnson 或 Box-Cox \t| \t|\n", - "| [`Normalizer`](https://docs.ray.io/en/latest/data/api/doc/ray.data.preprocessors.Normalizer.html) \t| 需要对数据进行正则化 \t| $x' = \\frac{x}{\\lVert x \\rVert_p}$ \t| $p$ 是正则方式,比如 `l1` 正则是绝对值求和 \t|\n", - "```\n" + "{numref}`sec-ray-data-transform` 介绍了通用接口 `map()` 和 `map_batches()`。对于结构化的表格类数据,Ray Data 在 `map()` 和 `map_batches()` 基础上,增加了一个高阶的 API:预处理器(Preprocessor)。[Preprocessor](https://docs.ray.io/en/latest/data/api/preprocessor.html) 是一系列特征处理操作,可与机器学习模型训练和推理更好地结合。其使用方式与 scikit-learn 的 [sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) 非常相似,熟悉 scikit-learn 的用户可以快速迁移过来。对于非结构化数据,比如图片、视频等,仍然建议使用 `map()` 或者 `map_batches()`。" ] }, { diff --git a/ch-ray-data/ray-data-intro.md b/ch-ray-data/ray-data-intro.md index 8fbe42a..cc15342 100644 --- a/ch-ray-data/ray-data-intro.md +++ b/ch-ray-data/ray-data-intro.md @@ -1,14 +1,14 @@ (sec-ray-data-intro)= # Ray Data 简介 -Ray Data 是基于 Ray Core 的数据处理框架,主要解决机器学习模型训练或推理相关的数据准备与处理问题,即数据的最后一公里问题(Last-mile Preprocessing)。 +Ray Data 是基于 Ray Core 的数据处理框架,主要解决机器学习模型训练或推理相关的数据准备与处理问题,即数据的最后一公里问题(Last-mile Preprocessing)。与 Dask DataFrame、Modin、Xorbits 相比,Ray Data 更通用,既可以处理二维表,也可以处理图片、视频;Ray Data 的通用也意味着它在很多方面还不够专业,比如 `groupby` 等操作相对比较粗糙。 Ray Data 对数据提供了一个抽象的类,[`ray.data.Dataset`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.html),在 `Dataset` 上提供了常见的大数据处理的原语,覆盖了数据处理的大部分阶段,例如: * 数据的读取,比如读取 Parquet 文件等。 * 对数据的转换(Transformation)操作,比如 [`map_batches()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html)。 * 分组聚合操作,比如 [`groupby()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.groupby.html) -* 涉及数据在计算节点间的交换,比如 [`random_shuffle()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.random_shuffle.html) 和 [`repartition()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.repartition.htmln) 等。 +* 涉及数据在计算节点间的交换,比如 [`random_shuffle()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.random_shuffle.html) 和 [`repartition()`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.repartition.html) 等。 ## 关键概念 diff --git a/ch-ray-serve/client-chat.py b/ch-ray-ml/client-chat.py similarity index 100% rename from ch-ray-serve/client-chat.py rename to ch-ray-ml/client-chat.py diff --git a/ch-ray-serve/config.yaml b/ch-ray-ml/config.yaml similarity index 100% rename from ch-ray-serve/config.yaml rename to ch-ray-ml/config.yaml diff --git a/ch-ray-ml/index.md b/ch-ray-ml/index.md new file mode 100644 index 0000000..800edc7 --- /dev/null +++ b/ch-ray-ml/index.md @@ -0,0 +1,6 @@ +# Ray 机器学习 + +本章将聚焦于 Ray 机器学习,主要介绍 Ray Data、Ray Train 和 Ray Tune 等库的使用。这些库基于 Ray 的分布式计算能力,面向机器学习和深度学习应用,深度集成了 PyTorch、TensorFlow、Hugging Face Transformers、XGBoost scikit-learn 等机器学习库。 + +```{tableofcontents} +``` \ No newline at end of file diff --git a/ch-ray-serve/llm.py b/ch-ray-ml/llm.py similarity index 100% rename from ch-ray-serve/llm.py rename to ch-ray-ml/llm.py diff --git a/ch-ray-serve/ray-serve.md b/ch-ray-ml/ray-serve.md similarity index 100% rename from ch-ray-serve/ray-serve.md rename to ch-ray-ml/ray-serve.md diff --git a/ch-ray-train-tune/ray-train.ipynb b/ch-ray-ml/ray-train.ipynb similarity index 100% rename from ch-ray-train-tune/ray-train.ipynb rename to ch-ray-ml/ray-train.ipynb diff --git a/ch-ray-ml/ray-tune.ipynb b/ch-ray-ml/ray-tune.ipynb new file mode 100644 index 0000000..771a443 --- /dev/null +++ b/ch-ray-ml/ray-tune.ipynb @@ -0,0 +1,3168 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(sec-ray-tune)=\n", + "# Ray Tune\n", + "\n", + "Ray Tune 主要面向超参数调优场景,将模型训练、超参数选择和并行计算结合起来,它底层基于 Ray 的 Actor、Task 和 Ray Train,并行地启动多个机器学习训练任务,并选择最好的超参数。Ray Tune 适配了 PyTorch、Keras、XGBoost 等常见机器学习训练框架,提供了常见超参数调优算法(例如随机搜索、贝叶斯优化等)和工具([Hyperopt](https://github.com/hyperopt/hyperopt)、[Optuna](https://github.com/optuna/optuna)等)。用户可以基于 Ray Tune 在 Ray 集群上进行批量超参数调优。读者可以阅读 {numref}`sec-hyperparameter-optimization`,重温超参数调优背景知识。\n", + "\n", + "## 关键组件\n", + "\n", + "Ray Tune 主要包括以下组件:\n", + "\n", + "* 将原有的训练过程抽象为一个可训练的函数(Trainable)\n", + "* 定义需要搜索的超参数搜索空间(Search Space)\n", + "* 使用一些搜索算法(Search Algorithm)和调度器(Scheduler)并行训练和智能调度。\n", + "\n", + "{numref}`fig-ray-tune-key-parts` 展示了适配 Ray Tune 的关键部分。用户创建一个 [`Tuner`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.Tuner.html),`Tuner` 中包含了需要训练的 Trainable 函数、超参数搜索空间,用户选择搜索算法或者使用某种调度器。不同超参数组合组成了不同的试验,Ray Tune 根据用户所申请的资源和集群已有资源,并行训练。用户可对多个试验的结果进行分析。\n", + "\n", + "```{figure} ../img/ch-ray-train-tune/ray-tune-key-parts.svg\n", + "---\n", + "width: 500px\n", + "name: fig-ray-tune-key-parts\n", + "---\n", + "Ray Tune 关键部分\n", + "```\n", + "\n", + "## Trainable\n", + "\n", + "跟其他超参数优化库一样,Ray Tune 需要一个优化目标(Objective),它是 Ray Tune 试图优化的方向,一般是一些机器学习训练指标,比如模型预测的准确度等。Ray Tune 用户需要将优化目标封装在可训练(Trainable)函数中,可在原有单节点机器学习训练的代码上进行改造。Trainable 函数接收一个字典式的配置,字典中的键是需要搜索的超参数。在 Trainable 函数中,优化目标以 `ray.train.report(...)` 方式存储起来,或者作为 Trainable 函数的返回值直接返回。例如,如果用户想对超参数 `lr` 进行调优,优化目标为 `score`,除了必要的训练代码外,Trainable 函数如下所示:\n", + "\n", + "```python\n", + "def trainable(config):\n", + " lr = config[\"lr\"]\n", + " \n", + " # 训练代码 ...\n", + " \n", + " # 以 ray.train.report 方式返回优化目标\n", + " ray.train.report({\"score\": ...})\n", + " # 或者使用 return 或 yield 直接返回\n", + " return {\"score\": ...}\n", + "```\n", + "\n", + "### 案例:图像分类\n", + "\n", + "对图像分类案例进行改造,对 `lr` 和 `momentum` 两个超参数进行搜索,Trainable 函数是代码中的 `train_mnist()`:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "from torch.utils.data import DataLoader\n", + "from torchvision.models import resnet18\n", + "\n", + "from ray import tune\n", + "from ray.tune.schedulers import ASHAScheduler\n", + "\n", + "import ray\n", + "import ray.train.torch\n", + "from ray.train import Checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = os.path.join(os.getcwd(), \"../data\")\n", + "\n", + "def train_func(model, optimizer, criterion, train_loader):\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " model.train()\n", + " for data, target in train_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output = model(data)\n", + " loss = criterion(output, target)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "\n", + "def test_func(model, data_loader):\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " model.eval()\n", + " correct = 0\n", + " total = 0\n", + " with torch.no_grad():\n", + " for data, target in data_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " outputs = model(data)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += target.size(0)\n", + " correct += (predicted == target).sum().item()\n", + "\n", + " return correct / total\n", + "\n", + "def train_mnist(config):\n", + " transform = torchvision.transforms.Compose(\n", + " [torchvision.transforms.ToTensor(), \n", + " torchvision.transforms.Normalize((0.5,), (0.5,))]\n", + " )\n", + "\n", + " train_loader = DataLoader(\n", + " torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),\n", + " batch_size=128,\n", + " shuffle=True)\n", + " test_loader = DataLoader(\n", + " torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),\n", + " batch_size=128,\n", + " shuffle=True)\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + " model = resnet18(num_classes=10)\n", + " model.conv1 = torch.nn.Conv2d(\n", + " 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False\n", + " )\n", + " model.to(device)\n", + "\n", + " criterion = nn.CrossEntropyLoss()\n", + "\n", + " optimizer = torch.optim.SGD(\n", + " model.parameters(), lr=config[\"lr\"], momentum=config[\"momentum\"])\n", + " \n", + " # 训练 10 个 epoch\n", + " for epoch in range(10):\n", + " train_func(model, optimizer, criterion, train_loader)\n", + " acc = test_func(model, test_loader)\n", + "\n", + " with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n", + " checkpoint = None\n", + " if (epoch + 1) % 5 == 0:\n", + " torch.save(\n", + " model.state_dict(),\n", + " os.path.join(temp_checkpoint_dir, \"model.pth\")\n", + " )\n", + " checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)\n", + "\n", + " ray.train.report({\"mean_accuracy\": acc}, checkpoint=checkpoint)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ray Tune 同时支持使用函数方式和类方式定义 Trainable,本例中使用函数方式,函数方式和类方式的区别如 {numref}`tab-function-class-trainable` 所示。\n", + "\n", + "```{table} 函数方式和类方式定义 Trainable 的区别\n", + ":name: tab-function-class-trainable\n", + "| 内容 | 函数方式 API | 类方式 API |\n", + "| :----------------: | :----------------------------------------: | :----------------------------------------: |\n", + "| 一次训练迭代 | 每调用一次 `train.report`,迭代次数加一 | 每调用一次 `Trainable.step()`,迭代次数加一 |\n", + "| 反馈性能指标 | 调用 `train.report(metrics)` | 在 `Trainable.step()` 的返回值处返回 |\n", + "| 写入 Checkpoint | 调用 `train.report(..., checkpoint=checkpoint)` | 实现 `Trainable.save_checkpoint()` |\n", + "| 读取 Checkpoint| 调用 `train.get_checkpoint()` | 实现 `Trainable.load_checkpoint()` |\n", + "| 读取不同的超参数组合 | `def train_func(config):` 中的 `config` 传入 | `Trainable.setup(self, config)` 中的 `config` 传入 |\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 搜索空间\n", + "\n", + "搜索空间是超参数可能的值,Ray Tune 提供了一些方法定义搜索空间。比如,[`ray.tune.choice()`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.sample_from.html) 从某个范围中选择可能的值,[`ray.tune.uniform()`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.uniform.html) 从均匀分布中选择可能的值。现在对 `lr` 和 `momentum` 两个超参数设置搜索空间:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "search_space = {\n", + " \"lr\": tune.choice([0.001, 0.002, 0.005, 0.01, 0.02, 0.05]),\n", + " \"momentum\": tune.uniform(0.1, 0.9),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 搜索算法和调度器\n", + "\n", + "Ray Tune 的超参数搜索中比较重要的概念是搜索算法和调度器:搜索算法确定如何从搜索空间中选择新的超参数组合(即试验);调度器决定对提前结束一些不太有前景的试验,节省计算资源。搜索算法是必须的,调度器不是必须的。这两者可以协作来选择超参数,比如使用随机搜索算法和 ASHA 调度器,调度器对一些看起来没希望的试验提前结束。另外,一些超参数优化的包通常提供了封装好的搜索算法,比如 [Hyperopt](https://github.com/hyperopt/hyperopt)、[Optuna](https://github.com/optuna/optuna) 等;有的还提供了调度器,这些包有自己的使用方式和习惯,Ray Tune 对这些包进行了封装,尽量使得这些包的使用方式统一。\n", + "\n", + "我们先使用随机搜索:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Tune Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Current time:2024-04-11 21:24:22
Running for: 00:03:56.63
Memory: 12.8/90.0 GiB
\n", + "
\n", + "
\n", + "
\n", + "

System Info

\n", + " Using FIFO scheduling algorithm.
Logical resource usage: 0/64 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "

Trial Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Trial name status loc lr momentum acc iter total time (s)
train_mnist_421ef_00000TERMINATED10.0.0.3:414850.002 0.290490.8686 10 228.323
\n", + "
\n", + "
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[36m(train_mnist pid=41485)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-20-25/train_mnist_421ef_00000_0_lr=0.0020,momentum=0.2905_2024-04-11_21-20-26/checkpoint_000000)\n", + "\u001b[36m(train_mnist pid=41485)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-20-25/train_mnist_421ef_00000_0_lr=0.0020,momentum=0.2905_2024-04-11_21-20-26/checkpoint_000001)\n", + "2024-04-11 21:24:22,692\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", + "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", + "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", + "2024-04-11 21:24:22,696\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/train_mnist_2024-04-11_21-20-25' in 0.0079s.\n", + "2024-04-11 21:24:22,706\tINFO tune.py:1048 -- Total run time: 236.94 seconds (236.62 seconds for the tuning loop).\n" + ] + } + ], + "source": [ + "trainable_with_gpu = tune.with_resources(train_mnist, {\"gpu\": 1})\n", + "\n", + "tuner = tune.Tuner(\n", + " trainable_with_gpu,\n", + " param_space=search_space,\n", + ")\n", + "results = tuner.fit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "调度器会对每个试验进行分析,未训练完就提前结束某个试验,提前结束又被称为早停(Early Stopping),这样可以节省计算资源,把计算资源留给最有希望的某个试验。下面的例子使用了 [ASHA 算法](https://openreview.net/forum?id=S1Y7OOlRZ) {cite}`li2018Massively` 进行调度。\n", + "\n", + "前面例子中,没设置 [`ray.tune.TuneConfig`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.TuneConfig.html),默认只进行了一次试验。现在我们设置 `ray.tune.TuneConfig` 中的 `num_samples`,该参数表示希望进行多少次试验。同时使用 [ray.tune.schedulers.ASHAScheduler](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.ASHAScheduler.html) 来做选择,提前结束那些性能较差的试验,把计算结果留给更有希望的试验。`ASHAScheduler` 的参数 `metric` 和 `mode` 表示希望优化的目标,本例的目标是最大化 \"mean_accuracy\"。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[36m(train_mnist pid=41806)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00002_2_lr=0.0020,momentum=0.4789_2024-04-11_21-24-22/checkpoint_000000)\n", + "\u001b[36m(train_mnist pid=42212)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00005_5_lr=0.0050,momentum=0.8187_2024-04-11_21-24-22/checkpoint_000000)\u001b[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n", + "\u001b[36m(train_mnist pid=41806)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00002_2_lr=0.0020,momentum=0.4789_2024-04-11_21-24-22/checkpoint_000001)\n", + "\u001b[36m(train_mnist pid=41809)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00003_3_lr=0.0100,momentum=0.4893_2024-04-11_21-24-22/checkpoint_000001)\n", + "\u001b[36m(train_mnist pid=42394)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00006_6_lr=0.0200,momentum=0.1573_2024-04-11_21-24-22/checkpoint_000000)\n", + "\u001b[36m(train_mnist pid=42212)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00005_5_lr=0.0050,momentum=0.8187_2024-04-11_21-24-22/checkpoint_000001)\n", + "\u001b[36m(train_mnist pid=42619)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00008_8_lr=0.0500,momentum=0.7167_2024-04-11_21-24-22/checkpoint_000000)\n", + "\u001b[36m(train_mnist pid=42394)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00006_6_lr=0.0200,momentum=0.1573_2024-04-11_21-24-22/checkpoint_000001)\n", + "\u001b[36m(train_mnist pid=42619)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00008_8_lr=0.0500,momentum=0.7167_2024-04-11_21-24-22/checkpoint_000001)\n", + "\u001b[36m(train_mnist pid=43231)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00014_14_lr=0.0200,momentum=0.7074_2024-04-11_21-24-22/checkpoint_000000)\n", + "\u001b[36m(train_mnist pid=43231)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00014_14_lr=0.0200,momentum=0.7074_2024-04-11_21-24-22/checkpoint_000001)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-11 21:34:30,051\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", + "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", + "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", + "2024-04-11 21:34:30,067\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22' in 0.0224s.\n", + "2024-04-11 21:34:30,083\tINFO tune.py:1048 -- Total run time: 607.32 seconds (607.25 seconds for the tuning loop).\n" + ] + } + ], + "source": [ + "tuner = tune.Tuner(\n", + " trainable_with_gpu,\n", + " tune_config=tune.TuneConfig(\n", + " num_samples=16,\n", + " scheduler=ASHAScheduler(metric=\"mean_accuracy\", mode=\"max\"),\n", + " ),\n", + " param_space=search_space,\n", + ")\n", + "results = tuner.fit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "屏幕上会打印出每个试验所选择的超参数值和目标,对于性能较差的试验,简单迭代几轮(`iter`)之后就早停了。我们对这些试验的结果进行可视化:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'Mean Accuracy')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-04-11T21:34:31.189696\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%config InlineBackend.figure_format = 'svg'\n", + "\n", + "dfs = {result.path: result.metrics_dataframe for result in results}\n", + "ax = None\n", + "for d in dfs.values():\n", + " ax = d.mean_accuracy.plot(ax=ax, legend=False)\n", + "ax.set_xlabel(\"Epochs\")\n", + "ax.set_ylabel(\"Mean Accuracy\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "以上为 Ray Tune 的完整案例,接下来我们展示几个使用不同搜索算法和调度器的案例。\n", + "\n", + "## 案例:飞机延误预测\n", + "\n", + "这个例子基于飞机起降数据,使用 XGBoost 对是否延误进行预测。XGBoost 是一个树模型,其训练过程有很多超参数,比如树深度等。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "import sys\n", + "sys.path.append(\"..\")\n", + "from utils import nyc_flights\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import torch\n", + "import torchvision\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader\n", + "from torchvision.models import resnet18\n", + "\n", + "import ray\n", + "from sklearn.model_selection import train_test_split\n", + "from ray.tune.search.hyperopt import HyperOptSearch\n", + "import xgboost as xgb\n", + "from ray import tune\n", + "from ray.tune.schedulers import AsyncHyperBandScheduler\n", + "from ray.tune.integration.xgboost import TuneReportCheckpointCallback\n", + "from ray.tune.schedulers import PopulationBasedTraining\n", + "\n", + "folder_path = nyc_flights()\n", + "file_path = os.path.join(folder_path, \"1991.csv\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "读取数据,进行必要的数据预处理:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_cols = [\n", + " \"Year\",\n", + " \"Month\",\n", + " \"DayofMonth\",\n", + " \"DayOfWeek\",\n", + " \"CRSDepTime\",\n", + " \"CRSArrTime\",\n", + " \"UniqueCarrier\",\n", + " \"FlightNum\",\n", + " \"ActualElapsedTime\",\n", + " \"Origin\",\n", + " \"Dest\",\n", + " \"Distance\",\n", + " \"Diverted\",\n", + " \"ArrDelay\",\n", + "]\n", + "\n", + "df = pd.read_csv(file_path, usecols=input_cols,)\n", + "\n", + "# 预测是否延误\n", + "df[\"ArrDelayBinary\"] = 1.0 * (df[\"ArrDelay\"] > 10)\n", + "\n", + "df = df[df.columns.difference([\"ArrDelay\"])]\n", + "\n", + "for col in df.select_dtypes([\"object\"]).columns:\n", + " df[col] = df[col].astype(\"category\").cat.codes.astype(np.int32)\n", + "\n", + "for col in df.columns:\n", + " df[col] = df[col].astype(np.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "XGBoost `train()` 函数的 `params` 参数接收树深度等超参数。需要注意的是,XGBoost 等训练框架提供的 `train()` 函数不像 PyTorch 那样有 `for epoch in range(...)` 这样的显式迭代训练过程,如果希望每次训练迭代后立即反馈性能指标,需要在 `train()` 的 `callbacks` 中传入回调函数,Ray 提供了 [`TuneReportCheckpointCallback`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.integration.xgboost.TuneReportCheckpointCallback.html),这个回调函数会在每次训练迭代后将相关指标报告给 Ray Tune。具体到本例中,XGBoost 的 `train()` 函数的 `params` 参数传入了 `\"eval_metric\": [\"logloss\", \"error\"]`,表示评估时的指标; `evals=[(test_set, \"eval\")]` 表示只关注验证集的指标;以上两者合起来,表示对验证集计算 `logloss` 和 `error` 指标,汇报给 Ray Tune 时,指标名称为 `eval-logloss` 和 `eval-error`。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_flight(config: dict):\n", + " config.update({\n", + " \"objective\": \"binary:logistic\",\n", + " \"eval_metric\": [\"logloss\", \"error\"]\n", + " })\n", + " _y_label = \"ArrDelayBinary\"\n", + " train_x, test_x, train_y, test_y = train_test_split(\n", + " df.loc[:, df.columns != _y_label], \n", + " df[_y_label], \n", + " test_size=0.25\n", + " )\n", + " \n", + " train_set = xgb.DMatrix(train_x, label=train_y)\n", + " test_set = xgb.DMatrix(test_x, label=test_y)\n", + " \n", + " xgb.train(\n", + " params=config,\n", + " dtrain=train_set,\n", + " evals=[(test_set, \"eval\")],\n", + " verbose_eval=False,\n", + " # 每次迭代后, `TuneReportCheckpointCallback` 将评估指标反馈给 Ray Tune\n", + " callbacks=[TuneReportCheckpointCallback(frequency=1)]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们底层使用 `hyperopt` 包所提供的贝叶斯优化搜索算法,如果没安装这个包,请先安装:`pip install hyperopt`。这些包通常有自己的定义搜索空间格式,用户也可以直接使用 Ray Tune 提供的搜索空间定义方式。\n", + "\n", + "调度器方面,我们使用 HyperBand 调度算法。[`AsyncHyperBandScheduler`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.AsyncHyperBandScheduler.html) 是 Ray Tune 推荐的 HyperBand 算法的实现,它是异步的,能够更充分利用计算资源。`AsyncHyperBandScheduler` 中 `time_attr` 是描述训练时间的单位,默认为 `training_iteration`,表示一次训练迭代周期,`time_attr` 是计算资源额度的基本时间单位。`AsyncHyperBandScheduler` 的其他参数与 `time_attr` 规定的时间单位高度相关,比如 `max_t` 是每个试验所能获得的总时间,即一个试验最多能获得 `max_t` * `time_attr` 的计算资源额度;`grace_period` 表示至少给每个试验 `grace_period` * `time_attr` 的计算资源额度。`reduction_factor` 是上述数学描述中的 $\\eta$,`brackets` 为 HyperBand 算法所涉及的组合的概念。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Tune Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Current time:2024-04-17 23:23:16
Running for: 00:00:10.45
Memory: 12.6/90.0 GiB
\n", + "
\n", + "
\n", + "
\n", + "

System Info

\n", + " Using AsyncHyperBand: num_stopped=16
Bracket: Iter 8.000: -0.2197494153541173 | Iter 4.000: -0.21991977574377797 | Iter 2.000: -0.2211587603958556 | Iter 1.000: -0.22190215118710216
Bracket: Iter 8.000: -0.2228778516006133 | Iter 4.000: -0.2228778516006133 | Iter 2.000: -0.22374514085706762
Bracket: Iter 8.000: -0.2238690393222754 | Iter 4.000: -0.2238690393222754
Logical resource usage: 1.0/64 CPUs, 0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "

Trial Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Trial name status loc eta max_depth min_child_weight subsample iter total time (s) eval-logloss eval-error
train_flight_63737_00000TERMINATED10.0.0.3:467650.0160564 2 3 0.963344 10 0.958357 0.522242 0.222878
train_flight_63737_00001TERMINATED10.0.0.3:467240.0027667 3 3 0.930057 10 1.11445 0.525986 0.219595
train_flight_63737_00002TERMINATED10.0.0.3:467000.00932612 3 1 0.532473 1 0.698576 0.53213 0.223699
train_flight_63737_00003TERMINATED10.0.0.3:467950.0807042 7 1 0.824932 10 1.27819 0.42436 0.176524
train_flight_63737_00004TERMINATED10.0.0.3:467960.0697454 1 2 0.908686 10 1.01485 0.516239 0.223466
train_flight_63737_00005TERMINATED10.0.0.3:468680.00334937 4 2 0.799064 10 0.983133 0.528863 0.223869
train_flight_63737_00006TERMINATED10.0.0.3:469320.00637837 5 2 0.555629 2 0.691448 0.528233 0.22136
train_flight_63737_00007TERMINATED10.0.0.3:469350.000145799 8 3 0.84289 1 0.668353 0.532382 0.223079
train_flight_63737_00008TERMINATED10.0.0.3:469590.0267405 5 1 0.766606 2 0.692802 0.520686 0.221159
train_flight_63737_00009TERMINATED10.0.0.3:469890.00848009 2 3 0.576874 2 0.610592 0.53193 0.223745
train_flight_63737_00010TERMINATED10.0.0.3:471250.0016903 8 3 0.824537 2 0.716938 0.532519 0.22407
train_flight_63737_00011TERMINATED10.0.0.3:471270.005344 7 1 0.921332 1 0.609434 0.532074 0.223993
train_flight_63737_00012TERMINATED10.0.0.3:471930.0956213 1 2 0.682057 8 0.791444 0.511592 0.219904
train_flight_63737_00013TERMINATED10.0.0.3:471960.00796245 5 2 0.570677 1 0.619144 0.531066 0.223172
train_flight_63737_00014TERMINATED10.0.0.3:471980.0106115 2 3 0.85295 1 0.582307 0.530977 0.222444
train_flight_63737_00015TERMINATED10.0.0.3:472000.0507297 1 1 0.720122 2 0.655333 0.527164 0.221283
\n", + "
\n", + "
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33m(raylet)\u001b[0m Warning: The actor ImplicitFunc is very large (27 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[36m(train_flight pid=46796)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_flight_2024-04-17_23-23-05/train_flight_63737_00004_4_eta=0.0697,max_depth=1,min_child_weight=2,subsample=0.9087_2024-04-17_23-23-08/checkpoint_000000)\n", + "2024-04-17 23:23:16,344\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/train_flight_2024-04-17_23-23-05' in 0.0653s.\n", + "2024-04-17 23:23:16,362\tINFO tune.py:1048 -- Total run time: 10.73 seconds (10.38 seconds for the tuning loop).\n" + ] + } + ], + "source": [ + "search_space = {\n", + " \"max_depth\": tune.randint(1, 9),\n", + " \"min_child_weight\": tune.choice([1, 2, 3]),\n", + " \"subsample\": tune.uniform(0.5, 1.0),\n", + " \"eta\": tune.loguniform(1e-4, 1e-1),\n", + "}\n", + "\n", + "scheduler = AsyncHyperBandScheduler(\n", + " max_t=10,\n", + " grace_period=1,\n", + " reduction_factor=2,\n", + " brackets=3,\n", + ")\n", + "\n", + "tuner = tune.Tuner(\n", + " train_flight,\n", + " tune_config=tune.TuneConfig(\n", + " metric=\"eval-error\",\n", + " mode=\"min\",\n", + " scheduler=scheduler,\n", + " num_samples=16,\n", + " ),\n", + " param_space=search_space,\n", + ")\n", + "results = tuner.fit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`Tuner.fit()` 会将所有试验的结果返回成 `ResultGrid` ,也会把各类信息写到持久化存储上,用户可以查看不同超参数下的效果并进行分析和对比:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
eval-errortraining_iterationconfig/max_depthconfig/min_child_weightconfig/subsample
00.22287810230.963344
10.21959510330.930057
20.2236991310.532473
30.17652410710.824932
40.22346610120.908686
50.22386910420.799064
60.2213602520.555629
70.2230791830.842890
80.2211592510.766606
90.2237452230.576874
100.2240702830.824537
110.2239931710.921332
120.2199048120.682057
130.2231721520.570677
140.2224441230.852950
150.2212832110.720122
\n", + "
" + ], + "text/plain": [ + " eval-error training_iteration config/max_depth config/min_child_weight \\\n", + "0 0.222878 10 2 3 \n", + "1 0.219595 10 3 3 \n", + "2 0.223699 1 3 1 \n", + "3 0.176524 10 7 1 \n", + "4 0.223466 10 1 2 \n", + "5 0.223869 10 4 2 \n", + "6 0.221360 2 5 2 \n", + "7 0.223079 1 8 3 \n", + "8 0.221159 2 5 1 \n", + "9 0.223745 2 2 3 \n", + "10 0.224070 2 8 3 \n", + "11 0.223993 1 7 1 \n", + "12 0.219904 8 1 2 \n", + "13 0.223172 1 5 2 \n", + "14 0.222444 1 2 3 \n", + "15 0.221283 2 1 1 \n", + "\n", + " config/subsample \n", + "0 0.963344 \n", + "1 0.930057 \n", + "2 0.532473 \n", + "3 0.824932 \n", + "4 0.908686 \n", + "5 0.799064 \n", + "6 0.555629 \n", + "7 0.842890 \n", + "8 0.766606 \n", + "9 0.576874 \n", + "10 0.824537 \n", + "11 0.921332 \n", + "12 0.682057 \n", + "13 0.570677 \n", + "14 0.852950 \n", + "15 0.720122 " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "results_df = results.get_dataframe()\n", + "results_df[[\"eval-error\", \"training_iteration\", \"config/max_depth\", \"config/min_child_weight\", \"config/subsample\"]]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 案例:基于 PBT 进行图像分类\n", + "\n", + "PBT 在训练过程中会对模型权重和超参数都进行调整,因此其训练代码部分必须有更新(加载)模型权重的代码。另外一个区别是训练迭代部分,大部分 PyTorch 训练过程都有 `for epoch in range(...)` 这样显式定义迭代训练的循环,循环一般有终止条件;PBT 训练过程不设置终止条件,当模型指标达到预期或者需要早停,Ray Tune 终止,因此训练迭代处使用 `while True` 一直循环迭代,直到被 Ray Tune 终止。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = os.path.join(os.getcwd(), \"../data\")\n", + "\n", + "def train_func(model, optimizer, criterion, train_loader):\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " model.train()\n", + " for data, target in train_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output = model(data)\n", + " loss = criterion(output, target)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "\n", + "def test_func(model, data_loader):\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " model.eval()\n", + " correct = 0\n", + " total = 0\n", + " with torch.no_grad():\n", + " for data, target in data_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " outputs = model(data)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += target.size(0)\n", + " correct += (predicted == target).sum().item()\n", + "\n", + " return correct / total" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_mnist(config):\n", + " step = 1\n", + " transform = torchvision.transforms.Compose(\n", + " [torchvision.transforms.ToTensor(), \n", + " torchvision.transforms.Normalize((0.5,), (0.5,))]\n", + " )\n", + "\n", + " train_loader = DataLoader(\n", + " torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),\n", + " batch_size=128,\n", + " shuffle=True)\n", + " test_loader = DataLoader(\n", + " torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),\n", + " batch_size=128,\n", + " shuffle=True)\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + " model = resnet18(num_classes=10)\n", + " model.conv1 = torch.nn.Conv2d(\n", + " 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False\n", + " )\n", + " model.to(device)\n", + "\n", + " criterion = nn.CrossEntropyLoss()\n", + "\n", + " optimizer = torch.optim.SGD(\n", + " model.parameters(), \n", + " lr=config.get(\"lr\", 0.01), \n", + " momentum=config.get(\"momentum\", 0.9)\n", + " )\n", + "\n", + " checkpoint = ray.train.get_checkpoint()\n", + " if checkpoint:\n", + " with checkpoint.as_directory() as checkpoint_dir:\n", + " checkpoint_dict = torch.load(os.path.join(checkpoint_dir, \"checkpoint.pt\"))\n", + " \n", + " model.load_state_dict(checkpoint_dict[\"model_state_dict\"])\n", + " optimizer.load_state_dict(checkpoint_dict[\"optimizer_state_dict\"])\n", + " \n", + " # 将 config 传进来的 lr 和 momentum 更新到优化器中 \n", + " for param_group in optimizer.param_groups:\n", + " if \"lr\" in config:\n", + " param_group[\"lr\"] = config[\"lr\"]\n", + " if \"momentum\" in config:\n", + " param_group[\"momentum\"] = config[\"momentum\"]\n", + " \n", + " last_step = checkpoint_dict[\"step\"]\n", + " step = last_step + 1\n", + " \n", + " # Ray Tune 会根据性能指标终止试验\n", + " while True:\n", + " train_func(model, optimizer, criterion, train_loader)\n", + " acc = test_func(model, test_loader)\n", + " metrics = {\"mean_accuracy\": acc, \"lr\": config[\"lr\"]}\n", + "\n", + " if step % config[\"checkpoint_interval\"] == 0:\n", + " with tempfile.TemporaryDirectory() as tmpdir:\n", + " torch.save(\n", + " {\n", + " \"step\": step,\n", + " \"model_state_dict\": model.state_dict(),\n", + " \"optimizer_state_dict\": optimizer.state_dict(),\n", + " },\n", + " os.path.join(tmpdir, \"checkpoint.pt\"),\n", + " )\n", + " ray.train.report(metrics, checkpoint=ray.train.Checkpoint.from_directory(tmpdir))\n", + " else:\n", + " ray.train.report(metrics)\n", + "\n", + " step += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接下来使用 [PopulationBasedTraining](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.PopulationBasedTraining.html) 定义 PBT 调度器。`time_attr` 跟刚才提到的其他调度器一样,是一个时间单位。`perturbation_interval` 表示每隔一定时间对超参数进行一些变异扰动,生成新的超参数,通常与 `checkpoint_interval` 使用同一个值,因为超参数变异扰动的同时也将 Checkpoint 写入持久化存储,会带来额外的开销,因此这个值不宜设置得过频繁。PBT 算法从 `hyperparam_mutations` 里选择可能变异的值,`hyperparam_mutations` 是一个键值字典,里面的内容就是变异值。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "perturbation_interval = 5\n", + "scheduler = PopulationBasedTraining(\n", + " time_attr=\"training_iteration\",\n", + " perturbation_interval=perturbation_interval,\n", + " metric=\"mean_accuracy\",\n", + " mode=\"max\",\n", + " hyperparam_mutations={\n", + " \"lr\": tune.uniform(0.0001, 1),\n", + " \"momentum\": [0.8, 0.9, 0.99],\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接下来就可以进行训练了。我们需要给 PBT 设置停止的条件,本例是 `mean_accuracy` 达到 0.9 或一共完成 20 次迭代。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Tune Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Current time:2024-04-17 18:09:24
Running for: 00:07:35.75
Memory: 16.7/90.0 GiB
\n", + "
\n", + "
\n", + "
\n", + "

System Info

\n", + " PopulationBasedTraining: 9 checkpoints, 1 perturbs
Logical resource usage: 0/64 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "

Trial Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Trial name status loc lr momentum acc iter total time (s) lr
train_mnist_817a7_00000TERMINATED10.0.0.3:269070.291632 0.578225 0.901 7 163.06 0.291632
train_mnist_817a7_00001TERMINATED10.0.0.3:269040.63272 0.94472 0.0996 20 446.4830.63272
train_mnist_817a7_00002TERMINATED10.0.0.3:269030.615735 0.07903790.901 9 219.5480.615735
train_mnist_817a7_00003TERMINATED10.0.0.3:269060.127736 0.486793 0.9084 8 181.9520.127736
\n", + "
\n", + "
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-17 18:03:53,880\tINFO pbt.py:716 -- [pbt]: no checkpoint for trial train_mnist_817a7_00003. Skip exploit for Trial train_mnist_817a7_00001\n", + "2024-04-17 18:09:24,486\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", + "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", + "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", + "2024-04-17 18:09:24,492\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/pbt_mnist' in 0.0111s.\n", + "2024-04-17 18:09:24,501\tINFO tune.py:1048 -- Total run time: 455.82 seconds (455.74 seconds for the tuning loop).\n" + ] + } + ], + "source": [ + "tuner = tune.Tuner(\n", + " tune.with_resources(train_mnist, {\"gpu\": 1}),\n", + " run_config=ray.train.RunConfig(\n", + " name=\"pbt_mnist\",\n", + " # 停止条件:`stop` 或者 `training_iteration` 两个条件任一先达到\n", + " stop={\"mean_accuracy\": 0.9, \"training_iteration\": 20},\n", + " checkpoint_config=ray.train.CheckpointConfig(\n", + " checkpoint_score_attribute=\"mean_accuracy\",\n", + " num_to_keep=4,\n", + " ),\n", + " storage_path=\"~/ray_results\",\n", + " ),\n", + " tune_config=tune.TuneConfig(\n", + " scheduler=scheduler,\n", + " num_samples=4,\n", + " ),\n", + " param_space={\n", + " \"lr\": tune.uniform(0.001, 1),\n", + " \"momentum\": tune.uniform(0.001, 1),\n", + " \"checkpoint_interval\": perturbation_interval,\n", + " },\n", + ")\n", + "\n", + "results_grid = tuner.fit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "调优之后,就可以查看不同超参数的结果了,我们选择最优的那个结果,查看 `lr` 的变化过程。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best result path: /home/u20200002/ray_results/pbt_mnist/train_mnist_817a7_00003_3_lr=0.1277,momentum=0.4868_2024-04-17_18-01-48\n", + "Best final iteration hyperparameter config:\n", + " {'lr': 0.1277359940819796, 'momentum': 0.48679312797681595, 'checkpoint_interval': 5}\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-04-17T19:01:04.895229\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%config InlineBackend.figure_format = 'svg'\n", + "\n", + "best_result = results_grid.get_best_result(metric=\"mean_accuracy\", mode=\"max\")\n", + "\n", + "print('Best result path:', best_result.path)\n", + "print(\"Best final iteration hyperparameter config:\\n\", best_result.config)\n", + "\n", + "df = best_result.metrics_dataframe\n", + "df = df.drop_duplicates(subset=\"training_iteration\", keep=\"last\")\n", + "df.plot(\"training_iteration\", \"mean_accuracy\")\n", + "plt.xlabel(\"Training Iterations\")\n", + "plt.ylabel(\"Test Accuracy\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ch-ray-serve/resnet_pytorch.py b/ch-ray-ml/resnet_pytorch.py similarity index 100% rename from ch-ray-serve/resnet_pytorch.py rename to ch-ray-ml/resnet_pytorch.py diff --git a/ch-ray-train-tune/tune-algorithm-scheduler.ipynb b/ch-ray-ml/tune-algorithm-scheduler.ipynb similarity index 87% rename from ch-ray-train-tune/tune-algorithm-scheduler.ipynb rename to ch-ray-ml/tune-algorithm-scheduler.ipynb index a58f767..d2241fc 100644 --- a/ch-ray-train-tune/tune-algorithm-scheduler.ipynb +++ b/ch-ray-ml/tune-algorithm-scheduler.ipynb @@ -9,100 +9,8 @@ "\n", "Ray Tune 的超参数搜索中比较重要的概念是搜索算法和调度器:搜索算法确定如何从搜索空间中选择新的超参数组合(即试验);调度器决定对提前结束一些不太有前景的试验,节省计算资源。搜索算法是必须的,调度器不是必须的。这两者可以协作来选择超参数,比如使用随机搜索算法和异步连续减半算法(Async Successive Halving Algorithm,ASHA)调度器,调度器对一些看起来没希望的试验提前结束。另外,一些超参数优化的包通常提供了封装好的搜索算法,有的还提供了调度器,这些包有自己的使用方式和习惯,Ray Tune 对这些包进行了封装,尽量使得这些包的使用方式统一。下面简单介绍一些常见的搜索算法和调度器。\n", "\n", - "## 搜索算法\n", - "\n", - "超参数调优是一种黑盒优化,所谓黑盒优化,指的是目标函数是一个黑盒,我们只能通过观察其输入和输出来推断其行为。黑盒的概念比较难以理解,但是我们可以相比梯度下降算法,梯度下降算法**不是**一种黑盒优化算法,我们可以得到目标函数的梯度(或近似值),并用梯度来指导搜索方向,最终找到目标函数的(局部)最优解。黑盒优化算法一般无法找到目标函数的数学表达式和梯度,也无法使用基于梯度的优化技术。贝叶斯优化、遗传算法、模拟退火等都是黑盒优化,这些算法通常在超参数搜索空间中选择一些候选解,运行目标函数,得到超参数组合的实际性能,基于实际性能,不断迭代调整,即重复上述过程,直到满足条件。{numref}`fig-tune-algorithms` 展示了在二维搜索空间中进行超参数搜索,每个点表示一种超参数组合,颜色越暖,表示性能越好。迭代式的算法从初始点开始,后续试验依赖之前试验的结果,最后向性能较好的方向收敛。\n", - "\n", - "```{figure} ../img/ch-ray-train-tune/tune-algorithms.svg\n", - "---\n", - "width: 600px\n", - "name: fig-tune-algorithms\n", - "---\n", - "在一个二维搜索空间中进行超参数搜索,每个点表示一种超参数组合,暖色表示性能较好,冷色表示性能较差。\n", - "```\n", - "\n", - "### 贝叶斯优化\n", - "\n", - "贝叶斯优化基于贝叶斯定理,这里不深入探讨详细的数学公式。简单来说,它需要先掌握搜索空间中几个观测样本点(Observation)的实际性能,构建概率模型,描述每个超参数在每个取值点上模型性能指标的**均值**和**方差**。其中,均值代表这个点最终的期望效果,均值越大表示模型最终性能指标越大,方差表示这个点的不确定性,方差越大表示这个点不确定,值得去探索。{numref}`fig-bayesian-optimization-explained` 在一个 1 维超参数搜索空间中迭代 3 步的过程,虚线是目标函数的真实值,实线是预测值(或者叫后验概率分布均值),实线上下的蓝色区域为置信区间。贝叶斯优化利用了高斯回归过程,即目标函数是由一系列观测样本点所构成的随机过程,通过高斯概率模型来描述这个随机过程的概率分布。贝叶斯优化通过不断地收集观测样本点来更新目标函数的后验分布,直到后验分布基本贴合真实分布。对应 {numref}`fig-bayesian-optimization-explained` 中,进行迭代 3 之前只有两个观测样本点,经过迭代 3 和迭代 4 之后中增加了新的观测样本点,这几个样本点附近的预测值逐渐接近真实值。\n", - "\n", - "贝叶斯优化有两个核心概念:\n", - "\n", - "* 代理模型(Surrogate Model):代理模型拟合观测值,预测实际性能,可以理解为图中的实线。\n", - "* 采集函数(Acquisition Function):采集函数用于选择下一个采样点,它使用一些方法,衡量每一个点对目标函数优化的贡献,可以理解为图中橘黄色的线。\n", - "\n", - "为防止陷入局部最优,采集函数在选取下一个取值点时,应该既考虑利用(Exploit)那些均值较大的,又探索(Explore)那些方差较大的,即在利用和探索之间寻找一个平衡。例如,模型训练非常耗时,有限的计算资源只能再跑 1 组超参数了,那应该选择均值较大的,因为这样能选到最优结果的可能性最高;如果我们计算资源还能可以跑上千次,那应该多探索不同的可能性。在 {numref}`fig-bayesian-optimization-explained` 的例子中,迭代 3 和 迭代 4 都在迭代 2 的观测值附近选择新的点,是在探索和利用之间的一个平衡。\n", - "\n", - "```{figure} ../img/ch-ray-train-tune/bayesian-optimization-explained.svg\n", - "---\n", - "width: 600px\n", - "name: fig-bayesian-optimization-explained\n", - "---\n", - "使用贝叶斯优化进行过一些迭代后,如何选择下一个点。\n", - "```\n", - "\n", - "相比网格搜索和随机搜索,贝叶斯优化并不容易并行化,因为贝叶斯优化需要先运行一些超参数组合,掌握一些实际观测数据。\n", - "\n", - "## 调度器\n", - "\n", - "### SHA 和 ASHA\n", - "\n", - "连续减半算法(Successive Halving Algorithm, SHA){cite}`karnin2013Almost` 是 ASHA 的基础。SHA 的核心思想非常简单,如 {numref}`fig-successive-halving` 所示:\n", - "\n", - "1. SHA 最开始给每个超参数组合一些计算资源额度。\n", - "2. 将这些超参数组合都训练执行完后,对结果进行评估。\n", - "3. 选出排序靠前的超参数组合,进行下一轮(Rung)训练,性能较差的超参数组合早停。\n", - "4. 下一轮每个超参数组合的计算资源额度以一定的策略增加。\n", - "\n", - "```{figure} ../img/ch-ray-train-tune/successive-halving.svg\n", - "---\n", - "width: 600px\n", - "name: fig-successive-halving\n", - "---\n", - "SHA 算法示意图:优化某指标最小值\n", - "```\n", - "\n", - "计算资源额度(英文为 Budget)可以是训练的迭代次数,或训练样本数量等。更精确地,SHA 每轮丢掉 $\\frac{\\eta - 1}{\\eta}$ 的超参数组合,留下 $ \\frac{1}{\\eta}$ 进入下一轮,下一轮每个超参数组合的计算资源额度变为原来的 $\\eta$ 倍。{numref}`tab-sha-resources` 中,每轮总的计算资源为 $B$,总共 81 个超参数组合;第一轮每个试验能分到 $\\frac{B}{81}$ 的计算资源;假设 $\\eta$ 为 3,只有 $\\frac{1}{3}$ 的试验会被提升到下一轮,经过 5 轮后,某个最优超参数组合会被选拔出来。\n", - "\n", - "```{table} 使用 SHA 算法,每个试验所能分配到的计算资源。\n", - ":name: tab-sha-resources\n", - "| \t| 超参数组合数量 $n$ \t| 每个试验所被分配的计算资源 $\\frac{B}{n}$ |\n", - "|:------:\t|:---:\t|:-----:\t|\n", - "| Rung 1 \t| 81 \t| $\\frac{B}{81}$ \t|\n", - "| Rung 2 \t| 27 \t| $\\frac{B}{27}$ \t|\n", - "| Rung 3 \t| 9 \t| $\\frac{B}{9}$ \t|\n", - "| Rung 4 \t| 3 \t| $\\frac{B}{3}$ \t|\n", - "| Rung 5 \t| 1 \t| $B$ \t|\n", - "```\n", - "\n", - "SHA 中,需要等待同一轮所有超参数组合训练完并评估结果后,才能进入下一轮;第一轮时,可以并行地执行多个试验,而进入到后几轮,试验越来越少,并行度越来越低。ASHA 针对 SHA 进行了优化,ASHA 算法不需要等某一轮的训练和评估结束选出下一轮入选者,而是在当前轮进行中的同时,选出可以提升到下一轮的超参数组合,前一轮的训练评估与下一轮的训练评估是同步进行的。\n", - "\n", - "SHA 和 ASHA 的一个主要假设是,如果一个试验在初始时间表现良好,那么它在更长的时间内也会表现良好。这个假设显然太过粗糙,一个反例是学习率:较大的学习率在短期内可能会比较小的学习率表现得更好,但长远来看,较大学习率不一定是最优的,SHA 调度器很有可能导致较小学习率的试验被错误地提前终止。从另外一个角度,为了避免潜在的优质试验提前结束,需要在第一轮时给每个试验更多的计算资源,但由于总的计算资源额度有限($B$),所以一种折中方式是选择较少的超参数组合,即 $n$ 的数量要少一些。\n", - "\n", - "### HyperBand\n", - "\n", - "SHA/ASHA 等算法面临着 $n$ 和 $\\frac{B}{n}$ 相互平衡的问题:如果 $n$ 太大,每个试验所能分到的资源有限,导致优质试验可能提前结束;如果 $n$ 太小,可选择的搜索空间有限,也可能导致优质试验未被囊括到搜索空间中。HyperBand 算法在 SHA 基础上提出了一种对冲机制。HyperBand 有点像金融投资组合,使用多种金融资产来对冲风险,初始轮不是一个固定的 $n$,而是有多个可能的 $n$。如 {numref}`fig-hyperband-algo` 所示,算法实现上,HyperBand 使用了两层循环,内层循环直接调用 SHA 算法,外层循环尝试不同的 $n$,每种可能性是一种 $s$。HyperBand 额外引入了变量 $R$,$R$ 指的是某一个超参数组合所能分配的最大的计算资源额度,$s_{max}$ 是一共多少可能性,它可以被计算出来:$\\lfloor \\log_{\\eta}{R} \\rfloor$;由于额外引入了 $R$,此时总的计算资源 $B = (s_{max} + 1)R$,加一是因为 $s$ 从 0 开始计算。\n", - "\n", - "```{figure} ../img/ch-ray-train-tune/hyperband-algo.png\n", - "---\n", - "width: 600px\n", - "name: fig-hyperband-algo\n", - "---\n", - "HyperBand 算法\n", - "```\n", - "\n", - "{numref}`fig-hyperband-example` 是一个例子:横轴是外层循环,共有 5 个(0 到 4)可能性,初始的计算资源 $n$ 和每个超参数组合所能获得的计算资源 $r$ 形成一个组合(Bracket);纵轴是内层循环,对于某一种初始的 Bracket,执行 SHA 算法,一直迭代到选出最优试验。\n", - "\n", - "```{figure} ../img/ch-ray-train-tune/hyperband-example.svg\n", - "---\n", - "width: 600px\n", - "name: fig-hyperband-example\n", - "---\n", - "Hyperband 示意图\n", - "```\n", - "\n", - "### 案例:飞机延误预测\n", - "\n", - "这个例子基于飞机起降数据,使用 XGBoost 对是否延误进行预测。XGBoost 是一个树模型,其训练过程有很多超参数,比如树深度等。" + "## Hyperband\n", + "\n" ] }, { @@ -120,7 +28,7 @@ "\n", "import sys\n", "sys.path.append(\"..\")\n", - "from datasets import nyc_flights\n", + "from utils import nyc_flights\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -628,22 +536,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### BOHB\n", - "\n", - "BOHB {cite}`falkner2018BOHB` 是一种结合了贝叶斯优化和 Hyperband 的调度器。\n", - "\n", - "### Population Based Training\n", - "\n", - "种群训练(Population Based Training,PBT){cite}`jaderberg2017Population` 主要针对深度神经网络训练,它借鉴了遗传算法的思想,可以同时优化模型参数和超参数。PBT 中,种群可以简单理解成不同的试验,PBT 并行地启动多个试验,每个试验从超参数搜索空间中随机选择一个超参数组合,并随机初始化参数矩阵,训练过程中会定期地评估模型指标。模型训练过程中,基于模型性能指标,PBT 会**利用**或**探索**当前试验的模型参数或超参数。当前试验的指标不理想,PBT 会执行“利用”,将当前模型权重换成种群中其他表现较好的参数权重。PBT 也会“探索”:变异生成新的超参数进行接下来的训练。在一次完整的训练过程中,其他超参数调优方法会选择一种超参数组合完成整个训练;PBT 在训练过程中借鉴效果更好的模型权重,或使用新的超参数,因此它被认为同时优化模型参数和超参数。\n", - "\n", - "```{figure} ../img/ch-ray-train-tune/population-based-training.svg\n", - "---\n", - "width: 600px\n", - "name: fig-population-based-training\n", - "---\n", - "PBT 训练中的利用和探索。利用指模型表现不理想时,将当前模型换成其他表现较好的参数权重;探索指变异生成新的超参数。\n", - "```\n", - "\n", "### 案例:基于 PBT 进行图像分类\n", "\n", "PBT 在训练过程中会对模型权重和超参数都进行调整,因此其训练代码部分必须有更新(加载)模型权重的代码。另外一个区别是训练迭代部分,大部分 PyTorch 训练过程都有 `for epoch in range(...)` 这样显式定义迭代训练的循环,循环一般有终止条件;PBT 训练过程不设置终止条件,当模型指标达到预期或者需要早停,Ray Tune 终止,因此训练迭代处使用 `while True` 一直循环迭代,直到被 Ray Tune 终止。" @@ -2007,7 +1899,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/ch-ray-train-tune/ray-tune.ipynb b/ch-ray-train-tune/ray-tune.ipynb deleted file mode 100644 index d4dc538..0000000 --- a/ch-ray-train-tune/ray-tune.ipynb +++ /dev/null @@ -1,1348 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "(sec-ray-tune)=\n", - "# Ray Tune\n", - "\n", - "Ray Tune 主要面向超参数调优场景,将模型训练、超参数选择和并行计算结合起来,它底层基于 Ray 的 Actor、Task 和 Ray Train,并行地启动多个机器学习训练任务,并选择最好的超参数。Ray Tune 适配了 PyTorch、Keras、XGBoost 等常见机器学习训练框架,提供了常见超参数调优算法(例如随机搜索、贝叶斯优化等)和工具([Hyperopt](https://github.com/hyperopt/hyperopt)、[Optuna](https://github.com/optuna/optuna)等)。用户可以基于 Ray Tune 在 Ray 集群上进行批量超参数调优。\n", - "\n", - "## 超参数调优\n", - "\n", - "{numref}`sec-machine-learning-intro` 中我们提到了模型的参数和超参数(Hyperparameter)的概念。超参数指的是模型参数(权重)之外的一些参数,比如深度学习模型训练时控制梯度下降速度的学习率,又比如决策树中分支的数量。超参数通常有两类:\n", - "\n", - "* 模型:神经网络的设计,比如多少层,卷积神经网络的核大小,决策树的分支数量等。\n", - "* 训练和算法:学习率、批量大小等。\n", - "\n", - "确定这些超参数的方式是开启多个试验(Trial),每个试验测试超参数的某个值,根据模型训练结果的好坏来做选择,这个过程称为超参数调优。寻找最优超参数的过程这个过程可以手动进行,手动的话费时费力,效率低下,所以业界提出一些自动化的方法。常见的自动化的搜索方法有:\n", - "\n", - "* 网格搜索(Grid Search):网格搜索是一种穷举搜索方法,它通过遍历所有可能的超参数组合来寻找最优解,这些组合会逐一被用来训练和评估模型。网格搜索简单直观,但当超参数空间很大时,所需的计算成本会急剧增加。\n", - "* 随机搜索(Random Search):随机搜索不是遍历所有可能的组合,而是在解空间中随机选择超参数组合进行评估。这种方法的效率通常高于网格搜索,因为它不需要评估所有可能的组合,而是通过随机抽样来探索参数空间。随机搜索尤其适用于超参数空间非常大或维度很高的情况下,它可以在较少的尝试中发现性能良好的超参数配置。然而,由于随机性的存在,随机搜索可能会错过一些局部最优解,因此可能需要更多的尝试次数来确保找到一个好的解。\n", - "* 贝叶斯优化(Bayesian Optimization):贝叶斯优化是一种基于贝叶斯定理的技术,它利用概率模型来指导搜索最优超参数的过程。这种方法的核心思想是构建一个贝叶斯模型,通常是高斯过程(Gaussian Process),来近似评估目标函数的未知部分。贝叶斯优化能够在有限的评估次数内,智能地选择最有希望的超参数组合进行尝试,特别适用于计算成本高昂的场景。\n", - "\n", - "## 关键组件\n", - "\n", - "Ray Tune 主要包括以下组件:\n", - "\n", - "* 将原有的训练过程抽象为一个可训练的函数(Trainable)\n", - "* 定义需要搜索的超参数搜索空间(Search Space)\n", - "* 使用一些搜索算法(Search Algorithm)和调度器(Scheduler)并行训练和智能调度。\n", - "\n", - "{numref}`fig-ray-tune-key-parts` 展示了适配 Ray Tune 的关键部分。用户创建一个 [`Tuner`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.Tuner.html),`Tuner` 中包含了需要训练的 Trainable 函数、超参数搜索空间,用户选择搜索算法或者使用某种调度器。不同超参数组合组成了不同的试验,Ray Tune 根据用户所申请的资源和集群已有资源,并行训练。用户可对多个试验的结果进行分析。\n", - "\n", - "```{figure} ../img/ch-ray-train-tune/ray-tune-key-parts.svg\n", - "---\n", - "width: 500px\n", - "name: fig-ray-tune-key-parts\n", - "---\n", - "Ray Tune 关键部分\n", - "```\n", - "\n", - "## Trainable\n", - "\n", - "跟其他超参数优化库一样,Ray Tune 需要一个优化目标(Objective),它是 Ray Tune 试图优化的方向,一般是一些机器学习训练指标,比如模型预测的准确度等。Ray Tune 用户需要将优化目标封装在可训练(Trainable)函数中,可在原有单节点机器学习训练的代码上进行改造。Trainable 函数接收一个字典式的配置,字典中的键是需要搜索的超参数。在 Trainable 函数中,优化目标以 `ray.train.report(...)` 方式存储起来,或者作为 Trainable 函数的返回值直接返回。例如,如果用户想对超参数 `lr` 进行调优,优化目标为 `score`,除了必要的训练代码外,Trainable 函数如下所示:\n", - "\n", - "```python\n", - "def trainable(config):\n", - " lr = config[\"lr\"]\n", - " \n", - " # 训练代码 ...\n", - " \n", - " # 以 ray.train.report 方式返回优化目标\n", - " ray.train.report({\"score\": ...})\n", - " # 或者使用 return 或 yield 直接返回\n", - " return {\"score\": ...}\n", - "```\n", - "\n", - "### 案例:图像分类\n", - "\n", - "对图像分类案例进行改造,对 `lr` 和 `momentum` 两个超参数进行搜索,Trainable 函数是代码中的 `train_mnist()`:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "from torch.utils.data import DataLoader\n", - "from torchvision.models import resnet18\n", - "\n", - "from ray import tune\n", - "from ray.tune.schedulers import ASHAScheduler\n", - "\n", - "import ray\n", - "import ray.train.torch\n", - "from ray.train import Checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "data_dir = os.path.join(os.getcwd(), \"../data\")\n", - "\n", - "def train_func(model, optimizer, criterion, train_loader):\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " model.train()\n", - " for data, target in train_loader:\n", - " data, target = data.to(device), target.to(device)\n", - " output = model(data)\n", - " loss = criterion(output, target)\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - "\n", - "def test_func(model, data_loader):\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " model.eval()\n", - " correct = 0\n", - " total = 0\n", - " with torch.no_grad():\n", - " for data, target in data_loader:\n", - " data, target = data.to(device), target.to(device)\n", - " outputs = model(data)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total += target.size(0)\n", - " correct += (predicted == target).sum().item()\n", - "\n", - " return correct / total\n", - "\n", - "def train_mnist(config):\n", - " transform = torchvision.transforms.Compose(\n", - " [torchvision.transforms.ToTensor(), \n", - " torchvision.transforms.Normalize((0.5,), (0.5,))]\n", - " )\n", - "\n", - " train_loader = DataLoader(\n", - " torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),\n", - " batch_size=128,\n", - " shuffle=True)\n", - " test_loader = DataLoader(\n", - " torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),\n", - " batch_size=128,\n", - " shuffle=True)\n", - "\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - " model = resnet18(num_classes=10)\n", - " model.conv1 = torch.nn.Conv2d(\n", - " 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False\n", - " )\n", - " model.to(device)\n", - "\n", - " criterion = nn.CrossEntropyLoss()\n", - "\n", - " optimizer = torch.optim.SGD(\n", - " model.parameters(), lr=config[\"lr\"], momentum=config[\"momentum\"])\n", - " \n", - " # 训练 10 个 epoch\n", - " for epoch in range(10):\n", - " train_func(model, optimizer, criterion, train_loader)\n", - " acc = test_func(model, test_loader)\n", - "\n", - " with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n", - " checkpoint = None\n", - " if (epoch + 1) % 5 == 0:\n", - " torch.save(\n", - " model.state_dict(),\n", - " os.path.join(temp_checkpoint_dir, \"model.pth\")\n", - " )\n", - " checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)\n", - "\n", - " ray.train.report({\"mean_accuracy\": acc}, checkpoint=checkpoint)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Ray Tune 同时支持使用函数方式和类方式定义 Trainable,本例中使用函数方式,函数方式和类方式的区别如 {numref}`tab-function-class-trainable` 所示。\n", - "\n", - "```{table} 函数方式和类方式定义 Trainable 的区别\n", - ":name: tab-function-class-trainable\n", - "| 内容 | 函数方式 API | 类方式 API |\n", - "| :----------------: | :----------------------------------------: | :----------------------------------------: |\n", - "| 一次训练迭代 | 每调用一次 `train.report`,迭代次数加一 | 每调用一次 `Trainable.step()`,迭代次数加一 |\n", - "| 反馈性能指标 | 调用 `train.report(metrics)` | 在 `Trainable.step()` 的返回值处返回 |\n", - "| 写入 Checkpoint | 调用 `train.report(..., checkpoint=checkpoint)` | 实现 `Trainable.save_checkpoint()` |\n", - "| 读取 Checkpoint| 调用 `train.get_checkpoint()` | 实现 `Trainable.load_checkpoint()` |\n", - "| 读取不同的超参数组合 | `def train_func(config):` 中的 `config` 传入 | `Trainable.setup(self, config)` 中的 `config` 传入 |\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 搜索空间\n", - "\n", - "搜索空间是超参数可能的值,Ray Tune 提供了一些方法定义搜索空间。比如,[`ray.tune.choice()`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.sample_from.html) 从某个范围中选择可能的值,[`ray.tune.uniform()`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.uniform.html) 从均匀分布中选择可能的值。现在对 `lr` 和 `momentum` 两个超参数设置搜索空间:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "search_space = {\n", - " \"lr\": tune.choice([0.001, 0.002, 0.005, 0.01, 0.02, 0.05]),\n", - " \"momentum\": tune.uniform(0.1, 0.9),\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 搜索算法和调度器\n", - "\n", - "Ray Tune 内置了一些搜索算法或者集成了常用的包,比如 [Hyperopt](https://github.com/hyperopt/hyperopt)、[Optuna](https://github.com/optuna/optuna) 等,比如贝叶斯优化等。如果不做设置,默认使用随机搜索。\n", - "\n", - "我们先使用随机搜索:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
\n", - "
\n", - "

Tune Status

\n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
Current time:2024-04-11 21:24:22
Running for: 00:03:56.63
Memory: 12.8/90.0 GiB
\n", - "
\n", - "
\n", - "
\n", - "

System Info

\n", - " Using FIFO scheduling algorithm.
Logical resource usage: 0/64 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n", - "
\n", - " \n", - "
\n", - "
\n", - "
\n", - "

Trial Status

\n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
Trial name status loc lr momentum acc iter total time (s)
train_mnist_421ef_00000TERMINATED10.0.0.3:414850.002 0.290490.8686 10 228.323
\n", - "
\n", - "
\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[36m(train_mnist pid=41485)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-20-25/train_mnist_421ef_00000_0_lr=0.0020,momentum=0.2905_2024-04-11_21-20-26/checkpoint_000000)\n", - "\u001b[36m(train_mnist pid=41485)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-20-25/train_mnist_421ef_00000_0_lr=0.0020,momentum=0.2905_2024-04-11_21-20-26/checkpoint_000001)\n", - "2024-04-11 21:24:22,692\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", - "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", - "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", - "2024-04-11 21:24:22,696\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/train_mnist_2024-04-11_21-20-25' in 0.0079s.\n", - "2024-04-11 21:24:22,706\tINFO tune.py:1048 -- Total run time: 236.94 seconds (236.62 seconds for the tuning loop).\n" - ] - } - ], - "source": [ - "trainable_with_gpu = tune.with_resources(train_mnist, {\"gpu\": 1})\n", - "\n", - "tuner = tune.Tuner(\n", - " trainable_with_gpu,\n", - " param_space=search_space,\n", - ")\n", - "results = tuner.fit()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "调度器会对每个试验进行分析,未训练完就提前结束某个试验,提前结束又被称为早停(Early Stopping),这样可以节省计算资源,把计算资源留给最有希望的某个试验。下面的例子使用了 [ASHA 算法](https://openreview.net/forum?id=S1Y7OOlRZ) {cite}`li2018Massively` 进行调度。\n", - "\n", - "前面例子中,没设置 [`ray.tune.TuneConfig`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.TuneConfig.html),默认只进行了一次试验。现在我们设置 `ray.tune.TuneConfig` 中的 `num_samples`,该参数表示希望进行多少次试验。同时使用 [ray.tune.schedulers.ASHAScheduler](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.ASHAScheduler.html) 来做选择,提前结束那些性能较差的试验,把计算结果留给更有希望的试验。`ASHAScheduler` 的参数 `metric` 和 `mode` 表示希望优化的目标,本例的目标是最大化 \"mean_accuracy\"。" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[36m(train_mnist pid=41806)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00002_2_lr=0.0020,momentum=0.4789_2024-04-11_21-24-22/checkpoint_000000)\n", - "\u001b[36m(train_mnist pid=42212)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00005_5_lr=0.0050,momentum=0.8187_2024-04-11_21-24-22/checkpoint_000000)\u001b[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n", - "\u001b[36m(train_mnist pid=41806)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00002_2_lr=0.0020,momentum=0.4789_2024-04-11_21-24-22/checkpoint_000001)\n", - "\u001b[36m(train_mnist pid=41809)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00003_3_lr=0.0100,momentum=0.4893_2024-04-11_21-24-22/checkpoint_000001)\n", - "\u001b[36m(train_mnist pid=42394)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00006_6_lr=0.0200,momentum=0.1573_2024-04-11_21-24-22/checkpoint_000000)\n", - "\u001b[36m(train_mnist pid=42212)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00005_5_lr=0.0050,momentum=0.8187_2024-04-11_21-24-22/checkpoint_000001)\n", - "\u001b[36m(train_mnist pid=42619)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00008_8_lr=0.0500,momentum=0.7167_2024-04-11_21-24-22/checkpoint_000000)\n", - "\u001b[36m(train_mnist pid=42394)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00006_6_lr=0.0200,momentum=0.1573_2024-04-11_21-24-22/checkpoint_000001)\n", - "\u001b[36m(train_mnist pid=42619)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00008_8_lr=0.0500,momentum=0.7167_2024-04-11_21-24-22/checkpoint_000001)\n", - "\u001b[36m(train_mnist pid=43231)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00014_14_lr=0.0200,momentum=0.7074_2024-04-11_21-24-22/checkpoint_000000)\n", - "\u001b[36m(train_mnist pid=43231)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22/train_mnist_cf61b_00014_14_lr=0.0200,momentum=0.7074_2024-04-11_21-24-22/checkpoint_000001)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-11 21:34:30,051\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", - "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", - "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", - "2024-04-11 21:34:30,067\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/train_mnist_2024-04-11_21-24-22' in 0.0224s.\n", - "2024-04-11 21:34:30,083\tINFO tune.py:1048 -- Total run time: 607.32 seconds (607.25 seconds for the tuning loop).\n" - ] - } - ], - "source": [ - "tuner = tune.Tuner(\n", - " trainable_with_gpu,\n", - " tune_config=tune.TuneConfig(\n", - " num_samples=16,\n", - " scheduler=ASHAScheduler(metric=\"mean_accuracy\", mode=\"max\"),\n", - " ),\n", - " param_space=search_space,\n", - ")\n", - "results = tuner.fit()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "屏幕上会打印出每个试验所选择的超参数值和目标,对于性能较差的试验,简单迭代几轮(`iter`)之后就早停了。我们对这些试验的结果进行可视化:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Text(0, 0.5, 'Mean Accuracy')" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2024-04-11T21:34:31.189696\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.8.4, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "%config InlineBackend.figure_format = 'svg'\n", - "\n", - "dfs = {result.path: result.metrics_dataframe for result in results}\n", - "ax = None\n", - "for d in dfs.values():\n", - " ax = d.mean_accuracy.plot(ax=ax, legend=False)\n", - "ax.set_xlabel(\"Epochs\")\n", - "ax.set_ylabel(\"Mean Accuracy\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "以上为 Ray Tune 的完整案例,接下来我们介绍搜索算法和调度器。" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/drawio/ch-ray-train-tune/bayesian-optimization-explained.drawio b/drawio/ch-data-science/bayesian-optimization-explained.drawio similarity index 96% rename from drawio/ch-ray-train-tune/bayesian-optimization-explained.drawio rename to drawio/ch-data-science/bayesian-optimization-explained.drawio index a936da5..218041c 100644 --- a/drawio/ch-ray-train-tune/bayesian-optimization-explained.drawio +++ b/drawio/ch-data-science/bayesian-optimization-explained.drawio @@ -1,111 +1,123 @@ - + - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - - + + - - - - + - + - + - + - - + + - + - - + + - + - - + + - + - - + + - + - + - + + + + + + + + + + + + + + + + diff --git a/drawio/ch-ray-train-tune/hyperband-example.drawio b/drawio/ch-data-science/hyperband-example.drawio similarity index 72% rename from drawio/ch-ray-train-tune/hyperband-example.drawio rename to drawio/ch-data-science/hyperband-example.drawio index e44e029..af42100 100644 --- a/drawio/ch-ray-train-tune/hyperband-example.drawio +++ b/drawio/ch-data-science/hyperband-example.drawio @@ -1,262 +1,262 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + diff --git a/drawio/ch-ray-train-tune/population-based-training.drawio b/drawio/ch-data-science/population-based-training.drawio similarity index 100% rename from drawio/ch-ray-train-tune/population-based-training.drawio rename to drawio/ch-data-science/population-based-training.drawio diff --git a/drawio/ch-ray-train-tune/successive-halving.drawio b/drawio/ch-data-science/successive-halving.drawio similarity index 79% rename from drawio/ch-ray-train-tune/successive-halving.drawio rename to drawio/ch-data-science/successive-halving.drawio index 7321bce..94f7ad5 100644 --- a/drawio/ch-ray-train-tune/successive-halving.drawio +++ b/drawio/ch-data-science/successive-halving.drawio @@ -1,151 +1,151 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - - + + - + - + - + - + - + - - + + - + diff --git a/drawio/ch-ray-train-tune/tune-algorithms.drawio b/drawio/ch-data-science/tune-algorithms.drawio similarity index 98% rename from drawio/ch-ray-train-tune/tune-algorithms.drawio rename to drawio/ch-data-science/tune-algorithms.drawio index 0df18f6..e85c591 100644 --- a/drawio/ch-ray-train-tune/tune-algorithms.drawio +++ b/drawio/ch-data-science/tune-algorithms.drawio @@ -1,106 +1,106 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + diff --git a/drawio/ch-ray-train-tune/Pythonplot.py b/drawio/ch-ray-train-tune/Pythonplot.py index 2dfe87d..2421ac0 100644 --- a/drawio/ch-ray-train-tune/Pythonplot.py +++ b/drawio/ch-ray-train-tune/Pythonplot.py @@ -25,4 +25,4 @@ rect = patches.Rectangle((-0.5, -0.5), harvest.shape[1], harvest.shape[0], linewidth=2, edgecolor='black', facecolor='none') ax.add_patch(rect) -fig.savefig(r'C:\Users\LY\Desktop\figure.svg', format="svg") +fig.savefig(r'./figure.svg', format="svg") diff --git a/img/ch-data-science/bayesian-optimization-explained.svg b/img/ch-data-science/bayesian-optimization-explained.svg new file mode 100644 index 0000000..833c848 --- /dev/null +++ b/img/ch-data-science/bayesian-optimization-explained.svg @@ -0,0 +1,4 @@ + + + +
第 3 次迭代
第 3 次迭代
观测样本点
观测样本点
采集函数最大值
采集函数最大值
目标函数
目标函数
新观测样本点
新观测样本点
采集函数
采集函数
方差
方差
均值
均值
采集函数
采集函数
第 4 次迭代
第 4 次迭代
第 5 次迭代
第 5 次迭代
Text is not SVG - cannot display
\ No newline at end of file diff --git a/img/ch-ray-train-tune/hyperband-algo.png b/img/ch-data-science/hyperband-algo.png similarity index 100% rename from img/ch-ray-train-tune/hyperband-algo.png rename to img/ch-data-science/hyperband-algo.png diff --git a/img/ch-data-science/hyperband-example.svg b/img/ch-data-science/hyperband-example.svg new file mode 100644 index 0000000..03dd945 --- /dev/null +++ b/img/ch-data-science/hyperband-example.svg @@ -0,0 +1,4 @@ + + + +
$$...
$$s=4$$
$$...
$$...
$$s=3$$
$$...
$$...
$$s=0$$
$$...
$$...
$$s=2$$
$$...
$$...
$$s=1$$
$$...
$$...
1
1
2
2
3
3
4
4
0
0
1
1
81
81
3
3
27
27
9
9
9
9
27
27
6
6
81
81
5
5
3
3
27
27
9
9
9
9
27
27
3
3
81
81
2
2
9
9
9
9
27
27
3
3
81
81
1
1
27
27
3
3
81
81
1
1
81
81
1
1
外层循环
外层循环
内层循环...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/img/ch-ray-train-tune/population-based-training.svg b/img/ch-data-science/population-based-training.svg similarity index 100% rename from img/ch-ray-train-tune/population-based-training.svg rename to img/ch-data-science/population-based-training.svg diff --git a/img/ch-data-science/successive-halving.svg b/img/ch-data-science/successive-halving.svg new file mode 100644 index 0000000..b1e030a --- /dev/null +++ b/img/ch-data-science/successive-halving.svg @@ -0,0 +1,4 @@ + + + +
优化目标
优化目标
0
0
12.5%
12.5%
25%
25%
50%
50%
100%
100%
计算资源额度
计算资源额度
1
1
Text is not SVG - cannot display
\ No newline at end of file diff --git a/img/ch-data-science/tune-algorithms.svg b/img/ch-data-science/tune-algorithms.svg new file mode 100644 index 0000000..a573d8d --- /dev/null +++ b/img/ch-data-science/tune-algorithms.svg @@ -0,0 +1,4 @@ + + + +
网格搜索
网格搜索
随机搜索
随机搜索
迭代式
迭代式
1
1
8
8
5
5
4
4
6
6
7
7
9
9
2
2
3
3
1
1
2
2
3
3
3
3
4
4
4
4
5
5
5
5
6
6
6
6
7
7
7
7
8
8
8
8
9
9
9
9
1
1
2
2
Text is not SVG - cannot display
\ No newline at end of file diff --git a/img/ch-ray-train-tune/bayesian-optimization-explained.svg b/img/ch-ray-train-tune/bayesian-optimization-explained.svg deleted file mode 100644 index 7d5d478..0000000 --- a/img/ch-ray-train-tune/bayesian-optimization-explained.svg +++ /dev/null @@ -1,3 +0,0 @@ - - -
Iteration 3
Iteration 3
Iteration 4
Iteration 4
observation
observation
acquisition max
acquisition max
objective function
objective function
new observation
new observation
acquisition function
acquisition functi...
posterior uncertainty
posterior uncertainty
posterior mean
posterior mean
\ No newline at end of file diff --git a/img/ch-ray-train-tune/successive-halving.svg b/img/ch-ray-train-tune/successive-halving.svg deleted file mode 100644 index 0046c73..0000000 --- a/img/ch-ray-train-tune/successive-halving.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
loss
0
12.5%
25%
50%
100%
budget
1
\ No newline at end of file diff --git a/img/ch-ray-train-tune/tune-algorithms.svg b/img/ch-ray-train-tune/tune-algorithms.svg deleted file mode 100644 index 0307a77..0000000 --- a/img/ch-ray-train-tune/tune-algorithms.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - - 2024-05-02T09:16:39.498287 image/svg+xml Matplotlib v3.5.3, https://matplotlib.org/ 2024-05-02T09:16:39.498287 image/svg+xml Matplotlib v3.5.3, https://matplotlib.org/ 2024-05-02T09:16:39.498287 image/svg+xml Matplotlib v3.5.3, https://matplotlib.org/
Grid Search
Random Search
Adaptive Selection
1
8
5
4
6
7
9
2
3
1
2
3
3
4
4
5
5
6
6
7
7
8
8
9
9
1
2
\ No newline at end of file diff --git a/datasets.py b/utils.py similarity index 100% rename from datasets.py rename to utils.py