Skip to content

Commit

Permalink
[feature] support async ckpt & pin memory cache (#760)
Browse files Browse the repository at this point in the history
* [feature] support async ckpt

* [feature] support pin memory cache

* [doc] update readme
  • Loading branch information
ver217 authored Dec 20, 2024
1 parent 38de637 commit 786d4e7
Show file tree
Hide file tree
Showing 8 changed files with 523 additions and 127 deletions.
29 changes: 19 additions & 10 deletions README.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions configs/opensora-v1-2/train/stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,6 @@
ema_decay = 0.99
adam_eps = 1e-15
warmup_steps = 1000

cache_pin_memory = True
pin_memory_cache_pre_alloc_numels = [(290 + 20) * 1024**2] * (2 * 8 + 4)
210 changes: 206 additions & 4 deletions opensora/datasets/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,216 @@
import collections
import functools
import queue
import random
import threading
from typing import Optional

import numpy as np
import torch
import torch.multiprocessing as multiprocessing
from torch._utils import ExceptionWrapper
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, _utils
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import (
IterDataPipe,
MapDataPipe,
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
_sharding_worker_init_fn,
_SingleProcessDataLoaderIter,
)

from .datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset
from .pin_memory_cache import PinMemoryCache
from .sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler


def _pin_memory_loop(
in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str
):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)

if device == "cuda":
torch.cuda.set_device(device_id)
elif device == "xpu":
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
elif device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
custom_device_mod.set_device(device_id)

def do_one_step():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
return
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
try:
assert isinstance(data, dict)
if pin_memory_key in data:
val = data[pin_memory_key]
pin_memory_value = pin_memory_cache.get(val)
pin_memory_value.copy_(val)
data[pin_memory_key] = pin_memory_value
except Exception:
data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}")
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue

# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
# Make sure that we don't preserve any object from one iteration
# to the next
do_one_step()


class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter):
pin_memory_key: str = "video"

def __init__(self, loader):
_BaseDataLoaderIter.__init__(self, loader)
self.pin_memory_cache = PinMemoryCache()

self._prefetch_factor = loader.prefetch_factor

assert self._num_workers > 0
assert self._prefetch_factor > 0

if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context

self._worker_init_fn = loader.worker_init_fn

# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# Additional worker init function will take care of sharding in MP and Distributed
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
self._worker_init_fn = functools.partial(
_sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
)

# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event()

self._index_queues = []
self._workers = []
for i in range(self._num_workers):
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(
self._dataset_kind,
self._dataset,
index_queue,
self._worker_result_queue,
self._workers_done_event,
self._auto_collation,
self._collate_fn,
self._drop_last,
self._base_seed,
self._worker_init_fn,
i,
self._num_workers,
self._persistent_workers,
self._shared_seed,
),
)
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)

if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()

# Queue is not type-annotated
self._data_queue = queue.Queue() # type: ignore[var-annotated]
if self._pin_memory_device == "xpu":
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
current_device = custom_device_mod.current_device()
else:
current_device = torch.cuda.current_device() # choose cuda for default
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(
self._worker_result_queue,
self._data_queue,
current_device,
self._pin_memory_thread_done_event,
self._pin_memory_device,
self.pin_memory_cache,
self.pin_memory_key,
),
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue # type: ignore[assignment]

# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked
# when main process exits
# It would cause failure when pin_memory_thread tries to read
# corrupted data from worker_result_queue
# atexit is used to shutdown thread and child processes in the
# right sequence before main process exits
if self._persistent_workers and self._pin_memory:
import atexit

for w in self._workers:
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)

# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._reset(loader, first_iter=True)

def remove_cache(self, output_tensor: torch.Tensor):
self.pin_memory_cache.remove(output_tensor)

def get_cache_info(self) -> str:
return str(self.pin_memory_cache)


class DataloaderForVideo(DataLoader):
def _get_iterator(self) -> "_BaseDataLoaderIter":
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIterForVideo(self)


# Deterministic dataloader
def get_seed_worker(seed):
def seed_worker(worker_id):
Expand All @@ -35,6 +234,7 @@ def prepare_dataloader(
bucket_config=None,
num_bucket_build_workers=1,
prefetch_factor=None,
cache_pin_memory=False,
**kwargs,
):
_kwargs = kwargs.copy()
Expand All @@ -50,8 +250,9 @@ def prepare_dataloader(
verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
)
dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
return (
DataLoader(
dl_cls(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=get_seed_worker(seed),
Expand All @@ -71,8 +272,9 @@ def prepare_dataloader(
rank=process_group.rank(),
shuffle=shuffle,
)
dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
return (
DataLoader(
dl_cls(
dataset,
batch_size=batch_size,
sampler=sampler,
Expand Down Expand Up @@ -137,7 +339,7 @@ def collate_fn_batch(batch):
"""
# filter out None
batch = [x for x in batch if x is not None]

res = torch.utils.data.default_collate(batch)

# squeeze the first dimension, which is due to torch.stack() in default_collate()
Expand Down
76 changes: 76 additions & 0 deletions opensora/datasets/pin_memory_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import threading
from typing import Dict, List, Optional

import torch


class PinMemoryCache:
force_dtype: Optional[torch.dtype] = None
min_cache_numel: int = 0
pre_alloc_numels: List[int] = []

def __init__(self):
self.cache: Dict[int, torch.Tensor] = {}
self.output_to_cache: Dict[int, int] = {}
self.cache_to_output: Dict[int, int] = {}
self.lock = threading.Lock()
self.total_cnt = 0
self.hit_cnt = 0

if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None:
for n in self.pre_alloc_numels:
cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True)
with self.lock:
self.cache[id(cache_tensor)] = cache_tensor

def get(self, tensor: torch.Tensor) -> torch.Tensor:
"""Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content.
Args:
tensor (torch.Tensor): The tensor to be pinned.
Returns:
torch.Tensor: The pinned tensor.
"""
self.total_cnt += 1
with self.lock:
# find free cache
for cache_id, cache_tensor in self.cache.items():
if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel():
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
out_id = id(target_cache_tensor)
self.output_to_cache[out_id] = cache_id
self.cache_to_output[cache_id] = out_id
self.hit_cnt += 1
return target_cache_tensor
# no free cache, create a new one
dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype
cache_numel = max(tensor.numel(), self.min_cache_numel)
cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True)
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
out_id = id(target_cache_tensor)
with self.lock:
self.cache[id(cache_tensor)] = cache_tensor
self.output_to_cache[out_id] = id(cache_tensor)
self.cache_to_output[id(cache_tensor)] = out_id
return target_cache_tensor

def remove(self, output_tensor: torch.Tensor) -> None:
"""Release corresponding cache tensor.
Args:
output_tensor (torch.Tensor): The tensor to be released.
"""
out_id = id(output_tensor)
with self.lock:
if out_id not in self.output_to_cache:
raise ValueError("Tensor not found in cache.")
cache_id = self.output_to_cache.pop(out_id)
del self.cache_to_output[cache_id]

def __str__(self):
with self.lock:
num_cached = len(self.cache)
num_used = len(self.output_to_cache)
total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()])
return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})"
6 changes: 0 additions & 6 deletions opensora/models/text_encoder/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,8 @@ def shardformer_t5(self):
from opensora.utils.misc import requires_grad

shard_config = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=None,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_flash_attention=False,
enable_jit_fused=True,
enable_sequence_parallelism=False,
enable_sequence_overlap=False,
)
shard_former = ShardFormer(shard_config=shard_config)
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
Expand Down
Loading

0 comments on commit 786d4e7

Please sign in to comment.