Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failed to compile SD for training. #94

Open
geonm opened this issue Dec 24, 2024 · 2 comments
Open

failed to compile SD for training. #94

geonm opened this issue Dec 24, 2024 · 2 comments

Comments

@geonm
Copy link

geonm commented Dec 24, 2024

Hi AWS neuron team.

I found the sample code to train SD which supports SD1.5 and 2.1.

https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx/training/stable_diffusion

but everytime I tried to compile the model, it always showed up this error...

Compiler status PASS
2024-12-24 10:43:46.000413:  802783  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/c9e126d0-5133-47e2-8473-b29d4b875523/model.MODULE_663265732549708907+d505a3f8.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/c9e126d0-5133-47e2-8473-b29d4b875523/model.MODULE_663265732549708907+d505a3f8.neff --target=trn1 --model-type=cnn-training -O1 --enable-saturate-infinity --verbose=35
....................................
Compiler status PASS
2024-12-24 10:46:40.000063:  802786  INFO ||NEURON_PARALLEL_COMPILE||: worker 3 finished with num of tasks 1....
2024-12-24 10:46:40.000063:  802786  INFO ||NEURON_CACHE||: Current remaining items are 0, locked are 3, failed are 0, done are 4, total is 7
.
Compiler status PASS
2024-12-24 10:46:55.000944:  802783  INFO ||NEURON_PARALLEL_COMPILE||: worker 0 finished with num of tasks 4....
2024-12-24 10:46:55.000945:  802783  INFO ||NEURON_CACHE||: Current remaining items are 0, locked are 2, failed are 0, done are 5, total is 7
...............................................................                     .....................................................................
2024-12-24 11:08:40.000928:  802785  ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/d38935b6-ac59-4635-a1c7-d736e1018863/model.MODULE_17201325080867297025+d505a3f8.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/d38935b6-ac59-4635-a1c7-d736e1018863/model.MODULE_17201325080867297025+d505a3f8.neff', '--target=trn1', '--model-type=cnn-training', '-O1', '--enable-saturate-infinity', '--verbose=35']: 2024-12-24T11:08:40Z [MFP002]  Compilation failed for modules(s) 0 4. Please review log-neuron-cc.txt for details. - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.

2024-12-24 11:08:40.000928:  802785  ERROR ||NEURON_CC_WRAPPER||: Compilation failed for /tmp/ubuntu/neuroncc_compile_workdir/d38935b6-ac59-4635-a1c7-d736e1018863/model.MODULE_17201325080867297025+d505a3f8.hlo_module.pb after 0 retries.
.....
2024-12-24 11:10:32.000604:  802784  ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/cf1c760d-c8e8-4d30-926f-47a09e0e84a0/model.MODULE_13795660956479274162+d505a3f8.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/cf1c760d-c8e8-4d30-926f-47a09e0e84a0/model.MODULE_13795660956479274162+d505a3f8.neff', '--target=trn1', '--model-type=cnn-training', '-O1', '--enable-saturate-infinity', '--verbose=35']: 2024-12-24T11:10:32Z [MFP002]  Compilation failed for modules(s) 0 4. Please review log-neuron-cc.txt for details. - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.

