From e757cef57d89e448c413de7325ed5601aceaac13 Mon Sep 17 00:00:00 2001 From: Albert Gu Date: Wed, 21 Feb 2024 20:08:10 +0000 Subject: [PATCH] Fix 'backend' usage in DPLR kernel --- models/s4/s4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/s4/s4.py b/models/s4/s4.py index caec2a8..deb5341 100644 --- a/models/s4/s4.py +++ b/models/s4/s4.py @@ -1375,9 +1375,9 @@ def forward(self, state=None, rate=1.0, L=None): v = v * dt # Incorporate dt into B # Dispatch which Cauchy kernel to use - if has_cuda_extension and z.dtype == torch.cfloat and z.device.type == 'cuda' and self.kernel == 'cuda': + if has_cuda_extension and z.dtype == torch.cfloat and z.device.type == 'cuda' and self.backend == 'cuda': cauchy_mult = cauchy_cuda - elif has_pykeops and self.kernel in ['cuda', 'keops']: + elif has_pykeops and self.backend in ['cuda', 'keops']: cauchy_mult = cauchy_keops else: cauchy_mult = cauchy_naive