forked from notmahi/clip-fields
-
Notifications
You must be signed in to change notification settings - Fork 1
/
point2emb.py
148 lines (123 loc) · 5.06 KB
/
point2emb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
from torch import Tensor, nn
import torch.nn.functional as F
POSITION_INPUT_DIMS = 3
class SinusoidalPositionalEmbeddings(nn.Module):
def __init__(self, num_embs: int, learned: bool = False) -> None:
super().__init__()
self.num_embs = num_embs
self.freqs = nn.Parameter(2.0 ** torch.arange(0, num_embs), requires_grad=learned)
def out_dims(self) -> int:
return 2 * self.num_embs + 1
def forward(self, pos: Tensor) -> Tensor:
"""Applies sinusoidal embeddings to input positions.
Args:
pos: Tensor with shape (..., N)
Returns:
Embedded positions, with shape (..., N * (num_embs * 2) + 1)
"""
pos = pos.unsqueeze(-1)
freq_pos = self.freqs * pos
sin_embs, cos_embs = torch.sin(freq_pos), torch.cos(freq_pos)
return torch.cat([pos, sin_embs, cos_embs], dim=-1).flatten(-2)
def init_weights(layer: nn.Module) -> None:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
#num_layers: int = ml.conf_field(MISSING, help="Number of MLP layers for encoding position")
#hidden_dims: int = ml.conf_field(MISSING, help="Number of hidden layer dimensions")
#num_pos_embs: int = ml.conf_field(6, help="Number of positional embedding frequencies")
#output_dims: int = ml.conf_field(MISSING, help="Number of output dimensions")
#norm: str = ml.conf_field("no_norm", help="Per-layer normalization to apply")
#act: str = ml.conf_field("relu", help="Activation function to use")
def get_norm_linear(norm = 'no_norm', dim = 0):
if norm == 'no_norm':
return nn.Identity()
if norm == 'layer':
return nn.LayerNorm(dim)
if norm == 'batch':
return nn.BactchNorm1d(dim)
def get_activation(act = 'relu'):
if act == 'no_act':
return nn.Identity()
if act == 'relu':
return nn.ReLU()
if act == 'softmax':
return nn.Softmax(dim = -1)
class Point2EmbModel(nn.Module):
def __init__(self,
num_layers: int,
hidden_dims: int,
image_rep_size: int,
text_rep_size: int,
num_pos_embs = 6,
norm = 'no_norm',
act = 'relu', ) -> None:
super().__init__()
assert num_layers > 0
# Gets the position embedding MLP.
self.image_rep_size = image_rep_size
self.text_rep_size = text_rep_size
self.pos_embs = SinusoidalPositionalEmbeddings(num_pos_embs)
pos_mlp_in_dims = POSITION_INPUT_DIMS * self.pos_embs.out_dims()
output_dims = image_rep_size + text_rep_size
layers: list[nn.Module] = []
self.temperature = nn.Parameter(torch.log(torch.tensor(1.0 / 0.07)))
layers += [
nn.Sequential(
nn.Linear(
pos_mlp_in_dims if i == 0 else hidden_dims,
output_dims if i == num_layers - 1 else hidden_dims,
),
get_norm_linear(
"no_norm" if i == num_layers - 1 else norm,
dim=output_dims if i == num_layers - 1 else hidden_dims,
),
get_activation(
"no_act" if i == num_layers - 1 else act,
),
)
for i in range(num_layers)
]
self.position_mlp = nn.Sequential(*layers)
self.apply(init_weights)
def forward(self, points: Tensor) -> Tensor:
"""Simple model mapping a viewing angle to an embedding vector.
Args:
points: The point cloud, with shape (B, N, 3)
Returns:
The output embedding for the viden views, with shape (B, N, E)
"""
# Embeds the (X, Y, Z) coordinates.
pos_embs = self.pos_embs(points)
preds = self.position_mlp(pos_embs)
return preds[:, :self.text_rep_size], preds[:, -self.image_rep_size:]
def compute_loss(
self, predicted_latents, actual_latents, label_mask=None, weights=None
):
normalized_predicted_latents = F.normalize(predicted_latents, p=2, dim=-1)
normalized_actual_latents = F.normalize(actual_latents, p=2, dim=-1)
temp = torch.exp(self.temperature)
sim = (
torch.einsum(
"i d, j d -> i j",
normalized_predicted_latents,
normalized_actual_latents,
)
* temp
)
# Zero out the cells where the labels are same.
if label_mask is not None:
sim = sim * label_mask
del label_mask
labels = torch.arange(len(predicted_latents), device=predicted_latents.device)
if weights is None:
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
else:
loss = (
F.cross_entropy(sim, labels, reduction="none")
+ F.cross_entropy(sim.t(), labels, reduction="none")
) / 2
loss = (loss * weights).mean()
return loss