Skip to content

Commit

Permalink
Scheduling fixes on MPS (#10549)
Browse files Browse the repository at this point in the history
* use np.int32 in scheduling

* test_add_noise_device

* -np.int32, fixes
  • Loading branch information
hlky authored Jan 16, 2025
1 parent 9e1b8a0 commit 08e62fe
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_heun_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def set_timesteps(
timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])

self.timesteps = timesteps.to(device=device)
self.timesteps = timesteps.to(device=device, dtype=torch.float32)

# empty dt and derivative
self.prev_derivative = None
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
Expand Down
2 changes: 1 addition & 1 deletion tests/schedulers/test_scheduler_lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_add_noise_device(self, num_inference_steps=10):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)

noise = torch.randn_like(scaled_sample).to(torch_device)
noise = torch.randn(scaled_sample.shape).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
Expand Down
4 changes: 2 additions & 2 deletions tests/schedulers/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def model(sample, t, *args):
if isinstance(t, torch.Tensor):
num_dims = len(sample.shape)
# pad t with 1s to match num_dims
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device).to(sample.dtype)
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device, dtype=sample.dtype)

return sample * t / (t + 1)

Expand Down Expand Up @@ -722,7 +722,7 @@ def test_add_noise_device(self):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)

noise = torch.randn_like(scaled_sample).to(torch_device)
noise = torch.randn(scaled_sample.shape).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
Expand Down

0 comments on commit 08e62fe

Please sign in to comment.