From 93103b739c3db39605c5358be49471b56c8df233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 19 Dec 2024 17:37:54 -0600 Subject: [PATCH 1/3] docs: tutorial on reusing fine-tuned models --- .../061_reusing_finetuned_models.ipynb | 941 ++++++++++++++++++ nbs/mint.json | 1 + 2 files changed, 942 insertions(+) create mode 100644 nbs/docs/tutorials/061_reusing_finetuned_models.ipynb diff --git a/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb b/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb new file mode 100644 index 00000000..613de388 --- /dev/null +++ b/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb @@ -0,0 +1,941 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "27371399-17ac-4fcf-8e2d-19091b32cdf7", + "metadata": {}, + "outputs": [], + "source": [ + "#|hide\n", + "#| eval: false\n", + "! [ -e /content ] && pip install -Uqq nixtla" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e428575b-700a-49a6-a0a9-6fa884119d86", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "from nixtla.utils import in_colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fba11152-1fbb-43b5-b6c7-ccb5ff688ce2", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide \n", + "IN_COLAB = in_colab()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0438f77-6a7e-400d-8739-09c9e347dcac", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "if not IN_COLAB:\n", + " from nixtla.utils import colab_badge\n", + " from dotenv import load_dotenv" + ] + }, + { + "cell_type": "markdown", + "id": "d4bcec3f-9ffe-41e0-a38b-92e77e460154", + "metadata": {}, + "source": [ + "# Re-using fine-tuned models\n", + "\n", + "Save and re-use models fine-tuned models across all of our endpoints." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56e9125c-53b3-41e4-bace-e920fb827c06", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nixtla/nixtla/blob/main/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb.ipynb)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "#| echo: false\n", + "if not IN_COLAB:\n", + " load_dotenv() \n", + " colab_badge('docs/tutorials/061_reusing_finetuned_models.ipynb')" + ] + }, + { + "cell_type": "markdown", + "id": "c7eb9fc0-4541-4c1e-8ffe-442d115fd638", + "metadata": {}, + "source": [ + "## 1. Import packages\n", + "First, we import the required packages and initialize the Nixtla client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89c80a4a-645d-43f9-9454-415a98685105", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from nixtla import NixtlaClient\n", + "from utilsforecast.losses import rmse\n", + "from utilsforecast.evaluation import evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73d7516f-2a78-4be1-972e-41cb70800bcd", + "metadata": {}, + "outputs": [], + "source": [ + "nixtla_client = NixtlaClient(\n", + " # defaults to os.environ[\"NIXTLA_API_KEY\"]\n", + " api_key = 'my_api_key_provided_by_nixtla'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a60ca743-7d68-4d4b-af72-10f63dbf5b26", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "if not IN_COLAB:\n", + " nixtla_client = NixtlaClient()" + ] + }, + { + "cell_type": "markdown", + "id": "83ca8dec-ca2a-4e9f-8983-886208423769", + "metadata": {}, + "source": [ + "## 2. Load data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb5ef6b1-4756-4f79-8609-12f051503431", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
unique_iddsy
0H11605.0
1H12586.0
2H13586.0
3H14559.0
4H15511.0
\n", + "
" + ], + "text/plain": [ + " unique_id ds y\n", + "0 H1 1 605.0\n", + "1 H1 2 586.0\n", + "2 H1 3 586.0\n", + "3 H1 4 559.0\n", + "4 H1 5 511.0" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_parquet('https://datasets-nixtla.s3.amazonaws.com/m4-hourly.parquet')\n", + "\n", + "h = 48\n", + "valid = df.groupby('unique_id', observed=True).tail(h)\n", + "train = df.drop(valid.index)\n", + "train.head()" + ] + }, + { + "cell_type": "markdown", + "id": "f7b61f18-64a3-4b7f-8f86-76a78d6a0c0c", + "metadata": {}, + "source": [ + "## 3. Zero-shot forecast\n", + "\n", + "We can try forecasting without any finetuning to see how well TimeGPT does." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e60cbbe-2710-4a7b-a453-27e52bf8b32b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:nixtla.nixtla_client:Validating inputs...\n", + "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n", + "INFO:nixtla.nixtla_client:Querying model metadata...\n", + "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon, this may lead to less accurate forecasts. Please consider using a smaller horizon.\n", + "INFO:nixtla.nixtla_client:Restricting input...\n", + "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
metricTimeGPT
0rmse1504.474342
\n", + "
" + ], + "text/plain": [ + " metric TimeGPT\n", + "0 rmse 1504.474342" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fcst_kwargs = {'df': train, 'freq': 1, 'model': 'timegpt-1-long-horizon'}\n", + "fcst = nixtla_client.forecast(h=h, **fcst_kwargs)\n", + "zero_shot_eval = evaluate(fcst.merge(valid), metrics=[rmse], agg_fn='mean')\n", + "zero_shot_eval" + ] + }, + { + "cell_type": "markdown", + "id": "f966407c-9c7d-4bce-8d6c-31870e00e7b5", + "metadata": {}, + "source": [ + "## 4. Fine-tune\n", + "\n", + "We can now fine-tune TimeGPT a little and save our model for later use. We can define the ID that we want that model to have by providing it through `output_model_id`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ffd8395-c30c-4522-b597-349a9d3a4b2e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:nixtla.nixtla_client:Validating inputs...\n", + "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n", + "INFO:nixtla.nixtla_client:Calling Fine-tune Endpoint...\n" + ] + }, + { + "data": { + "text/plain": [ + "'my-first-finetuned-model'" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "first_model_id = 'my-first-finetuned-model'\n", + "nixtla_client.finetune(output_model_id=first_model_id, **fcst_kwargs)" + ] + }, + { + "cell_type": "markdown", + "id": "1198429a-5518-43a3-bd73-2fa5d1f48cc3", + "metadata": {}, + "source": [ + "We can now forecast using this fine-tuned model by providing its ID through the `finetuned_model_id` argument." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb996e6a-37e1-44ea-af8d-3b71cf6276ae", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:nixtla.nixtla_client:Validating inputs...\n", + "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n", + "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon, this may lead to less accurate forecasts. Please consider using a smaller horizon.\n", + "INFO:nixtla.nixtla_client:Restricting input...\n", + "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
metricTimeGPT
0rmse1472.024619
\n", + "
" + ], + "text/plain": [ + " metric TimeGPT\n", + "0 rmse 1472.024619" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "first_finetune_fcst = nixtla_client.forecast(h=h, finetuned_model_id=first_model_id, **fcst_kwargs)\n", + "first_finetune_eval = evaluate(first_finetune_fcst.merge(valid), metrics=[rmse], agg_fn='mean')\n", + "first_finetune_eval" + ] + }, + { + "cell_type": "markdown", + "id": "fb763ee8-07c0-4a6b-85dd-deb6c8216ddd", + "metadata": {}, + "source": [ + "We can see the error was reduced." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97cd9f61-9c51-4db0-bf7b-96c151120b3f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
metricTimeGPT_zero_shotTimeGPT_first_finetune
0rmse1504.4743421472.024619
\n", + "
" + ], + "text/plain": [ + " metric TimeGPT_zero_shot TimeGPT_first_finetune\n", + "0 rmse 1504.474342 1472.024619" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "zero_shot_eval.merge(first_finetune_eval, on=['metric'], suffixes=('_zero_shot', '_first_finetune'))" + ] + }, + { + "cell_type": "markdown", + "id": "4b97ad55-a82c-4dd2-878c-40e2e9bf8945", + "metadata": {}, + "source": [ + "## 5. Further fine-tune\n", + "\n", + "We can now take this model and fine-tune it a bit further by using the `NixtlaClient.finetune` method but providing our already fine-tuned model as `finetuned_model_id`, which will take that model and fine-tune it a bit more. We can also change the fine-tuning settings, like using `finetune_depth=3`, for example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99ede33c-379b-4569-8e1a-996abbe8576e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:nixtla.nixtla_client:Validating inputs...\n", + "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n", + "INFO:nixtla.nixtla_client:Calling Fine-tune Endpoint...\n" + ] + }, + { + "data": { + "text/plain": [ + "'c47d99c1-6fdc-4b82-82d3-6ae9954c17cd'" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "second_model_id = nixtla_client.finetune(finetuned_model_id=first_model_id, finetune_depth=3, **fcst_kwargs)\n", + "second_model_id" + ] + }, + { + "cell_type": "markdown", + "id": "70f0cab5-7b01-4d2d-8afe-0a2317644eed", + "metadata": {}, + "source": [ + "Since we didn't provide `output_model_id` this time, it got assigned an UUID.\n", + "\n", + "We can now use this model to forecast." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6edc0794-598d-46ef-890e-23bfa66f20b2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:nixtla.nixtla_client:Validating inputs...\n", + "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n", + "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon, this may lead to less accurate forecasts. Please consider using a smaller horizon.\n", + "INFO:nixtla.nixtla_client:Restricting input...\n", + "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
metricTimeGPT
0rmse1435.365211
\n", + "
" + ], + "text/plain": [ + " metric TimeGPT\n", + "0 rmse 1435.365211" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "second_finetune_fcst = nixtla_client.forecast(h=h, finetuned_model_id=second_model_id, **fcst_kwargs)\n", + "second_finetune_eval = evaluate(second_finetune_fcst.merge(valid), metrics=[rmse], agg_fn='mean')\n", + "second_finetune_eval" + ] + }, + { + "cell_type": "markdown", + "id": "04184bf1-de6c-42c4-93be-8959a78b8d24", + "metadata": {}, + "source": [ + "We can see the error was reduced a bit more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78e09d44-957b-46fb-9751-716ebc3f63c8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
metricTimeGPT_first_finetuneTimeGPT_second_finetune
0rmse1472.0246191435.365211
\n", + "
" + ], + "text/plain": [ + " metric TimeGPT_first_finetune TimeGPT_second_finetune\n", + "0 rmse 1472.024619 1435.365211" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "first_finetune_eval.merge(second_finetune_eval, on=['metric'], suffixes=('_first_finetune', '_second_finetune'))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "611406fe-2379-4b92-bdd4-5f9a86438d91", + "metadata": {}, + "source": [ + "## 6. Listing fine-tuned models\n", + "\n", + "We can list our fine-tuned models with the `NixtlaClient.finetuned_models` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9648bb4-74ad-4a94-8c8a-74625e9795d7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[FinetunedModel(id='c47d99c1-6fdc-4b82-82d3-6ae9954c17cd', created_at=datetime.datetime(2024, 12, 19, 23, 31, 44, 927053, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='my-first-finetuned-model', steps=10, depth=3, loss='default', model='timegpt-1-long-horizon', freq='MS'),\n", + " FinetunedModel(id='my-first-finetuned-model', created_at=datetime.datetime(2024, 12, 19, 23, 31, 38, 889312, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='None', steps=10, depth=1, loss='default', model='timegpt-1-long-horizon', freq='MS')]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "finetuned_models = nixtla_client.finetuned_models()\n", + "finetuned_models" + ] + }, + { + "cell_type": "markdown", + "id": "95e591c8-80b0-43f8-afed-dfa760597af8", + "metadata": {}, + "source": [ + "While that representation may be useful for programmatic use, in this exploratory setting it's nicer to see them as a dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cacc468-0aa3-42af-85d9-7c31bfd2a4f3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idcreated_atcreated_bybase_model_idstepsdepthlossmodelfreq
0c47d99c1-6fdc-4b82-82d3-6ae9954c17cd2024-12-19 23:31:44.927053+00:00usermy-first-finetuned-model103defaulttimegpt-1-long-horizonMS
1my-first-finetuned-model2024-12-19 23:31:38.889312+00:00userNone101defaulttimegpt-1-long-horizonMS
\n", + "
" + ], + "text/plain": [ + " id created_at \\\n", + "0 c47d99c1-6fdc-4b82-82d3-6ae9954c17cd 2024-12-19 23:31:44.927053+00:00 \n", + "1 my-first-finetuned-model 2024-12-19 23:31:38.889312+00:00 \n", + "\n", + " created_by base_model_id steps depth loss \\\n", + "0 user my-first-finetuned-model 10 3 default \n", + "1 user None 10 1 default \n", + "\n", + " model freq \n", + "0 timegpt-1-long-horizon MS \n", + "1 timegpt-1-long-horizon MS " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame([model.model_dump() for model in finetuned_models])" + ] + }, + { + "cell_type": "markdown", + "id": "9697c759-1b08-4192-a14f-5df1fdb03191", + "metadata": {}, + "source": [ + "We can seee that the `base_model_id` of our second model is our first model, along with other metadata." + ] + }, + { + "cell_type": "markdown", + "id": "eae29db5-de09-4954-9352-4f22eb0c3675", + "metadata": {}, + "source": [ + "## 7. Deleting fine-tuned models\n", + "\n", + "In order to keep things organized, and since there's a limit on the number of fine-tuned models, you can delete models that weren't so promising to make room for more experiments. For example, we can delete our first finetuned model. Note that even though it was used as the base for our second model, they're saved independently so removing it won't affect our second model, except for the dangling metadata." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7232bc3b-9096-4875-978a-430b7627688f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nixtla_client.delete_finetuned_model(first_model_id)" + ] + }, + { + "cell_type": "markdown", + "id": "0973b161-368f-4681-8447-c87537a46583", + "metadata": {}, + "source": [ + "We can verify that our first model model doesn't show up anymore in our available models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b80edea-8926-4a13-8fb8-ec9bbcf4d575", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idcreated_atcreated_bybase_model_idstepsdepthlossmodelfreq
0c47d99c1-6fdc-4b82-82d3-6ae9954c17cd2024-12-19 23:31:44.927053+00:00usermy-first-finetuned-model103defaulttimegpt-1-long-horizonMS
\n", + "
" + ], + "text/plain": [ + " id created_at \\\n", + "0 c47d99c1-6fdc-4b82-82d3-6ae9954c17cd 2024-12-19 23:31:44.927053+00:00 \n", + "\n", + " created_by base_model_id steps depth loss \\\n", + "0 user my-first-finetuned-model 10 3 default \n", + "\n", + " model freq \n", + "0 timegpt-1-long-horizon MS " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame([model.model_dump() for model in nixtla_client.finetuned_models()])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/mint.json b/nbs/mint.json index c4e158c7..54951815 100644 --- a/nbs/mint.json +++ b/nbs/mint.json @@ -95,6 +95,7 @@ "group":"Fine-tuning", "pages":[ "docs/tutorials/finetuning.html", + "docs/tutorials/reusing_finetuned_models.html", "docs/tutorials/loss_function_finetuning.html", "docs/tutorials/finetune_depth_finetuning.html" ] From aa3a7698eafcc44f0da34dbfea520560aa13a126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 19 Dec 2024 17:47:01 -0600 Subject: [PATCH 2/3] enhancements --- .../061_reusing_finetuned_models.ipynb | 162 +++--------------- 1 file changed, 26 insertions(+), 136 deletions(-) diff --git a/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb b/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb index 613de388..0f24787d 100644 --- a/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb +++ b/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb @@ -54,7 +54,7 @@ "source": [ "# Re-using fine-tuned models\n", "\n", - "Save and re-use models fine-tuned models across all of our endpoints." + "Save and re-use fine-tuned models across all of our endpoints." ] }, { @@ -396,13 +396,15 @@ " \n", " \n", " metric\n", - " TimeGPT\n", + " TimeGPT_zero_shot\n", + " TimeGPT_first_finetune\n", " \n", " \n", " \n", " \n", " 0\n", " rmse\n", + " 1504.474342\n", " 1472.024619\n", " \n", " \n", @@ -410,8 +412,8 @@ "" ], "text/plain": [ - " metric TimeGPT\n", - "0 rmse 1472.024619" + " metric TimeGPT_zero_shot TimeGPT_first_finetune\n", + "0 rmse 1504.474342 1472.024619" ] }, "execution_count": null, @@ -422,7 +424,7 @@ "source": [ "first_finetune_fcst = nixtla_client.forecast(h=h, finetuned_model_id=first_model_id, **fcst_kwargs)\n", "first_finetune_eval = evaluate(first_finetune_fcst.merge(valid), metrics=[rmse], agg_fn='mean')\n", - "first_finetune_eval" + "zero_shot_eval.merge(first_finetune_eval, on=['metric'], suffixes=('_zero_shot', '_first_finetune'))" ] }, { @@ -433,63 +435,6 @@ "We can see the error was reduced." ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "97cd9f61-9c51-4db0-bf7b-96c151120b3f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
metricTimeGPT_zero_shotTimeGPT_first_finetune
0rmse1504.4743421472.024619
\n", - "
" - ], - "text/plain": [ - " metric TimeGPT_zero_shot TimeGPT_first_finetune\n", - "0 rmse 1504.474342 1472.024619" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "zero_shot_eval.merge(first_finetune_eval, on=['metric'], suffixes=('_zero_shot', '_first_finetune'))" - ] - }, { "cell_type": "markdown", "id": "4b97ad55-a82c-4dd2-878c-40e2e9bf8945", @@ -518,7 +463,7 @@ { "data": { "text/plain": [ - "'c47d99c1-6fdc-4b82-82d3-6ae9954c17cd'" + "'175de80d-0396-4d18-9cb0-26d43ccb95e5'" ] }, "execution_count": null, @@ -544,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6edc0794-598d-46ef-890e-23bfa66f20b2", + "id": "4cfeed2e-0a39-4211-82d1-67d1f868b311", "metadata": {}, "outputs": [ { @@ -580,13 +525,15 @@ " \n", " \n", " metric\n", - " TimeGPT\n", + " TimeGPT_first_finetune\n", + " TimeGPT_second_finetune\n", " \n", " \n", " \n", " \n", " 0\n", " rmse\n", + " 1472.024619\n", " 1435.365211\n", " \n", " \n", @@ -594,8 +541,8 @@ "" ], "text/plain": [ - " metric TimeGPT\n", - "0 rmse 1435.365211" + " metric TimeGPT_first_finetune TimeGPT_second_finetune\n", + "0 rmse 1472.024619 1435.365211" ] }, "execution_count": null, @@ -606,74 +553,17 @@ "source": [ "second_finetune_fcst = nixtla_client.forecast(h=h, finetuned_model_id=second_model_id, **fcst_kwargs)\n", "second_finetune_eval = evaluate(second_finetune_fcst.merge(valid), metrics=[rmse], agg_fn='mean')\n", - "second_finetune_eval" + "first_finetune_eval.merge(second_finetune_eval, on=['metric'], suffixes=('_first_finetune', '_second_finetune'))" ] }, { "cell_type": "markdown", - "id": "04184bf1-de6c-42c4-93be-8959a78b8d24", + "id": "a2bc7c72-47be-4cc5-b774-f75980e8d70b", "metadata": {}, "source": [ "We can see the error was reduced a bit more." ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "78e09d44-957b-46fb-9751-716ebc3f63c8", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
metricTimeGPT_first_finetuneTimeGPT_second_finetune
0rmse1472.0246191435.365211
\n", - "
" - ], - "text/plain": [ - " metric TimeGPT_first_finetune TimeGPT_second_finetune\n", - "0 rmse 1472.024619 1435.365211" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "first_finetune_eval.merge(second_finetune_eval, on=['metric'], suffixes=('_first_finetune', '_second_finetune'))" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -694,8 +584,8 @@ { "data": { "text/plain": [ - "[FinetunedModel(id='c47d99c1-6fdc-4b82-82d3-6ae9954c17cd', created_at=datetime.datetime(2024, 12, 19, 23, 31, 44, 927053, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='my-first-finetuned-model', steps=10, depth=3, loss='default', model='timegpt-1-long-horizon', freq='MS'),\n", - " FinetunedModel(id='my-first-finetuned-model', created_at=datetime.datetime(2024, 12, 19, 23, 31, 38, 889312, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='None', steps=10, depth=1, loss='default', model='timegpt-1-long-horizon', freq='MS')]" + "[FinetunedModel(id='175de80d-0396-4d18-9cb0-26d43ccb95e5', created_at=datetime.datetime(2024, 12, 19, 23, 46, 27, 175345, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='my-first-finetuned-model', steps=10, depth=3, loss='default', model='timegpt-1-long-horizon', freq='MS'),\n", + " FinetunedModel(id='my-first-finetuned-model', created_at=datetime.datetime(2024, 12, 19, 23, 45, 58, 151937, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='None', steps=10, depth=1, loss='default', model='timegpt-1-long-horizon', freq='MS')]" ] }, "execution_count": null, @@ -757,8 +647,8 @@ " \n", " \n", " 0\n", - " c47d99c1-6fdc-4b82-82d3-6ae9954c17cd\n", - " 2024-12-19 23:31:44.927053+00:00\n", + " 175de80d-0396-4d18-9cb0-26d43ccb95e5\n", + " 2024-12-19 23:46:27.175345+00:00\n", " user\n", " my-first-finetuned-model\n", " 10\n", @@ -770,7 +660,7 @@ " \n", " 1\n", " my-first-finetuned-model\n", - " 2024-12-19 23:31:38.889312+00:00\n", + " 2024-12-19 23:45:58.151937+00:00\n", " user\n", " None\n", " 10\n", @@ -785,8 +675,8 @@ ], "text/plain": [ " id created_at \\\n", - "0 c47d99c1-6fdc-4b82-82d3-6ae9954c17cd 2024-12-19 23:31:44.927053+00:00 \n", - "1 my-first-finetuned-model 2024-12-19 23:31:38.889312+00:00 \n", + "0 175de80d-0396-4d18-9cb0-26d43ccb95e5 2024-12-19 23:46:27.175345+00:00 \n", + "1 my-first-finetuned-model 2024-12-19 23:45:58.151937+00:00 \n", "\n", " created_by base_model_id steps depth loss \\\n", "0 user my-first-finetuned-model 10 3 default \n", @@ -894,8 +784,8 @@ " \n", " \n", " 0\n", - " c47d99c1-6fdc-4b82-82d3-6ae9954c17cd\n", - " 2024-12-19 23:31:44.927053+00:00\n", + " 175de80d-0396-4d18-9cb0-26d43ccb95e5\n", + " 2024-12-19 23:46:27.175345+00:00\n", " user\n", " my-first-finetuned-model\n", " 10\n", @@ -910,7 +800,7 @@ ], "text/plain": [ " id created_at \\\n", - "0 c47d99c1-6fdc-4b82-82d3-6ae9954c17cd 2024-12-19 23:31:44.927053+00:00 \n", + "0 175de80d-0396-4d18-9cb0-26d43ccb95e5 2024-12-19 23:46:27.175345+00:00 \n", "\n", " created_by base_model_id steps depth loss \\\n", "0 user my-first-finetuned-model 10 3 default \n", From 35c4ff5bd11abd4214ad27bcf7b7a1327db2fe8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 30 Dec 2024 12:01:24 -0600 Subject: [PATCH 3/3] add as_df argument to finetuned_models --- nbs/docs/reference/01_nixtla_client.ipynb | 32 +++++++++++++++---- .../061_reusing_finetuned_models.ipynb | 30 ++++++++--------- nbs/src/nixtla_client.ipynb | 24 ++++++++++++-- nixtla/nixtla_client.py | 22 +++++++++++-- 4 files changed, 83 insertions(+), 25 deletions(-) diff --git a/nbs/docs/reference/01_nixtla_client.ipynb b/nbs/docs/reference/01_nixtla_client.ipynb index 4b909171..28daadff 100644 --- a/nbs/docs/reference/01_nixtla_client.ipynb +++ b/nbs/docs/reference/01_nixtla_client.ipynb @@ -101,7 +101,12 @@ "\n", "> NixtlaClient.validate_api_key (log:bool=True)\n", "\n", - "*Returns True if your api_key is valid.*" + "*Check API key status.*\n", + "\n", + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| log | bool | True | Show the endpoint's response. |\n", + "| **Returns** | **bool** | | **Whether API key is valid.** |" ], "text/plain": [ "---\n", @@ -110,7 +115,12 @@ "\n", "> NixtlaClient.validate_api_key (log:bool=True)\n", "\n", - "*Returns True if your api_key is valid.*" + "*Check API key status.*\n", + "\n", + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| log | bool | True | Show the endpoint's response. |\n", + "| **Returns** | **bool** | | **Whether API key is valid.** |" ] }, "execution_count": null, @@ -623,18 +633,28 @@ "\n", "## NixtlaClient.finetuned_models\n", "\n", - "> NixtlaClient.finetuned_models ()\n", + "> NixtlaClient.finetuned_models (as_df:bool=False)\n", "\n", - "*List fine-tuned models*" + "*List fine-tuned models*\n", + "\n", + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| as_df | bool | False | Return the fine-tuned models as a pandas dataframe |\n", + "| **Returns** | **Union** | | **List of available fine-tuned models.** |" ], "text/plain": [ "---\n", "\n", "## NixtlaClient.finetuned_models\n", "\n", - "> NixtlaClient.finetuned_models ()\n", + "> NixtlaClient.finetuned_models (as_df:bool=False)\n", "\n", - "*List fine-tuned models*" + "*List fine-tuned models*\n", + "\n", + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| as_df | bool | False | Return the fine-tuned models as a pandas dataframe |\n", + "| **Returns** | **Union** | | **List of available fine-tuned models.** |" ] }, "execution_count": null, diff --git a/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb b/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb index 0f24787d..1faf7ad0 100644 --- a/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb +++ b/nbs/docs/tutorials/061_reusing_finetuned_models.ipynb @@ -463,7 +463,7 @@ { "data": { "text/plain": [ - "'175de80d-0396-4d18-9cb0-26d43ccb95e5'" + "'468b13fb-4b26-447a-bd87-87a64b50d913'" ] }, "execution_count": null, @@ -584,8 +584,8 @@ { "data": { "text/plain": [ - "[FinetunedModel(id='175de80d-0396-4d18-9cb0-26d43ccb95e5', created_at=datetime.datetime(2024, 12, 19, 23, 46, 27, 175345, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='my-first-finetuned-model', steps=10, depth=3, loss='default', model='timegpt-1-long-horizon', freq='MS'),\n", - " FinetunedModel(id='my-first-finetuned-model', created_at=datetime.datetime(2024, 12, 19, 23, 45, 58, 151937, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='None', steps=10, depth=1, loss='default', model='timegpt-1-long-horizon', freq='MS')]" + "[FinetunedModel(id='468b13fb-4b26-447a-bd87-87a64b50d913', created_at=datetime.datetime(2024, 12, 30, 17, 57, 31, 241455, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='my-first-finetuned-model', steps=10, depth=3, loss='default', model='timegpt-1-long-horizon', freq='MS'),\n", + " FinetunedModel(id='my-first-finetuned-model', created_at=datetime.datetime(2024, 12, 30, 17, 57, 16, 978907, tzinfo=TzInfo(UTC)), created_by='user', base_model_id='None', steps=10, depth=1, loss='default', model='timegpt-1-long-horizon', freq='MS')]" ] }, "execution_count": null, @@ -603,7 +603,7 @@ "id": "95e591c8-80b0-43f8-afed-dfa760597af8", "metadata": {}, "source": [ - "While that representation may be useful for programmatic use, in this exploratory setting it's nicer to see them as a dataframe." + "While that representation may be useful for programmatic use, in this exploratory setting it's nicer to see them as a dataframe, which we can get by providing `as_df=True`." ] }, { @@ -647,8 +647,8 @@ " \n", " \n", " 0\n", - " 175de80d-0396-4d18-9cb0-26d43ccb95e5\n", - " 2024-12-19 23:46:27.175345+00:00\n", + " 468b13fb-4b26-447a-bd87-87a64b50d913\n", + " 2024-12-30 17:57:31.241455+00:00\n", " user\n", " my-first-finetuned-model\n", " 10\n", @@ -660,7 +660,7 @@ " \n", " 1\n", " my-first-finetuned-model\n", - " 2024-12-19 23:45:58.151937+00:00\n", + " 2024-12-30 17:57:16.978907+00:00\n", " user\n", " None\n", " 10\n", @@ -675,8 +675,8 @@ ], "text/plain": [ " id created_at \\\n", - "0 175de80d-0396-4d18-9cb0-26d43ccb95e5 2024-12-19 23:46:27.175345+00:00 \n", - "1 my-first-finetuned-model 2024-12-19 23:45:58.151937+00:00 \n", + "0 468b13fb-4b26-447a-bd87-87a64b50d913 2024-12-30 17:57:31.241455+00:00 \n", + "1 my-first-finetuned-model 2024-12-30 17:57:16.978907+00:00 \n", "\n", " created_by base_model_id steps depth loss \\\n", "0 user my-first-finetuned-model 10 3 default \n", @@ -693,7 +693,7 @@ } ], "source": [ - "pd.DataFrame([model.model_dump() for model in finetuned_models])" + "nixtla_client.finetuned_models(as_df=True)" ] }, { @@ -711,7 +711,7 @@ "source": [ "## 7. Deleting fine-tuned models\n", "\n", - "In order to keep things organized, and since there's a limit on the number of fine-tuned models, you can delete models that weren't so promising to make room for more experiments. For example, we can delete our first finetuned model. Note that even though it was used as the base for our second model, they're saved independently so removing it won't affect our second model, except for the dangling metadata." + "In order to keep things organized, and since there's a limit of 50 fine-tuned models, you can delete models that weren't so promising to make room for more experiments. For example, we can delete our first finetuned model. Note that even though it was used as the base for our second model, they're saved independently so removing it won't affect our second model, except for the dangling metadata." ] }, { @@ -784,8 +784,8 @@ " \n", " \n", " 0\n", - " 175de80d-0396-4d18-9cb0-26d43ccb95e5\n", - " 2024-12-19 23:46:27.175345+00:00\n", + " 468b13fb-4b26-447a-bd87-87a64b50d913\n", + " 2024-12-30 17:57:31.241455+00:00\n", " user\n", " my-first-finetuned-model\n", " 10\n", @@ -800,7 +800,7 @@ ], "text/plain": [ " id created_at \\\n", - "0 175de80d-0396-4d18-9cb0-26d43ccb95e5 2024-12-19 23:46:27.175345+00:00 \n", + "0 468b13fb-4b26-447a-bd87-87a64b50d913 2024-12-30 17:57:31.241455+00:00 \n", "\n", " created_by base_model_id steps depth loss \\\n", "0 user my-first-finetuned-model 10 3 default \n", @@ -815,7 +815,7 @@ } ], "source": [ - "pd.DataFrame([model.model_dump() for model in nixtla_client.finetuned_models()])" + "nixtla_client.finetuned_models(as_df=True)" ] } ], diff --git a/nbs/src/nixtla_client.ipynb b/nbs/src/nixtla_client.ipynb index 195d2dc4..ab8bbca7 100644 --- a/nbs/src/nixtla_client.ipynb +++ b/nbs/src/nixtla_client.ipynb @@ -50,6 +50,7 @@ " Optional,\n", " TypeVar,\n", " Union,\n", + " overload,\n", ")\n", "\n", "import annotated_types\n", @@ -1148,8 +1149,24 @@ " resp = self._make_request_with_retries(client, 'v2/finetune', payload)\n", " return resp['finetuned_model_id']\n", "\n", - " def finetuned_models(self) -> list[FinetunedModel]:\n", + " @overload\n", + " def finetuned_models(self, as_df: Literal[False]) -> list[FinetunedModel]:\n", + " ...\n", + "\n", + " @overload\n", + " def finetuned_models(self, as_df: Literal[True]) -> pd.DataFrame:\n", + " ...\n", + "\n", + " def finetuned_models(\n", + " self,\n", + " as_df: bool = False,\n", + " ) -> Union[list[FinetunedModel], pd.DataFrame]:\n", " \"\"\"List fine-tuned models\n", + "\n", + " Parameters\n", + " ----------\n", + " as_df : bool\n", + " Return the fine-tuned models as a pandas dataframe\n", " \n", " Returns\n", " -------\n", @@ -1160,7 +1177,10 @@ " body = resp.json()\n", " if resp.status_code != 200:\n", " raise ApiError(status_code=resp.status_code, body=body)\n", - " return [FinetunedModel(**m) for m in body['finetuned_models']]\n", + " models = [FinetunedModel(**m) for m in body['finetuned_models']]\n", + " if as_df:\n", + " models = pd.DataFrame([m.model_dump() for m in models])\n", + " return models\n", "\n", " def delete_finetuned_model(self, finetuned_model_id: str) -> bool:\n", " \"\"\"Delete a previously fine-tuned model\n", diff --git a/nixtla/nixtla_client.py b/nixtla/nixtla_client.py index ca12e74a..ce0a9674 100644 --- a/nixtla/nixtla_client.py +++ b/nixtla/nixtla_client.py @@ -20,6 +20,7 @@ Optional, TypeVar, Union, + overload, ) import annotated_types @@ -1076,9 +1077,23 @@ def finetune( resp = self._make_request_with_retries(client, "v2/finetune", payload) return resp["finetuned_model_id"] - def finetuned_models(self) -> list[FinetunedModel]: + @overload + def finetuned_models(self, as_df: Literal[False]) -> list[FinetunedModel]: ... + + @overload + def finetuned_models(self, as_df: Literal[True]) -> pd.DataFrame: ... + + def finetuned_models( + self, + as_df: bool = False, + ) -> Union[list[FinetunedModel], pd.DataFrame]: """List fine-tuned models + Parameters + ---------- + as_df : bool + Return the fine-tuned models as a pandas dataframe + Returns ------- list of FinetunedModel @@ -1088,7 +1103,10 @@ def finetuned_models(self) -> list[FinetunedModel]: body = resp.json() if resp.status_code != 200: raise ApiError(status_code=resp.status_code, body=body) - return [FinetunedModel(**m) for m in body["finetuned_models"]] + models = [FinetunedModel(**m) for m in body["finetuned_models"]] + if as_df: + models = pd.DataFrame([m.model_dump() for m in models]) + return models def delete_finetuned_model(self, finetuned_model_id: str) -> bool: """Delete a previously fine-tuned model