diff --git a/models/s4/s4.py b/models/s4/s4.py index caec2a8..deb5341 100644 --- a/models/s4/s4.py +++ b/models/s4/s4.py @@ -1375,9 +1375,9 @@ def forward(self, state=None, rate=1.0, L=None): v = v * dt # Incorporate dt into B # Dispatch which Cauchy kernel to use - if has_cuda_extension and z.dtype == torch.cfloat and z.device.type == 'cuda' and self.kernel == 'cuda': + if has_cuda_extension and z.dtype == torch.cfloat and z.device.type == 'cuda' and self.backend == 'cuda': cauchy_mult = cauchy_cuda - elif has_pykeops and self.kernel in ['cuda', 'keops']: + elif has_pykeops and self.backend in ['cuda', 'keops']: cauchy_mult = cauchy_keops else: cauchy_mult = cauchy_naive