2024-12-24 11:10:32.000604:  802784  ERROR ||NEURON_CC_WRAPPER||: Compilation failed for /tmp/ubuntu/neuroncc_compile_workdir/cf1c760d-c8e8-4d30-926f-47a09e0e84a0/model.MODULE_13795660956479274162+d505a3f8.hlo_module.pb after 0 retries.
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/parallel_compile/neuron_parallel_compile.py", line 174, in compile_task
    compile_task_helper(compiled_hlo_status, compile_cache, hlos, workdir, dump=dump,
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/parallel_compile/neuron_parallel_compile.py", line 105, in compile_task_helper
    status, retry = libneuronxla.neuron_cc_wrapper.compile_cache_entry(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py", line 184, in compile_cache_entry
    raise (e)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py", line 163, in compile_cache_entry
    ret = call_neuron_compiler(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py", line 107, in call_neuron_compiler
    raise subprocess.CalledProcessError(res.returncode, cmd, stderr=error_info)
subprocess.CalledProcessError: Command '['neuronx-cc', 'compile', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/cf1c760d-c8e8-4d30-926f-47a09e0e84a0/model.MODULE_13795660956479274162+d505a3f8.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/cf1c760d-c8e8-4d30-926f-47a09e0e84a0/model.MODULE_13795660956479274162+d505a3f8.neff', '--target=trn1', '--model-type=cnn-training', '-O1', '--enable-saturate-infinity', '--verbose=35']' returned non-zero exit status 70.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/parallel_compile/neuron_parallel_compile.py", line 243, in parallel_compile
    compiled_hlo_status = feature.result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
subprocess.CalledProcessError: Command '['neuronx-cc', 'compile', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/cf1c760d-c8e8-4d30-926f-47a09e0e84a0/model.MODULE_13795660956479274162+d505a3f8.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/cf1c760d-c8e8-4d30-926f-47a09e0e84a0/model.MODULE_13795660956479274162+d505a3f8.neff', '--target=trn1', '--model-type=cnn-training', '-O1', '--enable-saturate-infinity', '--verbose=35']' returned non-zero exit status 70.
2024-12-24 11:10:32.000606:  788421  INFO ||NEURON_PARALLEL_COMPILE||: sub-process 1 got exception: None
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/parallel_compile/neuron_parallel_compile.py", line 174, in compile_task
    compile_task_helper(compiled_hlo_status, compile_cache, hlos, workdir, dump=dump,
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/parallel_compile/neuron_parallel_compile.py", line 105, in compile_task_helper
    status, retry = libneuronxla.neuron_cc_wrapper.compile_cache_entry(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py", line 184, in compile_cache_entry
    raise (e)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py", line 163, in compile_cache_entry
    ret = call_neuron_compiler(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py", line 107, in call_neuron_compiler
    raise subprocess.CalledProcessError(res.returncode, cmd, stderr=error_info)
subprocess.CalledProcessError: Command '['neuronx-cc', 'compile', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/d38935b6-ac59-4635-a1c7-d736e1018863/model.MODULE_17201325080867297025+d505a3f8.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/d38935b6-ac59-4635-a1c7-d736e1018863/model.MODULE_17201325080867297025+d505a3f8.neff', '--target=trn1', '--model-type=cnn-training', '-O1', '--enable-saturate-infinity', '--verbose=35']' returned non-zero exit status 70.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_neuronx/parallel_compile/neuron_parallel_compile.py", line 243, in parallel_compile
    compiled_hlo_status = feature.result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
subprocess.CalledProcessError: Command '['neuronx-cc', 'compile', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/d38935b6-ac59-4635-a1c7-d736e1018863/model.MODULE_17201325080867297025+d505a3f8.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/d38935b6-ac59-4635-a1c7-d736e1018863/model.MODULE_17201325080867297025+d505a3f8.neff', '--target=trn1', '--model-type=cnn-training', '-O1', '--enable-saturate-infinity', '--verbose=35']' returned non-zero exit status 70.
2024-12-24 11:10:32.000606:  788421  INFO ||NEURON_PARALLEL_COMPILE||: sub-process 2 got exception: None
2024-12-24 11:10:32.000619:  788421  INFO ||NEURON_PARALLEL_COMPILE||: {
    "compilation_summary": {
        "true": 5
    },
    "compilation_report": {
        "/home/ubuntu/sd-controlnet/compiler_cache/neuronxcc-2.16.345.0+69131dd3/MODULE_11100045050900979764+d505a3f8/model.hlo_module.pb": {
            "status": true,
            "retry": 0,
            "compile_time": 2.790724754333496
        },
        "/home/ubuntu/sd-controlnet/compiler_cache/neuronxcc-2.16.345.0+69131dd3/MODULE_18000634353127283720+d505a3f8/model.hlo_module.pb": {
            "status": true,
            "retry": 0,
            "compile_time": 2.7689967155456543
        },
        "/home/ubuntu/sd-controlnet/compiler_cache/neuronxcc-2.16.345.0+69131dd3/MODULE_3698851708420262760+d505a3f8/model.hlo_module.pb": {
            "status": true,
            "retry": 0,
            "compile_time": 4.183488130569458
        },
        "/home/ubuntu/sd-controlnet/compiler_cache/neuronxcc-2.16.345.0+69131dd3/MODULE_663265732549708907+d505a3f8/model.hlo_module.pb": {
            "status": true,
            "retry": 0,
            "compile_time": 189.53241777420044
        },
        "/home/ubuntu/sd-controlnet/compiler_cache/neuronxcc-2.16.345.0+69131dd3/MODULE_17441551917355545706+d505a3f8/model.hlo_module.pb": {
            "status": true,
            "retry": 0,
            "compile_time": 183.39421582221985
        }
    },
    "start_time": 1735037016.6359746,
    "compilation_time": 1615.983805179596
}
2024-12-24 11:10:32.000619:  788421  INFO ||NEURON_PARALLEL_COMPILE||: Total graphs: 5
2024-12-24 11:10:32.000619:  788421  INFO ||NEURON_PARALLEL_COMPILE||: Total successful compilations: 5
2024-12-24 11:10:32.000619:  788421  INFO ||NEURON_PARALLEL_COMPILE||: Total failed compilations: 0

Here is my env (trn1.32xlarge)

neuronx-cc==2.16.345.0+69131dd3
neuronx-distributed==0.9.0
neuronx-distributed-training==1.0.0
torch==2.1.2
torch-neuronx==2.1.2.2.4.0
torch-xla==2.1.6
torchvision==0.16.2

and this is the code to test. (only focus on Unet part)

click to open the code
#!/usr/bin/env python
# coding=utf-8

################################################################################
###                                                                          ###
###                                 IMPORTS                                  ###
###                                                                          ###
################################################################################

# System
import gc
import os
import shutil
import sys
import pathlib
import random
from glob import glob
from typing import Union
    
# Neuron
import torch_xla.core.xla_model as xm
import torch_neuronx
    
# General ML stuff
import torch
import torch.nn.functional as functional
from torchvision import transforms
import numpy as np

# For measuring throughput
import queue
import time

# Model
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, Adafactor
# Needed for LoRA
from diffusers.loaders import AttnProcsLayers
# LR scheduler
from diffusers.optimization import get_scheduler

# Dataset
from datasets import load_dataset
# For logging and benchmarking
from datetime import datetime
import time

from diffusers import StableDiffusionPipeline

# Command line args
import argparse

# Multicore
import torch.multiprocessing as mp
import torch.distributed as dist
import torch_xla.distributed.xla_backend
import torch_xla.distributed.parallel_loader as xpl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch.utils.data.distributed import DistributedSampler

import torch_xla.debug.profiler as xp

from torch_xla.amp.syncfree.adamw import AdamW
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer

from transformers import AutoModelForImageClassification

################################################################################
###                                                                          ###
###                           CONSTANTS, ENV SETUP                           ###
###                                                                          ###
################################################################################

##### Neuron compiler flags #####
# --model-type=cnn-training: To enable various CNN training-specific optimizations, including mixed tiling algorithm and spill-free attention BIR kernel matching
# --enable-saturate-infinity: Needed for correctness. We get garbage data otherwise (probably from the CLIP text encoder)
# -O1: Gets us better compile time, especially when not splitting the model at the FAL level
#compiler_flags = """ --retry_failed_compilation --cache_dir="./compiler_cache" --verbose=INFO -O1 --model-type=cnn-training  --enable-saturate-infinity """
compiler_flags = """ --retry_failed_compilation --cache_dir="./compiler_cache" -O1 --enable-saturate-infinity """

os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags

# Path to where this file is located
curr_dir = str(pathlib.Path(__file__).parent.resolve())
sys.path.append(curr_dir)
    
image_column_name = "image"
caption_column_name = "text"

LOSS_FILE_FSTRING = "LOSSES-RANK-{RANK}.txt"


################################################################################
###                                                                          ###
###                           HELPER FUNCTIONS                               ###
###                                                                          ###
################################################################################
# For measuring throughput
class Throughput:
    def __init__(self, batch_size=8, data_parallel_degree=2, grad_accum_usteps=1, moving_avg_window_size=10):
        self.inputs_per_training_step = batch_size * data_parallel_degree * grad_accum_usteps
        self.moving_avg_window_size = moving_avg_window_size
        self.moving_avg_window = queue.Queue()
        self.window_time = 0
        self.start_time = time.time()

    # Record a training step - to be called anytime we call optimizer.step()
    def step(self):
        step_time = time.time() - self.start_time
        self.start_time += step_time
        self.window_time += step_time
        self.moving_avg_window.put(step_time)
        window_size = self.moving_avg_window.qsize()
        if window_size > self.moving_avg_window_size:
            self.window_time -= self.moving_avg_window.get()
            window_size -= 1
        return
    
    # Returns the throughput measured over the last moving_avg_window_size steps
    def get_throughput(self):
        throughput = self.moving_avg_window.qsize() * self.inputs_per_training_step / self.window_time
        return throughput


# Patch ZeRO Bug - need to explicitly initialize the clip_value as the dtype we want
@torch.no_grad()
def _clip_grad_norm(
    self,
    max_norm: Union[float, int],
    norm_type: Union[float, int] = 2.0,
) -> torch.Tensor:
    """
    Clip all gradients at this point in time. The norm is computed over all
    gradients together, as if they were concatenated into a single vector.
    Gradients are modified in-place.
    """
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = self._calc_grad_norm(norm_type)

    clip_coeff = torch.tensor(
        max_norm, device=self.device) / (
            total_norm + 1e-6)
    clip_value = torch.where(clip_coeff < 1, clip_coeff,
                                torch.tensor(1., dtype=clip_coeff.dtype, device=self.device))
    for param_group in self.base_optimizer.param_groups:
        for p in param_group['params']:
            if p.grad is not None:
                p.grad.detach().mul_(clip_value)

ZeroRedundancyOptimizer._clip_grad_norm = _clip_grad_norm


# Saves a pipeline to the specified dir using HuggingFace's built-in methods, suitable for loading
# as a pretrained model in an inference script
def save_pipeline(results_dir, model_id, unet, vae, text_encoder):
    xm.master_print(f"Saving trained model to dir {results_dir}")

    if xm.is_master_ordinal():
        assert not os.path.exists(results_dir), f"Error! Can't save checkpoint to {results_dir} because it already exists."
        os.makedirs(results_dir)

    if xm.is_master_ordinal():
        pipeline = StableDiffusionPipeline.from_pretrained(
            model_id,
            text_encoder=text_encoder,
            vae=vae,
            unet=unet,
        )
        pipeline.save_pretrained(results_dir)

    xm.master_print(f"Done saving trained model to dir {results_dir}")
    return


# Saves a checkpoint of the unet and optimizer to the directory specified
# If ZeRO-1 optimizer sharding is enabled, each ordinal needs to save its own checkpoint of the optimizer
def save_checkpoint(results_dir, unet, optimizer, epoch, step, cumulative_step):
    # Save UNet state - only the master needs to save as UNet state is identical between workers
    if xm.is_master_ordinal():
        checkpoint_path = os.path.join(results_dir, f"checkpoint-unet-epoch_{epoch}-step_{step}-cumulative_train_step_{cumulative_step}.pt")
        xm.master_print(f"Saving UNet state checkpoint to {checkpoint_path}")
        data = {
            'epoch': epoch,
            'step': step,
            'cumulative_train_step': cumulative_step,
            'unet_state_dict': unet.state_dict(),
        }
        # Copied from https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/dp_bert_large_hf_pretrain_hdf5.py
        # Not sure if it's strictly needed
        cpu_data = xm._maybe_convert_to_cpu(data)
        torch.save(cpu_data, checkpoint_path)
        del(cpu_data)
        xm.master_print(f"Done saving UNet state checkpoint to {checkpoint_path}")

    # Save optimizer state
    # Under ZeRO optimizer sharding each worker needs to save the optimizer state
    # as each has its own unique state
    checkpoint_path = os.path.join(results_dir, f"checkpoint-optimizer-epoch_{epoch}-step_{step}-cumulative_train_step_{cumulative_step}-rank_{xm.get_ordinal()}.pt")
    xm.master_print(f"Saving optimizer state checkpoint to {checkpoint_path} (other ranks will ahve each saved their own state checkpoint)")
    data = {
        'epoch': epoch,
        'step': step,
        'cumulative_train_step': cumulative_step,
        'optimizer_state_dict': optimizer.state_dict()
    }
    cpu_data = data
    # Intentionally don't move the data to CPU here - it causes XLA to crash
    # later when loading the optimizer checkpoint once the optimizer gets run
    torch.save(cpu_data, checkpoint_path)
    del(cpu_data)
    xm.master_print(f"Done saving optimizer state checkpoint to {checkpoint_path}")

    # Make the GC collect the CPU data we deleted so the memory actually gets freed
    gc.collect()
    xm.master_print("Done saving checkpoints!")


# Loads a checkpoint of the unet and optimizer from the directory specified
# If ZeRO-1 optimizer sharding is enabled, each ordinal needs to load its own checkpoint of the optimizer
# Returns a tuple of (epoch, step, cumulative_train_step)
def load_checkpoint(results_dir, unet, optimizer, device, resume_step):
    # Put an asterisk in for globbing if the user didn't specify a resume_step
    if resume_step is None:
        resume_step = "*"
    unet_checkpoint_filenames = glob(os.path.join(results_dir, f"checkpoint-unet-epoch_*-step_*-cumulative_train_step_{resume_step}.pt"))
    optimizer_checkpoint_filenames = glob(os.path.join(results_dir, f"checkpoint-optimizer-epoch_*-step_*-cumulative_train_step_{resume_step}-rank_{xm.get_ordinal()}.pt"))

    unet_checkpoint_filenames.sort()
    optimizer_checkpoint_filenames.sort()

    # Load UNet checkpoint
    checkpoint_path = unet_checkpoint_filenames[-1]
    xm.master_print(f"Loading UNet checkpoint from path {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    unet.load_state_dict(checkpoint['unet_state_dict'], strict=True)
    ret = (checkpoint['epoch'], checkpoint['step'], checkpoint['cumulative_train_step'])
    del(checkpoint)

    # Load optimizer checkpoint
    checkpoint_path = optimizer_checkpoint_filenames[-1]
    xm.master_print(f"Loading optimizer checkpoint from path {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    if isinstance(optimizer, torch.nn.Module):
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'], strict=True)
    else:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    assert checkpoint['epoch'] == ret[0] and checkpoint['step'] == ret[1] and checkpoint['cumulative_train_step'] == ret[2], \
        "UNet checkpoint and optimizer checkpoint do not agree on the epoch, step, or cumulative_train_step!"
    del(checkpoint)

    gc.collect()

    xm.master_print("Done loading checkpoints!")
    return ret


# Seed various RNG sources that need to be seeded to make training deterministic
# WARNING: calls xm.rendezvous() internally
def seed_rng(device):
    LOCAL_RANK = xm.get_ordinal()
    xm.rendezvous('start-seeding-cpu')
    torch.manual_seed(9999 + LOCAL_RANK)
    random.seed(9999+ LOCAL_RANK)
    np.random.seed(9999 + LOCAL_RANK)

    xm.rendezvous('start-seeding-device')
    xm.set_rng_state(9999 + LOCAL_RANK, device=device)
    # TODO: not sure if we need to print the RNG state on CPU to force seeding to actually happen
    xm.master_print(f"xla rand state after setting RNG state {xm.get_rng_state(device=device)}\n")
    xm.rendezvous('seeding-device-done')

    xm.master_print("Done seeding CPU and device RNG!")




################################################################################
###                                                                          ###
###                           MAIN TRAINING FUNCTION                         ###
###                                                                          ###
################################################################################

def train(args):
    LOCAL_RANK = xm.get_ordinal()

    # Create all the components of our model pipeline and training loop
    xm.master_print('Building training loop components')

    device = xm.xla_device()

    t = torch.tensor([0.1]).to(device=device)
    xm.mark_step()
    xm.master_print(f"Initialized device, t={t.to(device='cpu')}")

    # Warning: calls xm.rendezvous() internally
    seed_rng(device)

    if not xm.is_master_ordinal(): xm.rendezvous('prepare')

    model_id = args.model_id
    noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
    #text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
    #vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")

    #unet = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
    unet.requires_grad_(True)
    optim_params = unet.parameters()

    unet.train()
    unet.to(device)

    # optimizer
    optimizer = torch.optim.AdamW(optim_params, lr=1e-5, betas=(0.9, 0.999), weight_decay=1e-2, eps=1e-08)

    # Download the dataset
    xm.master_print('Downloading dataset')
    # TODO: make this a parameter of the script
    dataset_name = "m1guelpf/nouns"
    dataset = load_dataset(dataset_name)
    args.dataset_name = dataset_name

    # Done anything that might trigger a download
    xm.master_print("Executing `if xm.is_master_ordinal(): xm.rendezvous('prepare')`")
    if xm.is_master_ordinal(): xm.rendezvous('prepare')

    def training_metrics_closure(epoch, global_step, loss):
        loss_val = loss.detach().to('cpu').item()
        loss_f.write(f"{LOCAL_RANK} {epoch} {global_step} {loss_val}\n")
        loss_f.flush()

    xm.rendezvous('prepare-to-load-checkpoint')

    loss_filename = f"LOSSES-RANK-{LOCAL_RANK}.txt"

    if args.resume_from_checkpoint:
        start_epoch, start_step, cumulative_train_step = load_checkpoint(args.results_dir, unet, optimizer, device, args.resume_checkpoint_step)
        loss_f = open(loss_filename, 'a')
    else:
        start_epoch = 0
        start_step = 0
        cumulative_train_step = 0

        loss_f = open(loss_filename, 'w')
        loss_f.write("RANK EPOCH STEP LOSS\n")
    
    xm.rendezvous('done-loading-checkpoint')

    lr_scheduler = get_scheduler(
        "constant",
        optimizer=optimizer
    )

    parameters = filter(lambda p: p.requires_grad, unet.parameters())
    parameters = sum([np.prod(p.size()) * p.element_size() for p in parameters]) / (1024 ** 2)
    xm.master_print('Trainable Parameters: %.3fMB' % parameters)
    total_parameters = 0
    #components = [text_encoder, vae, unet]
    #for component in components:
    #    total_parameters += sum([np.prod(p.size()) * p.element_size() for p in component.parameters()]) / (1024 ** 2)
    xm.master_print('Total parameters: %.3fMB' % total_parameters)
    
    # Preprocess the dataset
    column_names = dataset["train"].column_names
    if image_column_name not in column_names:
        raise ValueError(
            f"Did not find '{image_column_name}' in dataset's 'column_names'"
        )
    if caption_column_name not in column_names:
        raise ValueError(
            f"Did not find '{caption_column_name}' in dataset's 'column_names'"
        )
    
    resolution = args.resolution
    training_transforms = transforms.Compose(
        [
            transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.RandomCrop(resolution),
            transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column_name]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column_name}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        #print(captions[0])
        return inputs.input_ids
    
    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column_name]]
        examples["pixel_values"] = [training_transforms(image) for image in images]
        examples["input_ids"] = tokenize_captions(examples)
        return examples
    
    train_dataset = dataset["train"].with_transform(preprocess_train)
    args.dataset_size = len(train_dataset)
    
    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        # Set to double so that bf16 autocast keeps it as fp32
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).double()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        return {"pixel_values": pixel_values, "input_ids": input_ids}

    # Create dataloaders
    world_size = xm.xrt_world_size()

    train_sampler = None
    if world_size > 1:
        train_sampler = DistributedSampler(train_dataset,
                                           num_replicas=world_size,
                                           rank=xm.get_ordinal(),
                                           shuffle=True)

    # drop_last=True needed to avoid cases of an incomplete final batch, which would result in new graphs being cut and compiled
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, shuffle=False if train_sampler else True, collate_fn=collate_fn, batch_size=args.batch_size, sampler=train_sampler, drop_last=True
    )

    train_device_loader = xpl.MpDeviceLoader(train_dataloader, device, device_prefetch_size=2)
    
    xm.master_print('Entering training loop')
    xm.rendezvous('training-loop-start')

    found_inf = torch.tensor(0, dtype=torch.double, device=device)
    checkpoints_saved = 0

    # Use a moving average window size of 100 so we have a large sample at
    # the end of training
    throughput_helper = Throughput(args.batch_size, world_size, args.gradient_accumulation_steps, moving_avg_window_size=100)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.perf_counter_ns()
        before_batch_load_time = time.perf_counter_ns()
        xm.master_print("####################################")
        xm.master_print(f"###### Starting epoch {epoch} ######")
        xm.master_print("####################################")
        # Add 1 to the start_step so that we don't repeat the step we saved the checkpoint after
        
        for step, batch in enumerate(train_device_loader, start=(start_step + 1 if epoch == start_epoch else 0)):
            after_batch_load_time = time.perf_counter_ns()
            start_time = time.perf_counter_ns()

            #with torch.no_grad():
            #xm.master_print('-' * 50)
            # torch.Size([2, 4, 64, 64]) torch.Size([2, 77, 768])
            #noise = torch.randn(latents.size(), dtype=latents.dtype, layout=latents.layout, device='cpu')
            #noise = torch.randn([2, 4, 64, 64])#latents.size(), dtype=latents.dtype, layout=latents.layout, device='cpu')
            #latents = torch.randn([2, 4, 64, 64])
            #encoder_hidden_states = torch.randn([2, 77, 768])
            #xm.master_print('1', torch.sum(noise).item())
            #noise = noise.to(device=device)
            #latents = latents.to(device=device)
            #encoder_hidden_states = encoder_hidden_states.to(device=device)
            #xm.master_print('2', torch.sum(noise).item())
            #
    
            #timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
            #timesteps = timesteps.to(device=device)

            #noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            #xm.master_print('3', torch.sum(noise).item())
            # Run text encoder on caption
            #encoder_hidden_states = text_encoder(batch["input_ids"])[0]
            #print(encoder_hidden_states.shape, torch.sum(encoder_hidden_states, dim=[1,2]))
            #target = noise
            #xm.master_print('4', torch.sum(target).item())
            #continue
            # UNet forward pass
            #print(noisy_latents.shape, encoder_hidden_states.shape)
            #exit()
            
            #xm.master_print('5', torch.sum(model_pred).item(), torch.sum(target).item(), torch.sum(encoder_hidden_states).item())
            #print(model_pred.shape, torch.sum(model_pred), target)
            # Calculate loss
            #model_pred = unet(torch.randn([32,3,224,224]).to(device)).logits
            #target = torch.randn([32, 2]).to(device)

            latents = torch.randn([2, 4, 64, 64]).to(device)
            encoder_hidden_states = torch.randn([2, 77, 768]).to(device)
            bsz = latents.shape[0]
            target = noise = torch.randn([2, 4, 64, 64]).to(device)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)).to(device)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample



            loss = functional.mse_loss(model_pred.float(), target.float(), reduction="mean")

            #xm.master_print(loss.item(), torch.sum(model_pred).item(), torch.sum(target).item())
            #loss = functional.mse_loss(model_pred, target, reduction="mean")
            
            #print(torch.sum(model_pred), torch.sum(target))
            # Add in extra mark_steps to split the model into FWD / BWD / optimizer - helps with compiler QoR and thus
            # model fit
            # TODO: parametrize how the script splits the model
            #xm.mark_step()

            # Backwards pass
            loss.backward()

            #xm.mark_step()
            
            #continue
            #with torch.no_grad():
                # Optimizer
            if True:#(cumulative_train_step + 1) % args.gradient_accumulation_steps == 0:
                #optimizer.step(found_inf=found_inf)
                xm.optimizer_step(optimizer)
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                #xm.master_print("Finished weight update")
                throughput_helper.step()
                xm.master_print(loss.item(), torch.sum(model_pred).item(), torch.sum(target).item())
            else:
                xm.master_print("Accumulating gradients")
            #exit()
            #continue
            #xm.add_step_closure(training_metrics_closure, (epoch, step, loss.detach()), run_async=True)

            #xm.mark_step()

            #xm.master_print(f"*** Finished epoch {epoch} step {step} (cumulative step {cumulative_train_step})")
            #e2e_time = time.perf_counter_ns()
            #xm.master_print(f" > E2E for epoch {epoch} step {step} took {e2e_time - before_batch_load_time} ns")

            cumulative_train_step += 1

            # Checkpoint if needed

            #before_batch_load_time = time.perf_counter_ns()

            # Only need a handful of training steps for graph extraction. Cut it off so we don't take forever when
            # using a large dataset.
            #if os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None) and cumulative_train_step > 5:
            #    break
            if step > 10:
                break
                #return
        '''
        if args.save_model_epochs is not None and epoch % args.save_model_epochs == 0 and not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None):
            save_pipeline(args.results_dir + f"-EPOCH_{epoch}", args.model_id, unet, vae, text_encoder)
        '''
        #end_epoch_time = time.perf_counter_ns()
        #xm.master_print(f" Entire epoch {epoch} took {(end_epoch_time - start_epoch_time) / (10 ** 9)} s")
        #xm.master_print(f" Given {step + 1} many steps, e2e per iteration is {(end_epoch_time - start_epoch_time) / (step + 1) / (10 ** 6)} ms")
        #xm.master_print(f"!!! Finished epoch {epoch}")

        # Only need a handful of training steps for graph extraction. Cut it off so we don't take forever when
        # using a large dataset.
        #if os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None) and cumulative_train_step > 5:
        #    break
        break

    # Save the trained model for use in inference
    xm.rendezvous('finish-training')
    '''
    if xm.is_master_ordinal() and not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None):
        save_pipeline(os.path.join(args.results_dir, "stable_diffusion_trained_model_neuron"), args.model_id, unet, vae, text_encoder)
    '''
    loss_f.close()

    xm.master_print(f"!!! Finished all epochs")
     
    # However, I may need to block here to await the async? How to do that???
    xm.wait_device_ops()

    #xm.master_print(f"Average throughput over final 100 training steps was {throughput_helper.get_throughput()} images/s")

    xm.rendezvous('done')
    xm.master_print(f"!!! All done!")

    return




