diff --git a/core/utils/decoder_utils.py b/core/utils/decoder_utils.py index c2eaa5b..92d20d3 100644 --- a/core/utils/decoder_utils.py +++ b/core/utils/decoder_utils.py @@ -81,7 +81,8 @@ def decode_sdf_gradient(decoder, latent_vector, points, clamp_dist=0.1, MAX_POIN points_batch = points[start:end] sdf = decode_sdf(decoder, latent_vector, points_batch, clamp_dist=clamp_dist) start = end - grad_tensor = torch.autograd.grad(outputs=sdf, inputs=points_batch, grad_outputs=torch.ones_like(points_batch), create_graph=True, retain_graph=True) + # grad_tensor = torch.autograd.grad(outputs=sdf, inputs=points_batch, grad_outputs=torch.ones_like(points_batch), create_graph=True, retain_graph=True) + grad_tensor = torch.autograd.grad(outputs=sdf, inputs=points_batch, grad_outputs=torch.ones_like(sdf), create_graph=True, retain_graph=True) grad_tensor = grad_tensor[0] if no_grad: grad_tensor = grad_tensor.detach()