You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I want to ask how to utilize CNNLSTM model in DP with Opacus.
Let's assume my data is image frames with size of [batch_size (32), num_seq (20), 1, 32, 32] and target [batch_size,]
My question is how to utilize this model in setting?
When I train the model, I keep receiving this error in this line.
optimizer.step()
RuntimeError: stack expects each tensor to be equal size, but got [640] at entry 0 and [32] at entry 8
This means the first value of input size does not keep same as batch size(32).
I think the reason is this line in the model.
x = x.view(batch_size * seq_length, channels, height, width) # [batch_size * num_seq, channels, height, width]
this line changes the input data size as [640 (32*20, channels, height, width].
How can I train DP-CNNLSTM model in this setting?
This is the model.
from opacus.layers import DPLSTM
import torch.nn as nn
class DPcnnlstm(nn.Module):
def __init__(self, num_classes=10, num_layers=1):
super(DPcnnlstm, self).__init__()
# CNN layers
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.GroupNorm(num_groups=16, num_channels=16),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.GroupNorm(num_groups=16, num_channels=32),
nn.Dropout(p=0.3)
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Downsample: (32, 16, 16)
# LSTM layer
self.lstm = DPLSTM(32 * 8 * 8, 256, num_layers=num_layers, batch_first=True)
# Fully connected layer
self.fc = nn.Linear(256, num_classes)
def forward(self, x):
batch_size, seq_length, channels, height, width = x.size()
x = x.view(batch_size * seq_length, channels, height, width) # [batch_size * num_seq, channels, height, width]
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
# Flatten the CNN output
x = x.view(batch_size, seq_length, -1) # Reshape to (batch_size, seq_length, features)
# Unpack and select the last valid time step based on the length
lstm_out, _ = self.lstm(x)
# Fully connected layer for final classification output
out = self.fc(lstm_out) # [batch_size, seq_length, num_classes]
return out
Thank you for your question, the issue is indeed with the line you specify x = x.view(batch_size * seq_length, channels, height, width) # [batch_size * num_seq, channels, height, width].
Opacus expects that the input to each module has the batch size as the first or second dimension of the input (by default first). See e.g, another issue related to batch size.
There is no work-around to this. What is the source for your model architecture? Is it a standard approach to flatten the input to have dimension (batch_size * seq_length, , ) and then to re-shape it after the CNN layers to have size (batch_size, seq_length, ...)? For RNN architectures, I have not encountered this approach of essentially grouping the entire sequence into one sample.
Hi, I want to ask how to utilize CNNLSTM model in DP with Opacus.
Let's assume my data is image frames with size of [batch_size (32), num_seq (20), 1, 32, 32] and target [batch_size,]
My question is how to utilize this model in setting?
When I train the model, I keep receiving this error in this line.
optimizer.step()
RuntimeError: stack expects each tensor to be equal size, but got [640] at entry 0 and [32] at entry 8
This means the first value of input size does not keep same as batch size(32).
I think the reason is this line in the model.
x = x.view(batch_size * seq_length, channels, height, width) # [batch_size * num_seq, channels, height, width]
this line changes the input data size as [640 (32*20, channels, height, width].
How can I train DP-CNNLSTM model in this setting?
This is the model.
This is the opacus setting.
The text was updated successfully, but these errors were encountered: