Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESM2 Tutorial Updates #426

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 53 additions & 29 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ data_module = ESM2FineTuneDataModule(

# Fine-Tuning the Regressor Task Head for ESM2

Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.
Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained 650M ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.

```python
# create a List[Tuple] with (sequence, target) values
Expand All @@ -174,31 +174,33 @@ data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]
dataset = InMemorySingleValueDataset(data)
data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)

with tempfile.TemporaryDirectory() as experiment_tempdir_name:
experiment_dir = Path(experiment_tempdir_name)
experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42
experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42

config = ESM2FineTuneSeqConfig(
# initial_ckpt_path=str(pretrain_ckpt_path)
)
# To download a 650M pre-trained ESM2 model
pretrain_ckpt_path = load("esm2/650m:2.0", source="ngc")

checkpoint, metrics, trainer = train_model(
experiment_name=experiment_name,
experiment_dir=experiment_dir, # new checkpoint will land in a subdir of this
config=config, # same config as before since we are just continuing training
data_module=data_module,
n_steps_train=n_steps_train,
)
config = ESM2FineTuneSeqConfig(
initial_ckpt_path=str(pretrain_ckpt_path)
)

checkpoint, metrics, trainer = train_model(
experiment_name=experiment_name,
experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
config=config, # same config as before since we are just continuing training
data_module=data_module,
n_steps_train=n_steps_train,
)
```

This example is fully implemented in ```bionemo.esm2.model.finetune.train``` and can be executed by:
```bash
python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
python -m bionemo.esm2.model.finetune.train
```

## Notes
1. The above example is fine-tuning a randomly initialized ESM-2 model for demonstration purposes. In order to fine-tune a pre-trained ESM-2 model, please download the ESM-2 650M checkpoint from NGC resources using the following bash command
1. The above example is fine-tuning a 650M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.
```bash
download_bionemo_data esm2/650m:2.0 --source ngc
```
Expand All @@ -219,21 +221,43 @@ python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/fine
3. We are using a small dataset of artificial sequences as our fine-tuning data in this example. You may experience over-fitting and observe no change in the validation metrics.

# Fine-Tuned ESM-2 Model Inference
Once we have a checkpoint we can create a config object by pointing the path in `initial_ckpt_path` and use that for inference. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `nitial_ckpt_skip_keys_with_these_prefixes` in this config. Now we can use the ```bionemo.esm2.model.fnetune.train.infer``` to run inference on prediction dataset.
Now we can use the ```bionemo.esm2.model.fnetune.train.infer``` to run inference on an example prediction dataset.
Record the checkpoint path reported at the end of the finetuning run, after executing `python -m bionemo.esm2.model.finetune.train` (e.g. `/tmp/tmp1b5wlnba/finetune_regressor/checkpoints/finetune_regressor--reduced_train_loss=0.0016-epoch=0-last`) and use that as an argument to inference script (`--checkpoint-path`).

```python
config = ESM2FineTuneSeqConfig(
initial_ckpt_path = finetuned_checkpoint,
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
)
We download a CSV example dataset of articical sequences for this inference example. Please refer to [ESM-2 Inference](./inference) tutorial for detailed explanation of the arguments and how to create your own CSV file.

```bash
mkdir -p $WORKDIR/esm2_finetune_tutorial

# download sample data CSV for inference
DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0 --source ngc)
RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/inference_results.pt

infer_esm2 --checkpoint-path <finetune checkpoint path> \
--data-path $DATA_PATH \
--results-path $RESULTS_PATH \
--config-class ESM2FineTuneSeqConfig
```

This example is implemented in ```bionemo.esm2.model.finetune.infer``` and can be executed by:
This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/inference_results.pt` which can be loaded via PyTorch library in python environment:

```bash
python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/infer.py
```python
import torch

# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
# results_path = /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
results = torch.load(results_path)

