Skip to content

Commit

Permalink
Fix an issue with working with gpu
Browse files Browse the repository at this point in the history
The variable ab needs to be registered to buffer to avoid an error with incompatible devices.
  • Loading branch information
ksugar committed Jul 2, 2023
1 parent 42e87b2 commit 48f418b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mobile_sam/modeling/tiny_vit_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def train(self, mode=True):
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
self.register_buffer('ab',
self.attention_biases[:, self.attention_bias_idxs],
persistent=False)

def forward(self, x): # x (B,N,C)
B, N, _ = x.shape
Expand Down

0 comments on commit 48f418b

Please sign in to comment.