################################################################################
###                                                                          ###
###                             ARG PARSING, MAIN                            ###
###                                                                          ###
################################################################################

def parse_args():
    parser = argparse.ArgumentParser(
                    prog='Neuron SD training script',
                    description='Stable Diffusion training script for Neuron Trn1')
    parser.add_argument('--model', choices=['2.1', '1.5'], help='Which model to train')
    parser.add_argument('--resolution', choices=[64, 128, 512, 768], type=int, help='Which resolution of model to train')
    parser.add_argument('--batch_size', type=int, help='What per-device microbatch size to use')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='How many gradient accumulation steps to do (1 for no gradient accumulation)')
    parser.add_argument('--epochs', type=int, default=2000, help='How many epochs to train for')

    # Arguments for checkpointing
    parser.add_argument("--checkpointing_steps", type=int, default=None,
        help=(
            "Save a checkpoint of the training state every X training steps. These checkpoints are only suitable for resuming"
            " training using `--resume_from_checkpoint`."),
    )
    parser.add_argument("--max_num_checkpoints", type=int, default=None,
        help=("Max number of checkpoints to store."),
    )

    parser.add_argument("--save_model_epochs", type=int, default=None,
        help=(
            "Save a copy of the trained model every X epochs in a format that can be loaded using HuggingFace's from_pretrained method."
        ))

    # TODO: add ability to specify dir with checkpoints to restore from that is different than the default
    parser.add_argument('--resume_from_checkpoint', action="store_true", help="Resume from checkpoint at resume_step.")
    parser.add_argument('--resume_checkpoint_step', type=int, default=None, help="Which cumulative training step to resume from, looking for checkpoints in the script's work directory. Leave unset to use the latest checkpoint.")

    args = parser.parse_args()

    return args

if __name__ == "__main__":
    env_world_size = os.environ.get("WORLD_SIZE")

    args = parse_args()

    # Lookup model name by model, resolution
    model_id_lookup = {
        "2.1": {
            512: "stabilityai/stable-diffusion-2-1-base",
        },
        "1.5": {
            64: "runwayml/stable-diffusion-v1-5",
            512: "runwayml/stable-diffusion-v1-5"
        }
    }

    assert args.model in model_id_lookup.keys() and \
        args.resolution in model_id_lookup[args.model].keys(), \
        f"Error: model {args.model} at resolution {args.resolution} is not yet supported!"

    model_id = model_id_lookup[args.model][args.resolution]
    args.model_id = model_id

    test_name = f"sd_{args.model}_training-{args.resolution}-batch{args.batch_size}-AdamW-{env_world_size}w-zero1_optimizer-grad_checkpointing"

    # Directory to save artifacts to, like checkpoints
    results_dir = os.path.join(curr_dir, test_name + '_results')
    os.makedirs(results_dir, exist_ok=True)
    args.results_dir = results_dir

    dist.init_process_group('xla')
    world_size = xm.xrt_world_size()

    args.world_size = world_size

    assert int(world_size) == int(env_world_size), f"Error: world_size {world_size} does not match env_world_size {env_world_size}"

    xm.master_print(f"Starting Stable Diffusion training script on Neuron, training model {model_id} with the following configuration:")
    for k, v in vars(args).items():
        xm.master_print(f"{k}: {v}")
    xm.master_print(f"World size is {world_size}")
    xm.master_print("")
    xm.master_print(f"## Neuron RT flags ##")
    xm.master_print(f"NEURON_RT_STOCHASTIC_ROUNDING_SEED: {os.getenv('NEURON_RT_STOCHASTIC_ROUNDING_SEED', None)}")
    xm.master_print(f"NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS: {os.getenv('NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS', None)}")
    xm.master_print("")
    xm.master_print(f"## XLA flags ##")
    xm.master_print(f"XLA_DOWNCAST_BF16: {os.getenv('XLA_DOWNCAST_BF16', None)}")
    
    xm.rendezvous("Entering training function")

    train(args)

    xm.rendezvous("Done training")

I run the code with this instruction.

NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=1 MALLOC_ARENA_MAX=64 XLA_DOWNCAST_BF16=1 NEURON_CC_FLAGS="--model-type=cnn-training --cache_dir=./compiler_cache -O1 --enable-saturate-infinity" neuron_parallel_compile --num_parallel=4 torchrun --nproc_per_node=32 \
  compile_test.py \
  --model 1.5 \
  --resolution 512 \
  --gradient_accumulation_steps 1 \
  --batch_size 2 \
  --save_model_epochs 1 \
  --checkpointing_steps 750

I tried it with various versions of the AWS Neuron SDK, but the result was the same.

There are many opened and closed github issues related to SD inference, but very few for SD training... so I hope this issue would help other developers working on SD training.

Any help would be greatly appreciated!

@satishpasumarthi
Copy link

Thank you for filing the issue. We will take a look and get back to you.

@geonm
Copy link
Author

geonm commented Jan 2, 2025

Thanks!

I tested the SD training sample code with pytorch 1.13 + AWS neuron on AMI (Deep Learning AMI Neuron (Ubuntu 22.04) 20241027) without any modifications but an error was the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants