Skip to content

Commit

Permalink
Fix CLIP pos embedding interpolation to work on DTensors (#1739)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Oct 2, 2024
1 parent bae4b27 commit 7cf656b
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributed._tensor import distribute_tensor, DTensor


class TokenPositionalEmbedding(nn.Module):
Expand Down Expand Up @@ -137,8 +138,20 @@ def _load_state_dict_hook(
inpt_local_pos_embed = state_dict.get(
prefix + "local_token_positional_embedding"
)

if inpt_local_pos_embed is not None:

# We can only apply F.interpolate to vanilla tensors, not DTensors
# If pos embeds are a DTensor, we gather the full tensor, apply
# interpolate, and then reshard after
if isinstance(inpt_local_pos_embed, DTensor):
local_embed_is_sharded = True
local_embed_device_mesh = inpt_local_pos_embed.device_mesh
local_embed_placements = inpt_local_pos_embed.placements
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
else:
local_embed_is_sharded = False

# sanity check
inpt_n_tokens_per_tile, inpt_embed_dim = inpt_local_pos_embed.shape
if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0:
Expand All @@ -159,6 +172,13 @@ def _load_state_dict_hook(
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)

if local_embed_is_sharded:
inpt_local_pos_embed = distribute_tensor(
inpt_local_pos_embed,
device_mesh=local_embed_device_mesh,
placements=local_embed_placements,
)

# update state dict
state_dict[
prefix + "local_token_positional_embedding"
Expand All @@ -176,8 +196,20 @@ def _load_state_dict_hook(
inpt_global_pos_embed = state_dict.get(
prefix + "global_token_positional_embedding"
)

if inpt_global_pos_embed is not None:

# We can only apply F.interpolate to vanilla tensors, not DTensors
# If pos embeds are a DTensor, we gather the full tensor, apply
# interpolate, and then reshard after
if isinstance(inpt_global_pos_embed, DTensor):
global_embed_is_sharded = True
global_embed_device_mesh = inpt_global_pos_embed.device_mesh
global_embed_placements = inpt_global_pos_embed.placements
inpt_global_pos_embed = inpt_global_pos_embed.full_tensor()
else:
global_embed_is_sharded = False

_, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape

# sanity check
Expand All @@ -202,6 +234,13 @@ def _load_state_dict_hook(
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)

if global_embed_is_sharded:
inpt_global_pos_embed = distribute_tensor(
inpt_global_pos_embed,
device_mesh=global_embed_device_mesh,
placements=global_embed_placements,
)

# update state dict
state_dict[
prefix + "global_token_positional_embedding"
Expand Down Expand Up @@ -500,6 +539,17 @@ def _load_state_dict_hook(

if embedding is not None:

# We can only apply F.interpolate to vanilla tensors, not DTensors
# If pos embeds are a DTensor, we gather the full tensor, apply
# interpolate, and then reshard after
if isinstance(embedding, DTensor):
embedding_is_sharded = True
device_mesh = embedding.device_mesh
placements = embedding.placements
embedding = embedding.full_tensor()
else:
embedding_is_sharded = False

# ckpt pos emb
(
tgt_max_num_tiles_x,
Expand Down Expand Up @@ -534,6 +584,13 @@ def _load_state_dict_hook(
embedding, tgt_max_num_tiles=tgt_max_num_tiles_x
)

if embedding_is_sharded:
embedding_new = distribute_tensor(
embedding_new,
device_mesh=device_mesh,
placements=placements,
)

# update state dict
state_dict[prefix + "embedding"] = embedding_new
if embedding_new.shape != self.embedding.shape:
Expand Down

0 comments on commit 7cf656b

Please sign in to comment.