Skip to content

Commit

Permalink
[Enhance] Adapt test cases on Ascend NPU. (#1728)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ginray authored Jul 28, 2023
1 parent 4d1dbaf commit c5248b1
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 9 deletions.
3 changes: 2 additions & 1 deletion tests/test_engine/test_hooks/test_densecl_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
Expand Down Expand Up @@ -79,7 +80,7 @@ def tearDown(self):
self.temp_dir.cleanup()

def test_densecl_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
densecl_hook = DenseCLHook(start_iters=1)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_engine/test_hooks/test_ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.evaluator import Evaluator
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
Expand Down Expand Up @@ -70,7 +71,7 @@ def tearDown(self):
self.temp_dir.cleanup()

def test_load_state_dict(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
model = SimpleModel().to(device)
ema_hook = EMAHook()
runner = Runner(
Expand All @@ -95,7 +96,7 @@ def test_load_state_dict(self):

def test_evaluate_on_ema(self):

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
model = SimpleModel().to(device)

# Test validate on ema model
Expand Down
3 changes: 2 additions & 1 deletion tests/test_engine/test_hooks/test_simsiam_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.model import BaseModule
from mmengine.runner import Runner
Expand Down Expand Up @@ -79,7 +80,7 @@ def tearDown(self):
self.temp_dir.cleanup()

def test_simsiam_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
simsiam_hook = SimSiamHook(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_engine/test_hooks/test_swav_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
Expand Down Expand Up @@ -86,7 +87,7 @@ def tearDown(self):
self.temp_dir.cleanup()

def test_swav_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
swav_hook = SwAVHook(
Expand Down
9 changes: 5 additions & 4 deletions tests/test_engine/test_hooks/test_switch_recipe_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.dataset import BaseDataset, ConcatDataset, RepeatDataset
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel
from mmengine.optim import OptimWrapper
Expand Down Expand Up @@ -130,7 +131,7 @@ def test_init(self):
self.assertIsNone(hook.schedule[1]['batch_augments'])

def test_do_switch(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
model = SimpleModel().to(device)

loss = CrossEntropyLoss(use_soft=True)
Expand Down Expand Up @@ -205,7 +206,7 @@ def test_do_switch(self):
# runner.train()

def test_resume(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
model = SimpleModel().to(device)

loss = CrossEntropyLoss(use_soft=True)
Expand Down Expand Up @@ -275,7 +276,7 @@ def test_resume(self):
logs.output)

def test_switch_train_pipeline(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
model = SimpleModel().to(device)

runner = Runner(
Expand Down Expand Up @@ -324,7 +325,7 @@ def test_switch_train_pipeline(self):
pipeline)

def test_switch_loss(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
model = SimpleModel().to(device)

runner = Runner(
Expand Down

0 comments on commit c5248b1

Please sign in to comment.