diff --git a/dpr_scale/task/citadel_task.py b/dpr_scale/task/citadel_task.py index 07eaf7d..6f1eea6 100644 --- a/dpr_scale/task/citadel_task.py +++ b/dpr_scale/task/citadel_task.py @@ -4,6 +4,8 @@ import torch from pytorch_lightning.strategies import DDPShardedStrategy, DDPStrategy from dpr_scale.task.dpr_task import DenseRetrieverTask +from torch.cuda.amp import autocast + class MultiVecRetrieverTask(DenseRetrieverTask): def __init__( @@ -144,9 +146,15 @@ def sim_score(self, query_repr, context_repr, mask=None, pairwise=False): if mask is not None: scores[mask] = float("-inf") else: - scores = torch.matmul( - query_repr, torch.transpose(context_repr, 0, 1) - ) # num_q x num_ctx + if self.trainer.precision == 16: + with autocast(enabled=False): + scores = torch.matmul( + query_repr, torch.transpose(context_repr, 0, 1) + ) # num_q x num_ctx + else: + scores = torch.matmul( + query_repr, torch.transpose(context_repr, 0, 1) + ) # num_q x num_ctx if mask is not None: # mask is size num_ctx scores[mask.repeat(scores.size(0), 1)] = float("-inf")