# results is a python dict which includes the following result tensors for this example:
# results['regression_output'] is a tensor with shape: torch.Size([10, 1])
```

## Notes
1. For demonstration purposes, executing the above command will infer a randomly initialized `ESM2FineTuneSeqModel` unless `initial_ckpt_path` is specified and set to an already trained model.
2. If a fine-tuned checkpoint is provided as (`initial_ckpt_path`) the `initial_ckpt_skip_keys_with_these_prefixes` should reset to `field(default_factory=list)` and avoid skipping any parameters.
- ESM2 Inference module takes the `--checkpoint-path` and `--config-class` arguments to create a config object by pointing the path in `initial_ckpt_path`. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `initial_ckpt_skip_keys_with_these_prefixes` in this config.

```python
config = ESM2FineTuneSeqConfig(
initial_ckpt_path = <finetuned checkpoint>,
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
)
```
79 changes: 79 additions & 0 deletions docs/docs/user-guide/examples/bionemo-esm2/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# ESM-2 Inference

This tutorial serves as a demo for [ESM2](https://www.science.org/doi/abs/10.1126/science.ade2574) Inference using a CSV file with `sequences` column. To pre-train the ESM2 model please refer to [ESM-2 Pretraining](./pretrain.md) tutorial.

# Setup and Assumptions

In this tutorial, we will demonstrate how to download ESM2 checkpoint, create a CSV file with protein sequences, and infer a ESM-2 model.

All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at `/workspace/bionemo2`. For more information on how to build or pull the BioNeMo2 container, refer to the [Initialization Guide](../../getting-started/initialization-guide.md).

!!! note

This `WORKDIR` may be `/workspaces/bionemo-framework` if you are using the VSCode Dev Container.

Similar to PyTorch Lightning, we have to define some key classes:

1. `MegatronStrategy` - To launch and setup parallelism for [NeMo](https://github.com/NVIDIA/NeMo/tree/main) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM).
2. `Trainer` - To configure training configurations and logging.
3. `ESMFineTuneDataModule` - To load sequence data for both fine-tuning and inference.
4. `ESM2Config` - To configure the ESM-2 model as `BionemoLightningModule`.

Please refer to [ESM-2 Pretraining](./pretrain.md) and [ESM-2 Fine-Tuning](./finetune.md) tutorials for detailed description of these classes.

# Create a CSV data file containing your protein sequences

We use the `InMemoryCSVDataset` class to load the protein sequence data from a `.csv` file. This data file should at least have a `sequences` column and can optionally have a `labels` column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:

```python
import pandas as pd

artificial_sequence_data = [
"TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
"DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
"LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
"LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
"SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
]

csv_file = "/home/bionemo/sequences.csv"
# Create a DataFrame
df = pd.DataFrame(dummy_protein_sequences, columns=["sequences"])
# Save the DataFrame to a CSV file
df.to_csv(csv_file, index=False)
```

For the purpose of this tutorial, we have already provided an example `.csv` file as a downloadable resource in Bionemo Framework:

```bash
download_bionemo_data esm2/testdata_esm2_inference:2.0 --source ngc
```

To run inference on this data using an ESM2 checkpoint you can use the `infer_esm2` executable which calls `$WORKDIR/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py`:

```bash
DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_inference:2.0 --source ngc)
CHECKPOINT_PATH=$(download_bionemo_data esm2/650m:2.0 --source ngc)

infer_esm2
--data-path ${DATA_PATH} \
--checkpoint-path ${CHECKPOINT_PATH} \
--results-path ${RESULTS_MOUNT}/esm2_inference_tutorial.pt \
--micro-batch-size 2 \
--include-hiddens \
--include-embeddings \
--include-logits
```

This script will create the `esm2_inference_tutorial.pt` file under the results mount of your container to stores the results. The `.pt` file containes a dictionary of `{'result_key': torch.Tensor}` that be loaded with PyTorch:

```python
import torch
data = torch.load(f'${RESULTS_MOUNT}/esm2_inference_tutorial.pt')
```
In this example `data` a python dict with the following keys `['token_logits', 'binary_logits', 'hidden_states', 'embeddings']`
Loading
Loading