From 5189dddd6ea9511f1675fa49a289c3e98e76927f Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Mon, 17 Jun 2024 09:21:04 +0800 Subject: [PATCH] fix three_interplote bug. --- .../ops/csrc/pytorch/npu/three_interpolate_npu.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index 0f4590818f..42d346f7d2 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_forward ascend only support fp32 and fp16."); - auto point_c_trans = points.transpose(1, 2); - + auto point_c_trans = points.transpose(1, 2).to(at::kFloat); + auto weight_cast = weight.to(at::kFloat); + auto out_cast = out.to(at::kFloat); OpCommand cmd; cmd.Name("ThreeInterpolate") .Input(point_c_trans) .Input(idx) - .Input(weight) - .Output(out) + .Input(weight_cast) + .Output(out_cast) .Run(); - auto output = out.view({b, n, c}).transpose(1, 2); + if (originDtype == at::kHalf) { + out_cast = out_cast.to(at::kHalf); + } + auto output = out_cast.view({b, n, c}).transpose(1, 2); auto res = output.contiguous(); out.copy_(res); }