diff --git a/mlx_vlm/models/deepseek_vl_v2/language.py b/mlx_vlm/models/deepseek_vl_v2/language.py index 0d904f3..774b6ab 100644 --- a/mlx_vlm/models/deepseek_vl_v2/language.py +++ b/mlx_vlm/models/deepseek_vl_v2/language.py @@ -408,9 +408,9 @@ def __call__(self, x): # Calculate group scores using top-2 sum per group scores_reshaped = scores_for_choice.reshape(bsz * seq_len, self.n_group, -1) - k = 2 - group_scores_topk = mx.sort(scores_reshaped, axis=-1)[..., -k:] - group_scores = group_scores_topk.sum(axis=-1) + + # Get top 2 scores per group + group_scores = mx.topk(scores_reshaped, 2, axis=-1).sum(axis=-1) # Get top groups k = self.n_group - self.topk_group